美文网首页
PyTorch源码

PyTorch源码

作者: e237262360d2 | 来源:发表于2018-08-10 17:10 被阅读0次

x = x.view(x.size(0), -1)          

# view函数就是reshape,从1开始,-1说明自行计算行数或列数,把多行多列的变成一行,全链接的输入

nn.Linear(20, 30)

#相当于执行y = Wx + b,全链接,从20变成30,W是默认随机初始化的weight=Parameter(torch.Tensor(out_features, in_features)) [30,20]

loss_func = nn.CrossEntropyLoss()

loss = loss_func(output, b_y)

#计算交叉熵

预测的(0.0,1.0,0.0)

实际的(0.228,0.619,0.153)

H = - (0.0*ln(0.228) + 1.0*ln(0.619) + 0.0*ln(0.153)) = 0.479

import torch

import torchvision

import torchvision.transforms as transforms

import torchvision.utils as utils

from PIL import Image

import numpy as np

import cv2

img_path = "../img_66_pos_real/72264809151922400.jpg"

# transforms.ToTensor()

transform1 = transforms.Compose([     

    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 归一化

    ]

)

##opencv

img = cv2.imread(img_path)

print("img = ", img)

img1 = transform1(img)

print("img1 = ",img1)

##PIL

img = Image.open(img_path).convert('RGB')

img2 = transform1(img)

print("img2 = ",img2)

相关文章

网友评论

      本文标题:PyTorch源码

      本文链接:https://www.haomeiwen.com/subject/cqdlbftx.html