saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
saver = tf.train.Saver()
默认是保存默认图上的Variable数据。当然也可以指定保存那些Variable数据,tf.train.Saver([var_list])
。
模型的加载
loader = tf.train.Saver()
loader.restore(sess,model_dir)
-
Saver
的第一个参数是var_list
用来指定需要存储或者保存哪些变量。如果不指定的话那么默认保存和加载全部的可保存的对象。
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
表示的意思是需要加载的变量是embedding
def setup_loader(self):
self.loader = tf.train.Saver(self.var_list)
def load_session(self, itr):
self.loader.restore(self.sess, self.model_name + "_weights/" + self.dataset + "/" + itr + ".ckpt")
-----------------------TransE model中的self.var_list---------------------
self.rel_emb = tf.get_variable(name="rel_emb", initializer=tf.random_uniform(shape=[self.num_rel, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
self.ent_emb = tf.get_variable(name="ent_emb", initializer=tf.random_uniform(shape=[self.num_ent, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
self.var_list = [self.rel_emb, self.ent_emb]
模型的保存
saver = tf.train.Saver(max_to_keep=0)
saver.save(self.sess, filename)
-
os.mkdir()
只对路径的最后一级目录进行创建,如果前几级目录不存在,会报错!而os.makedirs()
可以创建多级目录,如果路径的目录都不存在,都可以创建出来。 - 按照模型和数据集合进行分文件夹的保存。
-
max_to_keep
参数:这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果想要保存模型的数量不受限制,则可以将 max_to_keep设置为None或者0,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可。 -
saver.save(sess,filename,global_step=step)
还有最后一个参数global_step
,表示保存模型名字的后缀是step。
def setup_saver(self):
self.saver = tf.train.Saver(max_to_keep=0)
def save_model(self, itr):
filename = self.model_name + "_weights/" + self.dataset + "/" + str(itr) + ".ckpt"
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
self.saver.save(self.sess, filename)
例子:保存模型
# construct graph
v1 = tf.Variable([0], name='v1')
v2 = tf.Variable([0], name='v2')
# run graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, 'ckp')
with tf.Session() as sess:
saver = tf.import_meta_graph('ckp.meta')
saver.restore(sess, 'ckp')
当执行Saver.saver操作的时候,在文件系统中生成如下文件:
- index:文件保存了一个不可变的表数据,记录Tensor元数据的信息,比如tensor存储在那个数据data文件中,以及在数据文件中的偏移,校验和其他信息。
- 数据(data) :文件记录了所有变量(Variable) 的值,当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。
- 元文件(meta) :保存了MetaGraphDef 的持久化数据,它包括GraphDef, SaverDef 等元数据。就是描述了图结构的信息。这也是在上例中,在调用Saver.restore 之前,得先调用tf.import_meta_graph 的真正原因;否则,缺失计算图的实例,就无法谈及恢复数据到图实例中了。
-
状态文件checkpoint:文件会记录最近一次的断点文件的前缀,根据前缀找到对应的索引和数据文件,当调用
tf.train.latest_checkpoint
,可以快速找到最近一次的断点文件,此外,Checkpoint 文件也记录了所有的断点文件列表,并且文件列表按照由旧至新的时间依次排序。当训练任务时间周期非常长,断点检查将持续进行,必将导致磁盘空间被耗尽。为了避免这个问题,存在两种基本的方法:设置max_to_keep
: 配置最近有效文件的最大数目,当新的断点文件生成时,且文件数目超过max_to_keep
,则删除最旧的断点文件;其中,max_to_keep
默认值为5,keep_checkpoint_every_n_hours
: 在训练过程中每n 小时做一次断点检查,保证只有一个断点文件;其中,该选项默认是关闭的。
├── checkpoint
├── ckp.data-00000-of-00001
├── ckp.index
├── ckp.meta
网友评论