美文网首页
tensorflow固定部分参数进行训练

tensorflow固定部分参数进行训练

作者: 大脸猫猫脸大 | 来源:发表于2019-05-05 18:11 被阅读0次

    模型预览

    假设新模型分encoder+decoder两部分。其中encoder模块要导入预训练的参数,并且数值固定,不参与训练。decoder则是在encoder的基础上增加的分支,需要通过数据训练不断优化参数。

    大体步骤

    主要分为四个步骤:
    1. 绘制整体网络图
    2. 固定encoder参数
    3. 导入encoder参数
    4. 训练 + 模型保存

    代码

    part1:画图

    #设置网络整体结构....
    

    part2:固定参数

    # 选择decode部分的参数
    train_var_list = [var for var in tf.trainable_variables() if 'decode' in var.name] 
    # 优化器只优化选中的参数list
    with tf.control_dependencies():
          optimizer = optimizer.minimize(loss, global_step=global_step, var_list = train_var_list) #自行选择优化器
    

    part3 导入旧参

    # 选择encode部分参数
    no_train_var = [var for var in tf.global_variables() if 'encode' in var.name]  #这里的'encode'是在设置网络过程中某个scope的命名
    # saver选择要导入的参数
    saver = tf.train.Saver(no_train_var)
    # 对整个网络所有参数做初始化
    init = tf.global_variables_initializer()
    sess.run(init)
    # encode部分参数覆盖
    saver.restore(sess, weights_path) #这里的weights_path是ckpt文件保存路径
    

    part4 训练+保存

    # 训练......
    # 保存模型
    # 重新定义saver为选中所有参数,否则最后将只保存no_train_var
    saver = tf.train.Saver()
    saver.save(sess=sess, save_path=model_save_path, global_step=epoch) 
    

    其他

    1. 对于该网络还有另外一种方法:encode前向传播保存结果,将其作为decode网络输入,进行训练。
    2. 模型导入还有其他方法,可参考https://blog.csdn.net/CV_YOU/article/details/80698942
      不同类型的模型(npy, ckpt)导入保存方式有差异。
    3. 固定参数还可以在构建网络的时候选择变量的trainable为False,或者设置变量学习率为0.

    参数导入方法2

    当预训练模型和新模型的图不同时,无法用Saver导入参数,这时候要用到tf.assign函数。
    假设预训练模型只有encode部分,新模型encode+decode。遍历模型参数,用预训练参数进行替换。

    代码

    part3 导入旧参

    # 导入所有参数
    saver = tf.train.Saver()
    # 对整个网络所有参数做初始化
    init = tf.global_variables_initializer()
    sess.run(init)
    #读取预训练模型
    reader = pywrap_tensorflow.NewCheckpointReader(weights_path)
    # 逐层遍历参数并替换
    for vv in tf.trainable_variables():
        if 'encode' in vv.name:
            weights = reader.get_tensor(weights_key)
            _op = tf.assign(vv, weights)
            sess.run(_op)
    

    相关文章

      网友评论

          本文标题:tensorflow固定部分参数进行训练

          本文链接:https://www.haomeiwen.com/subject/tgmhoqtx.html