美文网首页
使用CNN模型做预测:前向过程解释

使用CNN模型做预测:前向过程解释

作者: 钢笔先生 | 来源:发表于2019-08-04 13:40 被阅读0次

    Time: 2019-08-04
    视频地址:https://youtu.be/6vweQjouLEE?list=PLZbbT5o_s2xrfNyHZsM6ufI0iZENK9xgG&t=23

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    import torchvision
    import torchvision.transforms as transforms
    
    torch.set_printoptions(linewidth=120)
    
    # 训练集
    train_set = torchvision.datasets.FashionMNIST(
        root='./data/FashionMNIST',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )
    
    # 构建网络
    class Network(nn.Module):
        def __init__(self):
            super(Network, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
            self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
    
            self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
            self.fc2 = nn.Linear(in_features=120, out_features=60)
            self.out = nn.Linear(in_features=60, out_features=10)
        
        def forward(self, t):
            t = F.relu(self.conv1(t))
            t = F.max_pool2d(t, kernel_size=2, stride=2)
    
            t = F.relu(self.conv2 (t))
            t = F.max_pool2d(t, kernel_size=2, stride=2)
    
            t = F.relu(self.fc1(t.reshape(-1, 12*4*4)))
            t = F.relu(self.fc2(t))
            t = self.out(t)
    
            return t
    
    torch.set_grad_enabled(False)
    
    # 实例化网络
    net = Network()
    sample = next(iter(train_set))
    image, label = sample
    
    image.shape # torch.Size([1, 28, 28])
    image.unsqueeze(0).shape # torch.Size([1, 1, 28, 28])
    
    # 执行预测
    # 预测时需要输入图片的形状为4维张量
    pred = net(image.unsqueeze(0))
    
    pred # tensor([[-0.0484, -0.0635, -0.0606, -0.1533,  0.0612,  0.0382,  0.0014, -0.0159, -0.0116, -0.1182]])
    pred.argmax(dim=1) # tensor([4])
    F.softmax(pred, dim=1)
    

    这里有一些要点需要注意:网络接收的数据是4D张量,单张图片也需要处理成4D的格式,用的是tensor.unsqueeze()方法。

    END.

    相关文章

      网友评论

          本文标题:使用CNN模型做预测:前向过程解释

          本文链接:https://www.haomeiwen.com/subject/fmbedctx.html