美文网首页
tf.get_collection()

tf.get_collection()

作者: yalesaleng | 来源:发表于2018-07-16 16:10 被阅读273次

    此函数有两个参数,key和scope。

    Args:

    • 1.key: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
    • 2.scope: (Optional.) If supplied, the resulting list is filtered to include only items whose name attribute matches using re.match. Items without a name attribute are never returned if a scope is supplied and the choice or re.match means that a scope without special tokens filters by prefix.

    举个例子:

    # 在'My-TensorFlow-tutorials-master/02 CIFAR10/cifar10.py'代码中
    
      variables = tf.get_collection(tf.GraphKeys.VARIABLES)
      for i in variables:
      print(i)
    
    >>>   <tf.Variable 'conv1/weights:0' shape=(3, 3, 3, 96) dtype=float32_ref>
          <tf.Variable 'conv1/biases:0' shape=(96,) dtype=float32_ref>
          <tf.Variable 'conv2/weights:0' shape=(3, 3, 96, 64) dtype=float32_ref>
          <tf.Variable 'conv2/biases:0' shape=(64,) dtype=float32_ref>
          <tf.Variable 'local3/weights:0' shape=(16384, 384) dtype=float32_ref>
          <tf.Variable 'local3/biases:0' shape=(384,) dtype=float32_ref>
          <tf.Variable 'local4/weights:0' shape=(384, 192) dtype=float32_ref>
          <tf.Variable 'local4/biases:0' shape=(192,) dtype=float32_ref>
          <tf.Variable 'softmax_linear/softmax_linear:0' shape=(192, 10) dtype=float32_ref>
          <tf.Variable 'softmax_linear/biases:0' shape=(10,) dtype=float32_ref>
    

    tf.get_collection会列出key里所有的值。


    进一步地:

    tf.GraphKeys的点后可以跟很多类,
    比如VARIABLES类(包含所有variables),
    比如REGULARIZATION_LOSSES。

    具体tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)的使用:

    def easier_network(x, reg):
      """ A network based on tf.contrib.learn, with input `x`. """
      with tf.variable_scope('EasyNet'):
         out = layers.flatten(x)
         out = layers.fully_connected(out,
                                      num_outputs=200,
                                      weights_initializer = layers.xavier_initializer(uniform=True),
                                      weights_regularizer = layers.l2_regularizer(scale=reg),
                                      activation_fn = tf.nn.tanh)
         out = layers.fully_connected(out,
                                      num_outputs=200,
                                      weights_initializer = layers.xavier_initializer(uniform=True),
                                      weights_regularizer = layers.l2_regularizer(scale=reg),
                                      activation_fn = tf.nn.tanh)
         out = layers.fully_connected(out,
                                      num_outputs=10, # Because there are ten digits!
                                      weights_initializer = layers.xavier_initializer(uniform=True),
                                      weights_regularizer = layers.l2_regularizer(scale=reg),
                                      activation_fn = None)
         return out
    
     def main(_):
      mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
      x = tf.placeholder(tf.float32, [None, 784])
      y_ = tf.placeholder(tf.float32, [None, 10])
    
      # Make a network with regularization
      y_conv = easier_network(x, FLAGS.regu)
      weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet')
      print("")
      for w in weights:
         shp = w.get_shape().as_list()
         print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
         print("")
         reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet')
      for w in reg_ws:
         shp = w.get_shape().as_list()
         print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
         print("")
    
      # Make the loss function `loss_fn` with regularization.
      cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
      loss_fn = cross_entropy + tf.reduce_sum(reg_ws)
      train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn)
    
    main()
    
    >>>   - EasyNet/fully_connected/weights:0 shape:[784, 200] size:156800
          - EasyNet/fully_connected/biases:0 shape:[200] size:200
          - EasyNet/fully_connected_1/weights:0 shape:[200, 200] size:40000
          - EasyNet/fully_connected_1/biases:0 shape:[200] size:200
          - EasyNet/fully_connected_2/weights:0 shape:[200, 10] size:2000
          - EasyNet/fully_connected_2/biases:0 shape:[10] size:10
    
          - EasyNet/fully_connected/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
          - EasyNet/fully_connected_1/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
          - EasyNet/fully_connected_2/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
    

    据:

    for w in reg_ws:
         shp = ....
    

    这段代码的输出可知,
    在图上的所有regularization都会集中保存到tf.GraphKeys.REGULARIZATION_LOSSES去。

    关于collection的详情请参见:
    http://blog.csdn.net/shenxiaolu1984/article/details/52815641

    关于tf.GraphKeys.REGULARIZATION_LOSSES的详情参见:
    https://gxnotes.com/article/178205.html

    相关文章

      网友评论

          本文标题:tf.get_collection()

          本文链接:https://www.haomeiwen.com/subject/gwybpftx.html