美文网首页
tensorflow的权重加载及固定

tensorflow的权重加载及固定

作者: 7NIC7 | 来源:发表于2019-03-30 00:06 被阅读0次

    在迁移学习中——需要固定前面几层,去重新设置后面连接层并进行局部训练;在有些神经网络可视化的方法中,想看下哪种image会让某个神经元"兴奋"起来等等都需要用到权重加载和固定。
    下面会介绍几种tensorflow中权重加载或者固定的方法。
    以下用到的所有tf均是tensorflow的简写,即

import tensorflow as tf

一、利用Saver权重加载

    这种方法在使用的时候必须定义好计算图,并且这个计算图是和原先训练时的计算图一模一样。使用方法如下,这里只列出权重加载有关步骤,计算图的定义和权重更新等步骤省略了,省略步骤的具体位置在代码的注释段.

--训练时

# define your graph
saver = tf.train.Saver()   

with tf.Session() as sess:
    # update the parameters of net
    saver.save(sess, ckpt_file_path)               # 在模型训练完成后使用

--重新使用参数(或者是为了更进一步训练,或者用来做测试等等)

# define your graph (the same graph)
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())    # very important
    saver.restore(sess, ckpt_file_path)            # 加载模型

--获得某个变量的权重值

sess.run(tf.get_default_graph().get_tensor_by_name(variable_name))

二、使用assign和trainable进行更加自如的权重加载

    上述的方法只能在你的计算图不发生改变的时候使用。如果你需要增加新的变量或删除原有计算图中的某几个变量,而对其他的变量进行权重加载时可以使用如下的做法。
tips:在训练的时候最好定义好graph的命名区域,便于知道每一个变量的名字是什么,这样在加载的时候更加方便。

--训练时,有两种做法

  • 方法1.使用saver(推荐)
# define your graph
## 方法同上,var_list是需要保存权重的变量名称,如果不填,默认是所有trainable_variables  
saver = tf.train.Saver(var_list)                  

with tf.Session() as sess:
    # update the parameters of net
    saver.save(sess, ckpt_file_path)               # 在模型训练完成后使用
  • 方法2.使用字典保存
import numpy as np
# define your graph
with tf.Session() as sess:
    # update the parameters of net
    data = {}
    for var in tf.trainable_variables():
        data[var.name] = sess.run(var)
    np.save(file, data)                            # file 是你保存数据的地址

--加载并固定权重

  • 方法1对应的权重加载
# define your graph
# 在定义的时候将不需要参与更新的vars设置为trainable=False
from tensorflow.python import pywrap_tensorflow
saver = tf.train.Saver()
with tf.Session() as sess:
    reader = pywrap_tensorflow.NewCheckpointReader(file)
    for var in tf.global_variables():
        if var.name.startswith('conv'): #假如你只想给之前定义的conv层重新加载权重 
            # ckpt文件中的变量名是conv/kernel,那么计算图中对应的是conv/kernel:0
            # tf.assign是一个operator,必须加上sess.run才会赋值生效
            sess.run(var.assign(reader.get_tensor(var.name[:-2])))
  • 方法2对应的权重加载
data = np.load(file).item()                        # data的类型是字典
# If you know the name of variables.
# 只展示一个变量,其他的都是类似的
# trainable决定你是否要固定权重,False代表固定权重
w = tf.Variable(data[variable_name], dtype=tf.float32, trainable=False)

    这个方法加载进来的权重也可以用到keras中的计算图上。

data = np.load(file).item()
# 也只是举一个例子,其他的都是类似操作
kernel_w = data[w_name]
kernel_b = data[b_name]
layer = model.get_layer(layer_name)                # model就是你在keras中定义的计算图(模型)
layer.trainable = False                            # 如果想继续训练,设置为True
layer.set_weights([kernel_w, kernel_b])            # 顺序不要弄反了,先是权重后是bias

三、最后一个tip

    如果你的计算图完全不改变,也可以不用重新定义计算图,可以用你之前用saver保存下来的'xxx.ckpt.meta'文件直接加载图(权重加载和固定同上)。
    再重申一下,这种方法不用重新定义原来的计算图,但是你必须有graph的meta data且你不想改变图的任何操作而单单只为训练权重的话,可以考虑这个方法。

# You do not need to define your graph.
saver = tf.train.import_meta_graph(meta_file_path)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, ckpt_file_path)

相关文章

网友评论

      本文标题:tensorflow的权重加载及固定

      本文链接:https://www.haomeiwen.com/subject/iakdbqtx.html