美文网首页PyTorch
[PyTorch]专项 输出-模型存储与加载

[PyTorch]专项 输出-模型存储与加载

作者: DDuncan | 来源:发表于2020-02-03 18:28 被阅读0次

    一、Python模块 & data

    %matplotlib inline 
    %config InlineBackend.figure_format = 'retina' 
    import matplotlib.pyplot as plt 
    import torch 
    from torch import nn 
    from torch import optim 
    import torch.nn.functional as F 
    from torchvision import datasets, transforms 
    
    #自定义模块
    import helper 
    import fc_model 
    
    #Define a transform to normalize the data 
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5, ), (0.5, ))]) 
    #Download and load the training data 
    trainset = datasets.FashioniNIST('~/.pytorch/F_MNIST_data', download=True, 
                                    train=True, transform=transform) 
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) 
    
    #Download and load the test data 
    testset = datasets.FashioniNIST ('~/.pytorch/F_MNIST_data', download=True,
                                    train=False, transform=transform) 
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True) 
    

    二、建立模型 & 训练

    #建立模型 自定义模块fc_model
    model = fc_model.Network(784, 10, [512, 256, 128])
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    #训练模型
    fc_model.tranin(model, trainloader, testloader, criterion, optimizer, epochs=2)
    
    • print(model)

      模型:网络架构
    • print(model.state_dict().keys())

      模型参数:存储在字典的键中

    三、存储/加载模型

    1. 存储模型(参数)

    字典checkpoint:保存记录维度的信息

    1. 网络结构
    • input
    • output
    • hidden layers
    • .state_dict() 参数(weights, bias)
    checkpoint = {'input_size': 784,
                'output_size': 10,
                'hidden_layers': [each.out_features for each in model.hidden_layers],
                'state_dict': model.state_dict()}
                 
    torch.save(checkpoint, 'checkpoint.pth')}
    
    model.hidden_layers

    注意:属性in_features, out_features

    2. 加载模型(参数)

    加载模型的参数必须与存储好的模型一致,否则加载错误

    def load_checkpoint(filepath):
        checkpoint = torch.load(filepath) 
        model = fc_model.Network(checkpoint['input_size'],
                                checkpoint['output_size'],
                                checkpoint['hidden_layers']) #.out_features提取维度信息
        model.load_state_dict(checkpoint['state_dict'])
        return model
     
    #加载模型
    model = load_checkpoint('checkpoint.pth')
    print (model) 
    
    加载模型的参数

    相关文章

      网友评论

        本文标题:[PyTorch]专项 输出-模型存储与加载

        本文链接:https://www.haomeiwen.com/subject/thwcxhtx.html