在迁移学习中——需要固定前面几层,去重新设置后面连接层并进行局部训练;在有些神经网络可视化的方法中,想看下哪种image会让某个神经元"兴奋"起来等等都需要用到权重加载和固定。
下面会介绍几种tensorflow中权重加载或者固定的方法。
以下用到的所有tf均是tensorflow的简写,即
import tensorflow as tf
一、利用Saver权重加载
这种方法在使用的时候必须定义好计算图,并且这个计算图是和原先训练时的计算图一模一样。使用方法如下,这里只列出权重加载有关步骤,计算图的定义和权重更新等步骤省略了,省略步骤的具体位置在代码的注释段.
--训练时
# define your graph
saver = tf.train.Saver()
with tf.Session() as sess:
# update the parameters of net
saver.save(sess, ckpt_file_path) # 在模型训练完成后使用
--重新使用参数(或者是为了更进一步训练,或者用来做测试等等)
# define your graph (the same graph)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # very important
saver.restore(sess, ckpt_file_path) # 加载模型
--获得某个变量的权重值
sess.run(tf.get_default_graph().get_tensor_by_name(variable_name))
二、使用assign和trainable进行更加自如的权重加载
上述的方法只能在你的计算图不发生改变的时候使用。如果你需要增加新的变量或删除原有计算图中的某几个变量,而对其他的变量进行权重加载时可以使用如下的做法。
tips:在训练的时候最好定义好graph的命名区域,便于知道每一个变量的名字是什么,这样在加载的时候更加方便。
--训练时,有两种做法
- 方法1.使用saver(推荐)
# define your graph
## 方法同上,var_list是需要保存权重的变量名称,如果不填,默认是所有trainable_variables
saver = tf.train.Saver(var_list)
with tf.Session() as sess:
# update the parameters of net
saver.save(sess, ckpt_file_path) # 在模型训练完成后使用
- 方法2.使用字典保存
import numpy as np
# define your graph
with tf.Session() as sess:
# update the parameters of net
data = {}
for var in tf.trainable_variables():
data[var.name] = sess.run(var)
np.save(file, data) # file 是你保存数据的地址
--加载并固定权重
- 方法1对应的权重加载
# define your graph
# 在定义的时候将不需要参与更新的vars设置为trainable=False
from tensorflow.python import pywrap_tensorflow
saver = tf.train.Saver()
with tf.Session() as sess:
reader = pywrap_tensorflow.NewCheckpointReader(file)
for var in tf.global_variables():
if var.name.startswith('conv'): #假如你只想给之前定义的conv层重新加载权重
# ckpt文件中的变量名是conv/kernel,那么计算图中对应的是conv/kernel:0
# tf.assign是一个operator,必须加上sess.run才会赋值生效
sess.run(var.assign(reader.get_tensor(var.name[:-2])))
- 方法2对应的权重加载
data = np.load(file).item() # data的类型是字典
# If you know the name of variables.
# 只展示一个变量,其他的都是类似的
# trainable决定你是否要固定权重,False代表固定权重
w = tf.Variable(data[variable_name], dtype=tf.float32, trainable=False)
这个方法加载进来的权重也可以用到keras中的计算图上。
data = np.load(file).item()
# 也只是举一个例子,其他的都是类似操作
kernel_w = data[w_name]
kernel_b = data[b_name]
layer = model.get_layer(layer_name) # model就是你在keras中定义的计算图(模型)
layer.trainable = False # 如果想继续训练,设置为True
layer.set_weights([kernel_w, kernel_b]) # 顺序不要弄反了,先是权重后是bias
三、最后一个tip
如果你的计算图完全不改变,也可以不用重新定义计算图,可以用你之前用saver保存下来的'xxx.ckpt.meta'文件直接加载图(权重加载和固定同上)。
再重申一下,这种方法不用重新定义原来的计算图,但是你必须有graph的meta data且你不想改变图的任何操作而单单只为训练权重的话,可以考虑这个方法。
# You do not need to define your graph.
saver = tf.train.import_meta_graph(meta_file_path)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, ckpt_file_path)
网友评论