美文网首页
tf.control_dependencies()和tf.ide

tf.control_dependencies()和tf.ide

作者: yalesaleng | 来源:发表于2018-07-16 17:09 被阅读740次

参考文献:

tf.control_dependencies()设计是用来控制计算流图的,给图中的某些计算指定顺序。
比如:我们想要获取参数更新后的值,那么我们可以这么组织我们的代码。

opt = tf.train.Optimizer().minize(loss)
with tf.control_dependencies([opt]):
  updated_weight = tf.identity(weight)

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  sess.run(updated_weight, feed_dict={...}) # 这样每次得到的都是更新后的weight

关于tf.control_dependencies的具体用法,请移步官网:https://www.tensorflow.org/api_docs/python/tf/Graph#control_dependencies

总结一句话就是,在执行某些op, tensor之前,某些op, tensor得首先被运行。


tf.identity()的用法:

tf.identity(input,name=None)
#Return a tensor with the same shape and contents as input.
#返回一个tensor,contents和shape都和input的一样。

简单来说tf.identity()就是返回一个一模一样的新的tensor。
(别人的总结:为cpu\gpu传输提供更好的性能。就像你做一个电路板,有些地方要把线路印出来,调试的时候可以看到中间结果一样,tf.identity()就是为了在图上显示这个值而创建的虚拟节点)

在Stack Overflow中有一个问题对tf.identity()进行了举例,具体如下:

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
    y = x
init = tf.global_variables_initializer()
with tf.Session() as session:
    init.run() # 相当于session.run(init)
    for i in xrange(5):
        print(y.eval()) # y.eval()这个相当于session.run(y)

输出:
   0.0
   0.0
   0.0
   0.0
   0.0

我们的理想输出应该是:[1.0, 2.0, 3.0, 4.0, 5.0]。
之所以会产生错误结果是由于以下两个原因:
1. tf.control_dependencies()是一个在Graph上的operation,所以要想使得其参数起作用,就需要在代码11处利用sess.run()来执行;
2. y = x只是一个简单的赋值操作,而with tf.control_dependencies()作用域(也就是冒号下的代码行)只对op起作用,所以需要将tensor利用tf.identity()来转化为op。

针对以上原因,给出两个相应的解决方法:

x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
    y = x
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run() # 相当于session.run(init)
    for i in range(5):
        sess.run(x_plus_1)
        print(y.eval()) # y.eval()这个相当于session.run(y)
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
    y = tf.identity(x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run() # 相当于session.run(init)
    for i in range(5):
        print(y.eval()) # y.eval()这个相当于session.run(y)

总体来说,Graph上不论是tensor还是operation的更新都要借助op来进行,而将一个tensor转化为op最简单的方法就是tf.identity()。

相关文章

网友评论

      本文标题:tf.control_dependencies()和tf.ide

      本文链接:https://www.haomeiwen.com/subject/pybbpftx.html