美文网首页
Tensorflow添加正则项的三种方式

Tensorflow添加正则项的三种方式

作者: cheerss | 来源:发表于2018-10-27 15:33 被阅读0次
    1. 手写计算出正则项的大小,并通过add_to_collection()方法把它加入到collection中,在需要的时候再通过get_collection()方法取出来。
      • 优势:正则项的计算表达式可以随意定义
      • 劣势:需要手写比较麻烦,毕竟常用的正则项Tensorflow中都是有的
    import os
    import tensorflow as tf
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    
    def _variable_with_weight_decay(name, wd):
        a = tf.get_variable(name, [1], initializer=tf.constant_initializer(1.0))
        loss = tf.multiply(tf.nn.l2_loss(a), wd, name="weight_loss")
        tf.add_to_collection("loss", loss)
        tf.add_to_collection(tf.GraphKeys.WEIGHTS, a)
        return a
    
    def main():
        with tf.Graph().as_default():
            a = _variable_with_weight_decay("a", 0.1)
            b = _variable_with_weight_decay("b", 0.1)
            init_op = tf.global_variables_initializer()
            all_weight_decay = tf.get_collection("loss")
            all_weights = tf.get_collection(tf.GraphKeys.WEIGHTS)
            print(all_weights)
            # 输出 [<tf.Variable 'a:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'b:0' shape=(1,) dtype=float32_ref>]
            loss_all = tf.add_n(tf.get_collection('loss'))
            
            with tf.Session() as sess:
                sess.run(init_op)
                res = sess.run(all_weight_decay)
                print(res)
                # 输出 [0.05, 0.05]
            
    if __name__ == "__main__":
        main()
    
    1. 在variable_scope中指定,或者在get_variable中指定,当然,tf.layers.conv2d等API中也都可以指定正则项。其实variable_scope就像一个默认值。所有的正则项会被加入到tf.GraphKeys.REGULARIZATION_LOSSES中,这个东西也可以通过get_collection()获取。
      • 优势:方便,可以直接调用API中自带的正则项
      • 劣势:自带的regularizer有限
    import os
    import tensorflow as tf
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    
    def main():
        with tf.Graph().as_default():
            with tf.variable_scope("first", regularizer=tf.contrib.layers.l2_regularizer(0.1)) as scope:
                a = tf.get_variable("a", [1], initializer=tf.constant_initializer(1.0))
            init_op = tf.global_variables_initializer()
            print(tf.get_collection(tf.GraphKeys.WEIGHTS))
            print(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
            loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
            
            with tf.Session() as sess:
                sess.run(init_op)
                res = sess.run(loss)
                print(res)
            
    if __name__ == "__main__":
        main()
    
    1. 手动创建一个regularizer后作用于变量。这种方式可以不仅限于tf.variable_scope(),而是所有变量。其中tf.contrib.layers.apply_regularization()这个函数如果不指定weights_list参数则默认作用于tf.GraphKeys.WEIGHTS
    import os
    import tensorflow as tf
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"
    
    def main():
        with tf.Graph().as_default():
            with tf.variable_scope("first") as scope:
                regularizer = tf.contrib.layers.l2_regularizer(0.1)
                a = tf.get_variable("a", [1], initializer=tf.constant_initializer(1.0))
                b = tf.get_variable("b", [1], initializer=tf.constant_initializer(1.0))
                
            tf.contrib.layers.apply_regularization(regularizer, weights_list=[a, b])
            init_op = tf.global_variables_initializer()
            loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
            
            with tf.Session() as sess:
                sess.run(init_op)
                res = sess.run(loss)
                print(res)
            
    if __name__ == "__main__":
        main()
    

    相关文章

      网友评论

          本文标题:Tensorflow添加正则项的三种方式

          本文链接:https://www.haomeiwen.com/subject/erjdtqtx.html