老师让我搞attention可视化,总结一下坑。
获取所有的节点:
cczz1=[i.name for i in sess.graph.get_operations()]
参考一本叫做《强化学习精要》的书,可知,每生成一个变量,都会生成数个相关的操作(节点),而且,每个运算都会增加相关的节点,所以可能出现一大堆和你命名相同的节点,
所以这里,采用
dd=[i for i in cczz1 if 节点名称 in i]
找到合适的时候,使用get_tensor_by_name获取tensor
来查找你所要的节点。
这里要注意,操作的首字母要大写,比如Softmax,变量后面要加:0
z=sess.graph.get_tensor_by_name("完整的名称")
但,并不是所有的都可以获取,很多包含动态循环的cell中的数据是无法获取的,如下报错
ValueError: Operation 'modeling/bidirectional_rnn/bw/bw/while/matchlstm/ExpandDims_1' has been marked as not fetchable.
image.png
从stack overflow上查到的。
所以,你需要将动态的lstm转为静态的lstm,
如果你的数据是定长的,那么操作很简单,如下,这里以(batchtime_stepsword_dimence)举例
返回值=tf.unstack(数据,axis=1)
后面所有带dynamic的操作全部换成对应的static函数,
注意:不是直接替换函数名,而是找到对应的函数。
而且,很多static的输入都是sequence类型的(这个我也不懂),
它是一个tensor的列表(类型为列表,元素类型为tensor),如上unstack操作。
但,还有一种情况,你的序列是不定长的,你要把它变成定长的才能调用静态的函数。这个我没有太好的思路,准备以后试一试。
网友评论