Time: 2019-08-04
视频地址:https://youtu.be/6vweQjouLEE?list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG&t=23
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
torch.set_printoptions(linewidth=120)
# 训练集
train_set = torchvision.datasets.FashionMNIST(
root='./data/FashionMNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor()
])
)
# 构建网络
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t = F.relu(self.conv1(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.conv2 (t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
t = F.relu(self.fc2(t))
t = self.out(t)
return t
torch.set_grad_enabled(False)
# 实例化网络
net = Network()
sample = next(iter(train_set))
image, label = sample
image.shape # torch.Size([1, 28, 28])
image.unsqueeze(0).shape # torch.Size([1, 1, 28, 28])
# 执行预测
# 预测时需要输入图片的形状为4维张量
pred = net(image.unsqueeze(0))
pred # tensor([[-0.0484, -0.0635, -0.0606, -0.1533, 0.0612, 0.0382, 0.0014, -0.0159, -0.0116, -0.1182]])
pred.argmax(dim=1) # tensor([4])
F.softmax(pred, dim=1)
这里有一些要点需要注意:网络接收的数据是4D张量,单张图片也需要处理成4D的格式,用的是tensor.unsqueeze()
方法。
END.
网友评论