首先介绍两个创建variable的方法
- tf.Variable(initial_value, name, dtype, trainable, collection)
- tf.get_variable(name, shape, dtype, initializer, trainable, collection)
其中,tf.Variable
每次调用都会创建一个新的变量,如果变量名字相同,就在后面加N:
first_a = tf.Variable(name='a', initial_value=1, dtype=tf.int32)
second_a = tf.Variable(name='a', initial_value=1, dtype=tf.int32)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(first_a.name) # a_1:0
print(second_a.name) # a_2:0
而,tf.get_variable
的做法是,如果这个变量名字已经存在了,就拿这个变量,不再创建新的变量。
但是需要注意的是,一定要在scope中,使用reuse这个选项,如下是错误的。
first_a = tf.get_variable(name='a', shape=(1), initializer=tf.zeros_initializer, dtype=tf.int32)
second_a = tf.get_variable(name='a', shape=(1), initializer=tf.zeros_initializer, dtype=tf.int32)
不使用reuse是不能get_variable相同名字的变量的;而使用resue又只能在variable_scope中:
with tf.variable_scope('var_scope') as scope:
v = tf.get_variable(name='v', shape=[1], initializer=tf.zeros_initializer)
with tf.variable_scope(scope, reuse=True):
v1 = tf.get_variable(name='v', shape=[1], initializer=tf.zeros_initializer)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
assert v == v1
print(v.name) #var_scope/v:0
print(v1.name) #var_scope/v:0
tf.name_scope中
对tf.get_variable
不起作用,只对tf.Variable
起作用
with tf.name_scope("my_scope"):
v1 = tf.get_variable("var1", [1], dtype=tf.float32)
v2 = tf.Variable(1, name="var2", dtype=tf.float32)
a = tf.add(v1, v2)
print(v1.name) # var1:0
print(v2.name) # my_scope/var2:0
print(a.name) # my_scope/Add:0
tf.variable_scope中
对tf.get_variable
和tf.Variable
都起作用
with tf.variable_scope("my_scope"):
v1 = tf.get_variable("var1", [1], dtype=tf.float32)
v2 = tf.Variable(1, name="var2", dtype=tf.float32)
a = tf.add(v1, v2)
print(v1.name) # my_scope/var1:0
print(v2.name) # my_scope/var2:0
print(a.name) # my_scope/Add:0
这种机制允许在不用的name_scope中使用tf.get_variable
来share变量,但是需要注意的是,一定要声明reuse:
with tf.name_scope("foo"):
with tf.variable_scope("var_scope"):
v = tf.get_variable("var", [1])
with tf.name_scope("bar"):
with tf.variable_scope("var_scope", reuse=True):
v1 = tf.get_variable("var", [1])
assert v1 == v
print(v.name) # var_scope/var:0
print(v1.name) # var_scope/var:0****
网友评论