美文网首页
为什么 Pytorch 是动态 Dynamic

为什么 Pytorch 是动态 Dynamic

作者: 地平线上的背影 | 来源:发表于2019-02-14 08:50 被阅读0次

    在构建神经网络时,我们常会遭遇部分复杂情形。一般而言,神经网络构建好之后便可以不再变动,但是在部分情况下,我们需要变动神经网络的结构以完成某些特殊要求,这时候Pytorch的优势便显示出来了,因为Pytorch本身就是动态的。当然,Pytorch也自带某些动态的神经网络结构,如Dynamic RNN等

    1. 准备初始参数

    import torch
    from torch import nn
    import numpy as np
    import matplotlib.pyplot as plt
    
    # torch.manual_seed(1)    # reproducible
    
    # Hyper Parameters
    INPUT_SIZE = 1          # rnn input size / image width
    LR = 0.02               # learning rate
    

    2. 构建普通RNN模型

    class RNN(nn.Module):
        def __init__(self):
            super(RNN, self).__init__()
    
            self.rnn = nn.RNN(
                input_size=1,
                hidden_size=32,     # rnn hidden unit
                num_layers=1,       # number of rnn layer
                batch_first=True,   # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
            )
            self.out = nn.Linear(32, 1)
    
        def forward(self, x, h_state):
            # x (batch, time_step, input_size)
            # h_state (n_layers, batch, hidden_size)
            # r_out (batch, time_step, output_size)
            r_out, h_state = self.rnn(x, h_state)
    
            outs = []                                   # this is where you can find torch is dynamic
            for time_step in range(r_out.size(1)):      # calculate output for each time step
                outs.append(self.out(r_out[:, time_step, :]))
            return torch.stack(outs, dim=1), h_state
    
    
    rnn = RNN()
    print(rnn)
    

    3. 选择优化器和损失函数

    optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters
    loss_func = nn.MSELoss()                                # the target label is not one-hotted
    

    4. 初始化状态和plt图像

    h_state = None   # for initial hidden state
    
    plt.figure(1, figsize=(12, 5))
    plt.ion()   # continuously plot
    

    5. Dynamic RNN条件

    ################ static time steps ##########
    # for step in range(60):
    #     start, end = step * np.pi, (step+1)*np.pi   # time steps
    #     # use sin predicts cos
    #     steps = np.linspace(start, end, 10, dtype=np.float32)
    
    ################ dynamic time steps #########
    step = 0
    for i in range(60):
        dynamic_steps = np.random.randint(1, 4)  # has random time steps
        start, end = step * np.pi, (step + dynamic_steps) * np.pi  # different time steps length
        step += dynamic_steps
    
        # use sin predicts cos
        steps = np.linspace(start, end, 10 * dynamic_steps, dtype=np.float32)
    

    6. 训练和优化

    6.1 数据预处理
        print(len(steps))       # print how many time step feed to RNN
    
        x_np = np.sin(steps)    # float32 for converting torch FloatTensor
        y_np = np.cos(steps)
    
        x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
        y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
    
        prediction, h_state = rnn(x, h_state)   # rnn output
        # !! next step is important !!
        h_state = h_state.data        # repack the hidden state, break the connection from last iteration
    
    6.2 优化
        loss = loss_func(prediction, y)         # cross entropy loss
        optimizer.zero_grad()                   # clear gradients for this training step
        loss.backward()                         # backpropagation, compute gradients
        optimizer.step()                        # apply gradients
    
    6.3 训练过程可视化
        # plotting
        plt.plot(steps, y_np.flatten(), 'r-')
        plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
        plt.draw()
        plt.pause(0.05)
    
    plt.ioff()
    plt.show()
    

    相关文章

      网友评论

          本文标题:为什么 Pytorch 是动态 Dynamic

          本文链接:https://www.haomeiwen.com/subject/pubreqtx.html