美文网首页
pytorch实现RNN以及LSTM/GRU

pytorch实现RNN以及LSTM/GRU

作者: 升不上三段的大鱼 | 来源:发表于2021-01-10 10:52 被阅读0次

    pytorch提供了很方便的RNN模块,以及其他结构像LSTM和GRU。
    pytorch里的RNN需要的参数主要有:

    • input_size:input_tensor的形状是(序列长度, batch大小,input_size)
    • hidden_size:可以自己定义大小,是一个需要调的参数,hidden state是(RNN的层数*方向,batch,hidden_size),这里的方向默认是1,如果是双向的RNN,方向则是2.
    • num_layers:RNN也可以堆叠起来,默认是1层,可以设置层数。
    • batch_first:第一维是非为batch size,默认为false;如果设为true,意味着输入和输出的第一维是batch。

    代码实现很简单,对于一个一层的RNN,实现多对一的分类:

    class RNN(nn.Module):
        def __init__(self, input_dim, hidden_dim, num_layers, num_class):
            super(RNN,self).__init__()
            self.input_dim = input_dim
            self.hidden_dim = hidden_dim
            self.num_layers = num_layers
            self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=num_layers, batch_first=True)
            # input: (batch_size, sequence_size, input_size)
            # many to one mode
            self.fc = nn.Linear(hidden_dim, num_class)
    
        def forward(self,x):
            # initialize hidden state
            h0 = torch.zeros(self.num_layers, x.size(0),self.hidden_dim).to('cuda:0')
            # output size: (batch_size, sequnce_size, hidden_size)
            out, _ = self.rnn(x)
            out = self.fc(out)
            return out
    

    多对一指的是一个序列的输入对应着一个值的输出,多对多指的是一个序列输入对应着一个序列的输出。
    RNN层的输出有两个,一个是最后一层RNN在所有时间上的输出,另一个是最后一个隐含状态,这里我们只需要一个输出就够了,再加上一个线性层用于分类。
    LSTM/GRU和RNN的用法基本一致,只有改一下名字就行了。

    对于一个一维的序列,假设batch size为128, 序列长度为256,输入维度为1,分类数目为3,输入的shape为(128,256,1),如果是多对一的分类,得到的输出为(128,3);如果是多对多的分类,得到的输出是(128,256,3).在输入序列的256个时间点上都有输出。

    相关文章

      网友评论

          本文标题:pytorch实现RNN以及LSTM/GRU

          本文链接:https://www.haomeiwen.com/subject/belkaktx.html