美文网首页
tensorflow固定部分参数进行训练

tensorflow固定部分参数进行训练

作者: 大脸猫猫脸大 | 来源:发表于2019-05-05 18:11 被阅读0次

模型预览

假设新模型分encoder+decoder两部分。其中encoder模块要导入预训练的参数,并且数值固定,不参与训练。decoder则是在encoder的基础上增加的分支,需要通过数据训练不断优化参数。

大体步骤

主要分为四个步骤:
1. 绘制整体网络图
2. 固定encoder参数
3. 导入encoder参数
4. 训练 + 模型保存

代码

part1:画图

#设置网络整体结构....

part2:固定参数

# 选择decode部分的参数
train_var_list = [var for var in tf.trainable_variables() if 'decode' in var.name] 
# 优化器只优化选中的参数list
with tf.control_dependencies():
      optimizer = optimizer.minimize(loss, global_step=global_step, var_list = train_var_list) #自行选择优化器

part3 导入旧参

# 选择encode部分参数
no_train_var = [var for var in tf.global_variables() if 'encode' in var.name]  #这里的'encode'是在设置网络过程中某个scope的命名
# saver选择要导入的参数
saver = tf.train.Saver(no_train_var)
# 对整个网络所有参数做初始化
init = tf.global_variables_initializer()
sess.run(init)
# encode部分参数覆盖
saver.restore(sess, weights_path) #这里的weights_path是ckpt文件保存路径

part4 训练+保存

# 训练......
# 保存模型
# 重新定义saver为选中所有参数,否则最后将只保存no_train_var
saver = tf.train.Saver()
saver.save(sess=sess, save_path=model_save_path, global_step=epoch) 

其他

  1. 对于该网络还有另外一种方法:encode前向传播保存结果,将其作为decode网络输入,进行训练。
  2. 模型导入还有其他方法,可参考https://blog.csdn.net/CV_YOU/article/details/80698942
    不同类型的模型(npy, ckpt)导入保存方式有差异。
  3. 固定参数还可以在构建网络的时候选择变量的trainable为False,或者设置变量学习率为0.

参数导入方法2

当预训练模型和新模型的图不同时,无法用Saver导入参数,这时候要用到tf.assign函数。
假设预训练模型只有encode部分,新模型encode+decode。遍历模型参数,用预训练参数进行替换。

代码

part3 导入旧参

# 导入所有参数
saver = tf.train.Saver()
# 对整个网络所有参数做初始化
init = tf.global_variables_initializer()
sess.run(init)
#读取预训练模型
reader = pywrap_tensorflow.NewCheckpointReader(weights_path)
# 逐层遍历参数并替换
for vv in tf.trainable_variables():
    if 'encode' in vv.name:
        weights = reader.get_tensor(weights_key)
        _op = tf.assign(vv, weights)
        sess.run(_op)

相关文章

网友评论

      本文标题:tensorflow固定部分参数进行训练

      本文链接:https://www.haomeiwen.com/subject/tgmhoqtx.html