美文网首页pytorch
6. pytorch-保存与恢复

6. pytorch-保存与恢复

作者: FantDing | 来源:发表于2018-07-01 10:36 被阅读134次

    官方序列化教程

    1. 只保存参数

    推荐

    1.1 示例

    • 训练过程: main.py
    # file: main.py
    import torch
    import torch.nn.functional as F
    import matplotlib.pyplot as plt
    
    class Net(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.h1 = torch.nn.Linear(1, 10)
            self.h2 = torch.nn.Linear(10, 1)
    
        def forward(self, x):
            x = F.relu(self.h1(x)) 
            x = self.h2(x)
            return x
    
    def prepare_data():
        torch.manual_seed(1)  # 保证每次生成的随机数相同
        x = torch.linspace(-1, 1, 50)
        x = torch.unsqueeze(x, 1)
        y = x ** 2 + 0.2 * torch.rand(x.size())
        return x, y
    
    if __name__ == "__main__":
        # 1. 数据准备
        x,y=prepare_data()
        plt.scatter(x.numpy(), y.numpy())
        plt.show()
        # 2. 网络搭建
        net = Net()
        # 3. 训练
        optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
        loss_F = torch.nn.MSELoss()
        for iter in range(100):
            pred = net(x)
            loss = loss_F(pred, y)
            optimizer.zero_grad()
            loss.backward()
            print(loss.detach().numpy())
            optimizer.step()
        # 只保存网络状态
        torch.save(net.state_dict(), "./net_param.pkl")
    
    • 恢复过程: test.py
    from main import Net, prepare_data
    import torch
    import matplotlib.pyplot as plt
    
    if __name__ == "__main__":
        net = Net()
        x, y = prepare_data()
        plt.scatter(x.numpy(), y.numpy())
        plt.show()
        # load是加载成dict形式
        net.load_state_dict(torch.load("net_param.pkl"))
        loss_F = torch.nn.MSELoss()
        pred = net(x)
        loss = loss_F(pred, y) # loos值与训练最后一次迭代的loss值相同
        print(loss.detach().numpy())
    

    1.2 好处

    • 可以定义新的类。在test.py中可以定义新的class, forward可以有不同的方式。只要有相同名字的参数,都可以load成功
    class New_Net(torch.nn.Module): # class名字也修改了
        def __init__(self):
            super().__init__()
            self.h1 = torch.nn.Linear(1, 10)
            self.h2 = torch.nn.Linear(10, 1)
    
        def forward(self, x):
            x = F.tanh(self.h1(x)) # 修改了激活函数
            x = self.h2(x)
            return x
    
    • 加载更快

    2. 保存网络结构和参数

    2.1 示例

    • main.py
    import torch
    import torch.nn.functional as F
    import matplotlib.pyplot as plt
    
    class Net(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.h1 = torch.nn.Linear(1, 10)
            self.h2 = torch.nn.Linear(10, 1)
    
        def forward(self, x):
            x = F.relu(self.h1(x))
            x = self.h2(x)
            return x
    
    def prepare_data():
        torch.manual_seed(1)  # 保证每次生成的随机数相同
        x = torch.linspace(-1, 1, 50)
        x = torch.unsqueeze(x, 1)
        y = x ** 2 + 0.2 * torch.rand(x.size())
        return x, y
    
    if __name__ == "__main__":
        # 1. 数据准备
        x,y=prepare_data()
        plt.scatter(x.numpy(), y.numpy())
        plt.show()
        # 2. 网络搭建
        net = Net()
        # 3. 训练
        optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
        loss_F = torch.nn.MSELoss()
        for iter in range(100):
            pred = net(x)
            loss = loss_F(pred, y)
            optimizer.zero_grad()
            loss.backward()
            print(loss.detach().numpy())
            optimizer.step()
        # 只保存网络状态
        torch.save(net, "./net.pkl") #直接保存net,而不是net.state_dict()
    
    • test.py
    from main import Net, prepare_data
    import torch
    import matplotlib.pyplot as plt
    
    if __name__ == "__main__":
        x, y = prepare_data()
        plt.scatter(x.numpy(), y.numpy())
        plt.show()
        # load是加载成dict形式
        net = torch.load("net.pkl")
        loss_F = torch.nn.MSELoss()
        pred = net(x)
        loss = loss_F(pred, y)
        print(loss.detach().numpy())
    

    2.2 弊端

    • 与特定class绑定了。即: 虽然test.py中的net是通过load的来的, 但是还是需要import训练时候的那个类Net(否则会报错)
    • 不灵活。因为结构被定死了,不能定义新的层等等。

    相关文章

      网友评论

        本文标题:6. pytorch-保存与恢复

        本文链接:https://www.haomeiwen.com/subject/lrjcuftx.html