刚开始学习Rnn的结构的时候,经常看到的图是下面这样的:
理所当然的认为,如果用代码构建出来Rnn网络,应该是有多份节点的。而这个数量正好就是timestep的数量。
但是某天某人指出我的理解是错误的,并让我好好看了下tensorflow的lstm的源代码,这才发现对于Rnn的理解完全错误了。
修正后的理解如下:
1.首先节点只有1份,而timestep是用循环实现,所以是无限的。这意味着训练和推理的时候可以把全部的数据都一次传入,只要显存足够大。
2.正是因为节点只有1份,那么前置的网络就不能简单的用个全连接网络了,这样全连接权重数量限制了timestep的数量。可以直接用多层Rnn Node来替代普通全连接层。
3.每个batch的训练开始都会把state清空。
4.在训练的时候如果需要的内存太大,无法一次全部传入网络,可以把中间状态state保存下来,把timestep分成多次进行训练。下次把state传入。
网友评论