美文网首页
tensorflow InvalidArgumentError:

tensorflow InvalidArgumentError:

作者: lzhenboy | 来源:发表于2019-04-26 22:02 被阅读0次

1. 错误信息描述

在使用(Bi)-LSTM,(Bi)-RNN 时遇到下述报错信息:

tensorflow.python.framework.errors_impl.InvalidArgumentError: seq_lens(8) > input.dims(1)
[[{{node fusion/bidirectional_rnn/bw/ReverseSequence}} = ReverseSequence[T=DT_FLOAT, Tlen=DT_INT32, batch_dim=0, seq_dim=1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dropout_2/mul, _arg_Placeholder_0_0)]]

截图更清晰: error.png

2. 问题定位

同样的问题,在 StackOverflow 很多人遇到类似的报错信息:
https://stackoverflow.com/questions/49078297/lstm-tensorflow-shape-error
google了一番,没有很多答案,但根据 错误提示(截图标红),并结合源码,可以发现问题是出在 tf/reverse_sequence()这个函数中;
通过查看 tf-api
https://www.tensorflow.org/api_docs/python/tf/reverse_sequence
并结合W3Cschool的中文解释:
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-xhau2ihk.html

image.png

可以 定位问题 在于:
(Bi)-RNN中利用 tf/reverse_sequence() 对序列进行翻转过程中,所传入的 seq_lengths 的值须满足: seq_lengths[i] <= input.dims[seq_dim],即seq_lengths的值不能大于所翻转序列本身的长度!

3. 解决方案

因此,在调用bidirectional_dynamic_rnn(如下)之类函数过程中,需要对传入的参数 seq_length 做一定的限制,使其不大于序列本身的长度!

outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw,
                                                 inputs=seq_encodes,
                                                 sequence_length=seq_length,
                                                 dtype=tf.float32)

如,你可能需要在数据class中做如下处理:

sample['seq_length'] = min(self.max_seq_len, len(sample['seq']))



如果你有任何疑问或不同见解,欢迎在下方留言或私信我一起探讨学习!

参考文献

https://www.tensorflow.org/api_docs/python/tf/reverse_sequence
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-xhau2ihk.html

相关文章

网友评论

      本文标题:tensorflow InvalidArgumentError:

      本文链接:https://www.haomeiwen.com/subject/mqtcnqtx.html