美文网首页
Pytorch 载入和保存模型(无格式整理,先记下)

Pytorch 载入和保存模型(无格式整理,先记下)

作者: fwei | 来源:发表于2017-07-30 17:40 被阅读0次
    1. 定义网络结构
    class DenseNet(nn.Module):
        r"""Densenet-BC model class, based on
        `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
        Args:
            growth_rate (int) - how many filters to add each layer (`k` in paper)
            block_config (list of 4 ints) - how many layers in each pooling block
            num_init_features (int) - the number of filters to learn in the first convolution layer
            bn_size (int) - multiplicative factor for number of bottle neck layers
              (i.e. bn_size * k features in the bottleneck layer)
            drop_rate (float) - dropout rate after each dense layer
            num_classes (int) - number of classification classes
        """
        def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                     num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
    
            super(DenseNet, self).__init__()
    
            # First convolution
            self.features = nn.Sequential(OrderedDict([
                ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                ('norm0', nn.BatchNorm2d(num_init_features)),
                ('relu0', nn.ReLU(inplace=True)),
                ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
            ]))
    
            # Each denseblock
            num_features = num_init_features
            for i, num_layers in enumerate(block_config):
                block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                    bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
                self.features.add_module('denseblock%d' % (i + 1), block)
                num_features = num_features + num_layers * growth_rate
                if i != len(block_config) - 1:
                    trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                    self.features.add_module('transition%d' % (i + 1), trans)
                    num_features = num_features // 2
    
            # Final batch norm
            self.features.add_module('norm5', nn.BatchNorm2d(num_features))
    
            # Linear layer
            self.classifier = nn.Linear(num_features, num_classes)
    
        def forward(self, x):
            features = self.features(x)
            out = F.relu(features, inplace=True)
            out = F.avg_pool2d(out, kernel_size=7).view(features.size(0), -1)
            out = self.classifier(out)
            return out
    
    1. 使用网络结构定义模型:
    net = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
    
    1. 载入模型参数
    net.load_state_dict(torch.load('/home/wei.fan/.torch/models/densenet161-17b70270.pth'))
    

    4.训练模型

    num_ftrs = model_conv.classifier.in_features
    net.classifier = nn.Linear(num_ftrs, 100) #调整最后一层的尺寸
    net =net.cuda()
    criterion = nn.CrossEntropyLoss()
    net = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    net =train_net() #训练模型的函数,自定义
    torch.save(net.state_dict(), 'net_params.pkl') #只保存模型参数
    

    相关文章

      网友评论

          本文标题:Pytorch 载入和保存模型(无格式整理,先记下)

          本文链接:https://www.haomeiwen.com/subject/kgkklxtx.html