美文网首页
pytorch RNN运算分析

pytorch RNN运算分析

作者: 锦绣拾年 | 来源:发表于2020-12-10 22:43 被阅读0次

pytorch中的RNN计算过程,以及双向RNN拼接方式

问题起源是在看论文和读源码时,看到用pytorch中的双向RNN,就在思考pytorch中的双向RNN的结果是怎么拼接的,是论文中写的那样吗[\overrightarrow{h},\overleftarrow{h}]

#coding:utf-8
import torch
torch.manual_seed(2)
src = torch.rand((2, 2, 3)) # [bacth_size, 句长(每句2个词) ,词向量(dim=3)]
print(src)
h0 = torch.randn((1, 2, 2))#[维度1,batch_size, 维度2] 
output,_ = torch.nn.RNN(3,2,batch_first=True)(src,h0)#词向量,隐藏层.batch_first=True:表明输入的第0维是batchsize
print(output.size()) # 2,2,2

print(h0)
h1 = torch.cat((h0,h0),0)
print(h1)
print(h1.shape)# 2,2,2
model = torch.nn.RNN(3,2,bidirectional=True,batch_first=True)
output2,_ = model(src,h1)
print(output2.size())# 2,2,4
pa_dict = {}
for x in model.named_parameters():
    print(x[0])
    print(x[1].data)
    pa_dict[x[0]] = x[1].data
print(src)
# RNN运算公式:h_t=tanh(Ux+Wh_{t-1})
# src[:,0,:]:对第1个词进行运算。
# h0.squeeze() 消除为1的那一维。
left1 = torch.tanh(
    torch.matmul(src[:,0,:],pa_dict['weight_ih_l0'].transpose(0,1))+
    pa_dict['bias_ih_l0']+
    torch.matmul(h0.squeeze(),pa_dict['weight_hh_l0'].transpose(0,1))+
    pa_dict['bias_hh_l0'])
left2 = torch.tanh(
    torch.matmul(src[:,1,:],pa_dict['weight_ih_l0'].transpose(0,1))+
    pa_dict['bias_ih_l0']+
    torch.matmul(left1.squeeze(),pa_dict['weight_hh_l0'].transpose(0,1))+
    pa_dict['bias_hh_l0'])
print("left1",left1)
print("left2",left2)
# 逆向时反过来,先对第二个词进行运算。
right2 = torch.tanh(
    torch.matmul(src[:,1,:],pa_dict['weight_ih_l0_reverse'].transpose(0,1))+
    pa_dict['bias_ih_l0_reverse']+
    torch.matmul(h0.squeeze(),pa_dict['weight_hh_l0_reverse'].transpose(0,1))+
    pa_dict['bias_hh_l0_reverse'])
right1 = torch.tanh(
    torch.matmul(src[:,0,:],pa_dict['weight_ih_l0_reverse'].transpose(0,1))+
    pa_dict['bias_ih_l0_reverse']+
    torch.matmul(right2.squeeze(),pa_dict['weight_hh_l0_reverse'].transpose(0,1))+
    pa_dict['bias_hh_l0_reverse'])
print("right1",right1)
print("right2",right2)
print(output2)

output2 输出维度是(2,2,4),output输出维度是(2,2,2),其实可以看出就是进行拼接了。
结果:

left1 tensor([[ 0.0543,  0.2228],
        [ 0.8822, -0.3797]])
left2 tensor([[ 0.2540, -0.0244],
        [ 0.4441,  0.2254]])
right1 tensor([[0.5537, 0.7760],
        [0.5290, 0.5026]])
right2 tensor([[0.5573, 0.7746],
        [0.6490, 0.2075]])
tensor([[[ 0.0543,  0.2228,  0.5537,  0.7760],
         [ 0.2540, -0.0244,  0.5573,  0.7746]],

        [[ 0.8822, -0.3797,  0.5290,  0.5026],
         [ 0.4441,  0.2254,  0.6490,  0.2075]]], grad_fn=<TransposeBackward1>)

如果换成batch_size =1会更明显一点

src = torch.rand((1, 2, 3)) # [bacth_size, 句长(每句2个词) ,词向量(dim=3)]
h0 = torch.randn((1, 1, 2))#
left1 tensor([[ 0.9628, -0.9493]])
left2 tensor([[ 0.2381, -0.9160]])
right1 tensor([[-0.0205,  0.2631]])
right2 tensor([[-0.4846, -0.0552]])
tensor([[[ 0.9628, -0.9493, -0.0205,  0.2631],
         [ 0.2381, -0.9160, -0.4846, -0.0552]]], grad_fn=<TransposeBackward1>)

可以看出最后的结果是[ [left1,right1][left2,right2]]
证明双向RNN的输出和大部分论文使用相同[\overrightarrow{h},\overleftarrow{h}]

相关文章

  • pytorch RNN运算分析

    pytorch中的RNN计算过程,以及双向RNN拼接方式 问题起源是在看论文和读源码时,看到用pytorch中的双...

  • pytorch实现RNN以及LSTM/GRU

    pytorch提供了很方便的RNN模块,以及其他结构像LSTM和GRU。pytorch里的RNN需要的参数主要有:...

  • NLP自然语言理解的学习

    RNN和LSTM网络结构 原文LSTM in pytorch Pytorch上的教学word_embeddings...

  • 循环神经网络pytorch实现

    RNN pytorch 实现 LSTM 输入门: 遗忘门: 输出门: pytorch 实现 GRU 更新门: 候选...

  • CS231n Spring 2019 Assignment 3—

    NetworkVisualization-PyTorch 从上次RNN之后的三次作业就会有PyTorch和Tens...

  • pytorch RNN的一点理解

    本文主要介绍一下RNN的计算规则以及pytorch里面RNN怎么计算的,给自己备注一下。RNN是一种循环神经网络,...

  • Pytorch学习之LSTM识别MNIST数据集

    实验RNN循环神经网络识别MNIST手写数字集 本文主要是讲述pytorch实现的RNN神经网络去识别MNIST手...

  • PyTorch RNN Classification

     循环神经网络RNN让神经网络有了记忆, 对于序列话的数据,循环神经网络能达到更好的效果. 更多可以查看官网 :*...

  • PyTorch RNN Regression

     循环神经网络RNN及时预测时间序列. 更多可以查看官网 :* PyTorch 官网 载入数据 假设想要用 sin...

  • 02-25:RNN算法

    RNN算法 1、RNN算法原理 (1)RNN变种GRU (2)RNN变种LSTM LSTM缺点分析: todo: ...

网友评论

      本文标题:pytorch RNN运算分析

      本文链接:https://www.haomeiwen.com/subject/xekdgktx.html