美文网首页
name_scope, variable_scope的理解

name_scope, variable_scope的理解

作者: 富有的心 | 来源:发表于2018-05-24 12:03 被阅读0次

    tf. get_variable()
    用于获取一个变量,先搜索变量名,没有就新建,有就直接用,并且不受name_scope的约束
    遇到重名的变量创建且变量名没有设置为共享变量时,则会报错

    tf.Variable()
    每次都会新建变量
    会自动检测命名冲突并自行处理

    name_scope
    作用于操作:主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。

    variable_scope
    可以通过设置reuse 标志以及初始化方式来影响域下的变量,一般与tf.name_scope()配合使用,用于管理一个graph中变量的名字,避免变量之间的命名冲突,tf.variable_scope(<scope_name>)允许在一个variable_scope下面共享变量。

    import tensorflow as tf
    
    with tf.name_scope('name_scope_x'):
        var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
        var3 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
        var4 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.name, sess.run(var1))
        print(var3.name, sess.run(var3))
        print(var4.name, sess.run(var4))
    # 输出结果:
    # var1:0 [-0.30036557]   可以看到前面不含有指定的'name_scope_x'
    # name_scope_x/var2:0 [ 2.]
    # name_scope_x/var2_1:0 [ 2.]  可以看到变量名自行变成了'var2_1',避免了和'var2'冲突
    

    如果使用tf.get_variable()创建变量,且没有设置共享变量,重名时会报错

    作者:C Li
    链接:https://www.zhihu.com/question/54513728/answer/181819324
    来源:知乎
    著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
    
    import tensorflow as tf
    
    with tf.name_scope('name_scope_1'):
        var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
        var2 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.name, sess.run(var1))
        print(var2.name, sess.run(var2))
    
    # ValueError: Variable var1 already exists, disallowed. Did you mean 
    # to set reuse=True in VarScope? Originally defined at:
    # var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    

    所以要共享变量,需要使用tf.variable_scope()

    作者:C Li
    链接:https://www.zhihu.com/question/54513728/answer/181819324
    来源:知乎
    著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
    
    import tensorflow as tf
    
    with tf.variable_scope('variable_scope_y') as scope:
        var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
        scope.reuse_variables()  # 设置共享变量
        var1_reuse = tf.get_variable(name='var1')
        var2 = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
        var2_reuse = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.name, sess.run(var1))
        print(var1_reuse.name, sess.run(var1_reuse))
        print(var2.name, sess.run(var2))
        print(var2_reuse.name, sess.run(var2_reuse))
    # 输出结果:
    # variable_scope_y/var1:0 [-1.59682846]
    # variable_scope_y/var1:0 [-1.59682846]   可以看到变量var1_reuse重复使用了var1
    # variable_scope_y/var2:0 [ 2.]
    # variable_scope_y/var2_1:0 [ 2.]
    

    也可以这样

    with tf.variable_scope('foo') as foo_scope:
        v = tf.get_variable('v', [1])
    with tf.variable_scope('foo', reuse=True):
        v1 = tf.get_variable('v')
    assert v1 == v
    

    或者这样:

    with tf.variable_scope('foo') as foo_scope:
        v = tf.get_variable('v', [1])
    with tf.variable_scope(foo_scope, reuse=True):
        v1 = tf.get_variable('v')
    assert v1 == v
    

    相关文章

      网友评论

          本文标题:name_scope, variable_scope的理解

          本文链接:https://www.haomeiwen.com/subject/jptndftx.html