-
tf.assign
是创建一个操作符这个操作符具有这个值变量的值,而=
是Python中的赋值,tensorflow函数的操作会新建一个节点,如果用Python的=
那么就相当于将变量的引用给到这个新节点上,但是在计算图上并没有相应的赋值操作节点(因为只是python对于等式右边节点的一个引用而已),而如果使用tf.assign
的话计算图中有赋值节点。 - 你只要分清哪些是
tensorflow
中的操作和哪些是python
语言的引用操作,就能分清哪些是在建图,哪些只是在改变引用。
第一个例子:因为op是tensorflow中的一个结点,而assign_add是对原始的节点a进行赋值,所以最终的结果是7。
a = tf.Variable(3)
op = tf.assign_add(a,1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(op)
sess.run(op)
sess.run(op)
sess.run(op)
print(sess.run(a))
>>7
第二个例子:注意a = a + 1
的实际操作是首先将右边a
的节点加上1,这是一个新建节点操作,a+1
返回的是这个新建节点,此时a = 新建节点
。也就是a引用的节点地址从变量a变成了a+1操作符的引用,我们打印出来可以看到,此时的a是一个add操作
。这是run(a)
就不会修改原始节点name="a"
的值,也就是始终为1
,也就是关键问题是在计算图上没有给原始的变量a进行赋值,所以他的值始终是1,加了1以后add操作符打印出来是2。
a = tf.Variable(3)
a = a + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(a)
sess.run(a)
sess.run(a)
sess.run(a)
print(a)
print(sess.run(a))
>>Tensor("add:0", shape=(), dtype=int32)
4
网友评论