1. 错误信息描述
在使用(Bi)-LSTM,(Bi)-RNN 时遇到下述报错信息:
tensorflow.python.framework.errors_impl.InvalidArgumentError: seq_lens(8) > input.dims(1)
[[{{node fusion/bidirectional_rnn/bw/ReverseSequence}} = ReverseSequence[T=DT_FLOAT, Tlen=DT_INT32, batch_dim=0, seq_dim=1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dropout_2/mul, _arg_Placeholder_0_0)]]
2. 问题定位
同样的问题,在 StackOverflow 很多人遇到类似的报错信息:
https://stackoverflow.com/questions/49078297/lstm-tensorflow-shape-error
google了一番,没有很多答案,但根据 错误提示(截图标红),并结合源码,可以发现问题是出在 tf/reverse_sequence()这个函数中;
通过查看 tf-api:
https://www.tensorflow.org/api_docs/python/tf/reverse_sequence
并结合W3Cschool的中文解释:
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-xhau2ihk.html
可以 定位问题 在于:
(Bi)-RNN中利用 tf/reverse_sequence() 对序列进行翻转过程中,所传入的 seq_lengths 的值须满足: seq_lengths[i] <= input.dims[seq_dim],即seq_lengths的值不能大于所翻转序列本身的长度!
3. 解决方案
因此,在调用bidirectional_dynamic_rnn(如下)之类函数过程中,需要对传入的参数 seq_length 做一定的限制,使其不大于序列本身的长度!
outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw,
inputs=seq_encodes,
sequence_length=seq_length,
dtype=tf.float32)
如,你可能需要在数据class中做如下处理:
sample['seq_length'] = min(self.max_seq_len, len(sample['seq']))
如果你有任何疑问或不同见解,欢迎在下方留言或私信我一起探讨学习!
参考文献
https://www.tensorflow.org/api_docs/python/tf/reverse_sequence
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-xhau2ihk.html
网友评论