摘要:tensorflow
问题描述
今天看到GCN源码采用的tensorflow 1.x版本,在调用session run图结构时最后的feed_dict是这么写的
# 截取了部分
feed_dict.update({placeholders['dropout']: FLAGS.dropout})
outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
placeholders = {
'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)),
'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
'dropout': tf.placeholder_with_default(0., shape=())
}
根据他的逻辑取到placeholders字典的key,他的feed_dict的key是tensor对象,就是这么个玩意{tf.placeholder_with_default(0., shape=()): 0.5}
,最开始看到这个百思不得其解,一般的写法都是代码中定义一个变量,赋值为占位符,feed_dict中引用这个变量,按照上面的写法如何找到变量?也有可能这是基操我一直没有看过,下面我来测试一下,毕竟作者的代码没有报错可以跑。
将tensor赋值给变量,以变量作为feed_dict的key
这种是最常见的也是本人从接触tensorflow1.x到现在一直使用,在代码中将tensor赋值给一个变量,feed_dict传变量为key
placeholder = {'dropout': tf.placeholder(tf.float32)}
x = placeholder['dropout']
with tf.Session() as sess:
print(sess.run([x], feed_dict={x: 0.5}))
# 输出
[array(0.5, dtype=float32)]
和明显是可以的,相当于告诉代码要找到x变量,将0.5传给x表示的占位符,当然自定义的变量x要唯一
将tensor装入其他集合,以引用方式作为feed_dict的key
这种就是作者的方法,我测一下
placeholder = {'dropout': tf.placeholder(tf.float32)}
x = placeholder['dropout']
with tf.Session() as sess:
print(sess.run([x], feed_dict={placeholder['dropout']: 0.5}))
# 输出
[array(0.5, dtype=float32)]
厉害啊可以的结果是一样的,问题基本清楚了,这个地方作者采用的引用的方式,猜测一下虽然没有指定变量,但是tensorflow此时还是根据tf.placeholder(tf.float32)找到了要传值的tensor,因为引用对象保证了唯一
重新定义相同的tensor作为feed_dict的key
承接上面的猜想,这种方式应该是不行的,想想都觉得不行,也是文章最开始疑惑的地方
placeholder = {'dropout': tf.placeholder(tf.float32)}
x = placeholder['dropout']
with tf.Session() as sess:
print(sess.run([x], feed_dict={tf.placeholder(tf.float32): 0.5}))
# 直接报错
You must feed a value for placeholder tensor 'Placeholder_14' with dtype float
根据报错内容很明显tensorflow没有找到tensor的匹配关系,这是因为如果把tf.placeholder(tf.float32)作为feed_dict的key相当于是新建了一个tensor,虽然创建的语句一模一样也是不同的对象,类似于Python的类就算所有属性一样,重新实例化一个两个对象的id不一样
class a(object):
def __init__(self, name):
self.name = name
def get_name(self):
return self.name
id(a('bb'))
Out[99]: 140226928417232
id(a('bb'))
Out[100]: 140226928480528
使用tensor name获取唯一tensor对象作为key
既然知道了要让tensorflow识别到唯一的tensor对象,以上两种成功的都是引用的方式,还有一种就是通过tensor name来确定位移,使用tf.get_default_graph().get_tensor_by_name()
获得tensor对象
placeholder = {'dropout': tf.placeholder(tf.float32, name='p6')}
x = placeholder['dropout']
with tf.Session() as sess:
print(sess.run([x], feed_dict={tf.get_default_graph().get_tensor_by_name('p6:0'): 0.5}))
# 输出
[array(0.5, dtype=float32)]
不出所料是可以,在tensorflow中tensor的name的格式为
<op_name>:<output_index>
由节点名称和输出索引构成
总结
feed_dict的目的是连接tensor对象和真实数据对象,由于在graph中会定义很多tensor对象,所以feed_dict的key必须要能唯一标识出tensor对象,一种方式是指定tensor name,这种较为复杂容易弄错输出索引,另一种是直接再引用一次模型中定义的tensor对象,因此可以将模型的tensor对象赋值传递,或者装入集合中在通过集合取出,或者其他方式。
作者以一个字典集合作为中间桥梁,使得在feed_dict中不出现模型中的属性名称,一定程度上遮蔽了模型层,起到解耦作用。
网友评论