美文网首页
tensorflow InvalidArgumentError:

tensorflow InvalidArgumentError:

作者: lzhenboy | 来源:发表于2019-04-26 22:02 被阅读0次

    1. 错误信息描述

    在使用(Bi)-LSTM,(Bi)-RNN 时遇到下述报错信息:

    tensorflow.python.framework.errors_impl.InvalidArgumentError: seq_lens(8) > input.dims(1)
    [[{{node fusion/bidirectional_rnn/bw/ReverseSequence}} = ReverseSequence[T=DT_FLOAT, Tlen=DT_INT32, batch_dim=0, seq_dim=1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](dropout_2/mul, _arg_Placeholder_0_0)]]

    截图更清晰: error.png

    2. 问题定位

    同样的问题,在 StackOverflow 很多人遇到类似的报错信息:
    https://stackoverflow.com/questions/49078297/lstm-tensorflow-shape-error
    google了一番,没有很多答案,但根据 错误提示(截图标红),并结合源码,可以发现问题是出在 tf/reverse_sequence()这个函数中;
    通过查看 tf-api
    https://www.tensorflow.org/api_docs/python/tf/reverse_sequence
    并结合W3Cschool的中文解释:
    https://www.w3cschool.cn/tensorflow_python/tensorflow_python-xhau2ihk.html

    image.png

    可以 定位问题 在于:
    (Bi)-RNN中利用 tf/reverse_sequence() 对序列进行翻转过程中,所传入的 seq_lengths 的值须满足: seq_lengths[i] <= input.dims[seq_dim],即seq_lengths的值不能大于所翻转序列本身的长度!

    3. 解决方案

    因此,在调用bidirectional_dynamic_rnn(如下)之类函数过程中,需要对传入的参数 seq_length 做一定的限制,使其不大于序列本身的长度!

    outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw,
                                                     inputs=seq_encodes,
                                                     sequence_length=seq_length,
                                                     dtype=tf.float32)
    

    如,你可能需要在数据class中做如下处理:

    sample['seq_length'] = min(self.max_seq_len, len(sample['seq']))
    



    如果你有任何疑问或不同见解,欢迎在下方留言或私信我一起探讨学习!

    参考文献

    https://www.tensorflow.org/api_docs/python/tf/reverse_sequence
    https://www.w3cschool.cn/tensorflow_python/tensorflow_python-xhau2ihk.html

    相关文章

      网友评论

          本文标题:tensorflow InvalidArgumentError:

          本文链接:https://www.haomeiwen.com/subject/mqtcnqtx.html