参考文献:
tf.control_dependencies()
设计是用来控制计算流图的,给图中的某些计算指定顺序。
比如:我们想要获取参数更新后的值,那么我们可以这么组织我们的代码。
opt = tf.train.Optimizer().minize(loss)
with tf.control_dependencies([opt]):
updated_weight = tf.identity(weight)
with tf.Session() as sess:
tf.global_variables_initializer().run()
sess.run(updated_weight, feed_dict={...}) # 这样每次得到的都是更新后的weight
关于tf.control_dependencies的具体用法,请移步官网:https://www.tensorflow.org/api_docs/python/tf/Graph#control_dependencies。
总结一句话就是,在执行某些op, tensor之前,某些op, tensor得首先被运行。
tf.identity()
的用法:
tf.identity(input,name=None)
#Return a tensor with the same shape and contents as input.
#返回一个tensor,contents和shape都和input的一样。
简单来说tf.identity()就是返回一个一模一样的新的tensor。
(别人的总结:为cpu\gpu传输提供更好的性能。就像你做一个电路板,有些地方要把线路印出来,调试的时候可以看到中间结果一样,tf.identity()就是为了在图上显示这个值而创建的虚拟节点)
在Stack Overflow中有一个问题对tf.identity()进行了举例,具体如下:
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
y = x
init = tf.global_variables_initializer()
with tf.Session() as session:
init.run() # 相当于session.run(init)
for i in xrange(5):
print(y.eval()) # y.eval()这个相当于session.run(y)
输出:
0.0
0.0
0.0
0.0
0.0
我们的理想输出应该是:[1.0, 2.0, 3.0, 4.0, 5.0]。
之所以会产生错误结果是由于以下两个原因:
1. tf.control_dependencies()是一个在Graph上的operation,所以要想使得其参数起作用,就需要在代码11处利用sess.run()来执行;
2. y = x只是一个简单的赋值操作,而with tf.control_dependencies()作用域(也就是冒号下的代码行)只对op起作用,所以需要将tensor利用tf.identity()来转化为op。
针对以上原因,给出两个相应的解决方法:
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
y = x
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run() # 相当于session.run(init)
for i in range(5):
sess.run(x_plus_1)
print(y.eval()) # y.eval()这个相当于session.run(y)
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
y = tf.identity(x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run() # 相当于session.run(init)
for i in range(5):
print(y.eval()) # y.eval()这个相当于session.run(y)
总体来说,Graph上不论是tensor还是operation的更新都要借助op来进行,而将一个tensor转化为op最简单的方法就是tf.identity()。
网友评论