美文网首页算法工程
tensorflow 1.x:feed_dict原理测试,以唯一

tensorflow 1.x:feed_dict原理测试,以唯一

作者: xiaogp | 来源:发表于2021-12-30 12:44 被阅读0次

    摘要:tensorflow

    问题描述

    今天看到GCN源码采用的tensorflow 1.x版本,在调用session run图结构时最后的feed_dict是这么写的

    # 截取了部分
    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
    outs = sess.run([model.opt_op, model.loss, model.accuracy], feed_dict=feed_dict)
    
    placeholders = {
        'features': tf.sparse_placeholder(tf.float32, shape=tf.constant(features[2], dtype=tf.int64)),
        'labels': tf.placeholder(tf.float32, shape=(None, y_train.shape[1])),
        'dropout': tf.placeholder_with_default(0., shape=())
    }
    

    根据他的逻辑取到placeholders字典的key,他的feed_dict的key是tensor对象,就是这么个玩意{tf.placeholder_with_default(0., shape=()): 0.5},最开始看到这个百思不得其解,一般的写法都是代码中定义一个变量,赋值为占位符,feed_dict中引用这个变量,按照上面的写法如何找到变量?也有可能这是基操我一直没有看过,下面我来测试一下,毕竟作者的代码没有报错可以跑。


    将tensor赋值给变量,以变量作为feed_dict的key

    这种是最常见的也是本人从接触tensorflow1.x到现在一直使用,在代码中将tensor赋值给一个变量,feed_dict传变量为key

    placeholder = {'dropout': tf.placeholder(tf.float32)}
    x = placeholder['dropout']
    with tf.Session() as sess:
        print(sess.run([x], feed_dict={x: 0.5}))
    
    # 输出
    [array(0.5, dtype=float32)]
    

    和明显是可以的,相当于告诉代码要找到x变量,将0.5传给x表示的占位符,当然自定义的变量x要唯一


    将tensor装入其他集合,以引用方式作为feed_dict的key

    这种就是作者的方法,我测一下

    placeholder = {'dropout': tf.placeholder(tf.float32)}
    x = placeholder['dropout']
    with tf.Session() as sess:
        print(sess.run([x], feed_dict={placeholder['dropout']: 0.5}))
    
    # 输出
    [array(0.5, dtype=float32)]
    

    厉害啊可以的结果是一样的,问题基本清楚了,这个地方作者采用的引用的方式,猜测一下虽然没有指定变量,但是tensorflow此时还是根据tf.placeholder(tf.float32)找到了要传值的tensor,因为引用对象保证了唯一


    重新定义相同的tensor作为feed_dict的key

    承接上面的猜想,这种方式应该是不行的,想想都觉得不行,也是文章最开始疑惑的地方

        
    placeholder = {'dropout': tf.placeholder(tf.float32)}
    x = placeholder['dropout']
    with tf.Session() as sess:
        print(sess.run([x], feed_dict={tf.placeholder(tf.float32): 0.5}))
    
    # 直接报错
    You must feed a value for placeholder tensor 'Placeholder_14' with dtype float
    

    根据报错内容很明显tensorflow没有找到tensor的匹配关系,这是因为如果把tf.placeholder(tf.float32)作为feed_dict的key相当于是新建了一个tensor,虽然创建的语句一模一样也是不同的对象,类似于Python的类就算所有属性一样,重新实例化一个两个对象的id不一样

    class a(object):
        def __init__(self, name):
            self.name = name
        def get_name(self):
            return self.name
    
    id(a('bb'))
    Out[99]: 140226928417232
    id(a('bb'))
    Out[100]: 140226928480528
    

    使用tensor name获取唯一tensor对象作为key

    既然知道了要让tensorflow识别到唯一的tensor对象,以上两种成功的都是引用的方式,还有一种就是通过tensor name来确定位移,使用tf.get_default_graph().get_tensor_by_name()获得tensor对象

    placeholder = {'dropout': tf.placeholder(tf.float32, name='p6')}
    x = placeholder['dropout']
    with tf.Session() as sess:
        print(sess.run([x], feed_dict={tf.get_default_graph().get_tensor_by_name('p6:0'): 0.5}))
    
    # 输出
    [array(0.5, dtype=float32)]
    

    不出所料是可以,在tensorflow中tensor的name的格式为

    <op_name>:<output_index>
    

    由节点名称和输出索引构成


    总结

    feed_dict的目的是连接tensor对象和真实数据对象,由于在graph中会定义很多tensor对象,所以feed_dict的key必须要能唯一标识出tensor对象,一种方式是指定tensor name,这种较为复杂容易弄错输出索引,另一种是直接再引用一次模型中定义的tensor对象,因此可以将模型的tensor对象赋值传递,或者装入集合中在通过集合取出,或者其他方式。
    作者以一个字典集合作为中间桥梁,使得在feed_dict中不出现模型中的属性名称,一定程度上遮蔽了模型层,起到解耦作用

    相关文章

      网友评论

        本文标题:tensorflow 1.x:feed_dict原理测试,以唯一

        本文链接:https://www.haomeiwen.com/subject/sosiqrtx.html