美文网首页
[tf]模型存储和加载

[tf]模型存储和加载

作者: VanJordan | 来源:发表于2018-12-07 16:36 被阅读0次

    saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
    saver = tf.train.Saver() 默认是保存默认图上的Variable数据。当然也可以指定保存那些Variable数据,tf.train.Saver([var_list])

    模型的加载

    loader = tf.train.Saver()
    loader.restore(sess,model_dir)
    
    • Saver的第一个参数是var_list用来指定需要存储或者保存哪些变量。如果不指定的话那么默认保存和加载全部的可保存的对象
    v1 = tf.Variable(..., name='v1')
    v2 = tf.Variable(..., name='v2')
    
    # Pass the variables as a dict:
    saver = tf.train.Saver({'v1': v1, 'v2': v2})
    
    # Or pass them as a list.
    saver = tf.train.Saver([v1, v2])
    # Passing a list is equivalent to passing a dict with the variable op names
    # as keys:
    saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
    

    表示的意思是需要加载的变量是embedding

    def setup_loader(self):
        self.loader = tf.train.Saver(self.var_list)
    
    def load_session(self, itr):
            self.loader.restore(self.sess, self.model_name + "_weights/" + self.dataset + "/" + itr + ".ckpt")
    -----------------------TransE model中的self.var_list---------------------
    self.rel_emb = tf.get_variable(name="rel_emb", initializer=tf.random_uniform(shape=[self.num_rel, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
    self.ent_emb = tf.get_variable(name="ent_emb", initializer=tf.random_uniform(shape=[self.num_ent, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
    self.var_list = [self.rel_emb, self.ent_emb]
    

    模型的保存

    saver = tf.train.Saver(max_to_keep=0)
    saver.save(self.sess, filename)
    
    • os.mkdir()只对路径的最后一级目录进行创建,如果前几级目录不存在,会报错!os.makedirs()可以创建多级目录,如果路径的目录都不存在,都可以创建出来
    • 按照模型和数据集合进行分文件夹的保存
    • max_to_keep 参数:这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果想要保存模型的数量不受限制,则可以将 max_to_keep设置为None或者0,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可。
    • saver.save(sess,filename,global_step=step)还有最后一个参数global_step,表示保存模型名字的后缀是step。
    def setup_saver(self):
        self.saver = tf.train.Saver(max_to_keep=0)
    
    def save_model(self, itr):
        filename = self.model_name + "_weights/" + self.dataset + "/" + str(itr) + ".ckpt"
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(os.path.dirname(filename))
        self.saver.save(self.sess, filename)
    

    例子:保存模型

    # construct graph
    v1 = tf.Variable([0], name='v1')
    v2 = tf.Variable([0], name='v2')
    # run graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.save(sess, 'ckp')
    
    with tf.Session() as sess:
        saver = tf.import_meta_graph('ckp.meta')
        saver.restore(sess, 'ckp')
    

    当执行Saver.saver操作的时候,在文件系统中生成如下文件:

    • index:文件保存了一个不可变的表数据,记录Tensor元数据的信息,比如tensor存储在那个数据data文件中,以及在数据文件中的偏移,校验和其他信息。
    • 数据(data) :文件记录了所有变量(Variable) 的值,当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。
    • 元文件(meta) :保存了MetaGraphDef 的持久化数据,它包括GraphDef, SaverDef 等元数据。就是描述了图结构的信息。这也是在上例中,在调用Saver.restore 之前,得先调用tf.import_meta_graph 的真正原因;否则,缺失计算图的实例,就无法谈及恢复数据到图实例中了。
    • 状态文件checkpoint:文件会记录最近一次的断点文件的前缀,根据前缀找到对应的索引和数据文件,当调用tf.train.latest_checkpoint,可以快速找到最近一次的断点文件,此外,Checkpoint 文件也记录了所有的断点文件列表,并且文件列表按照由旧至新的时间依次排序。当训练任务时间周期非常长,断点检查将持续进行,必将导致磁盘空间被耗尽。为了避免这个问题,存在两种基本的方法:设置max_to_keep: 配置最近有效文件的最大数目,当新的断点文件生成时,且文件数目超过max_to_keep,则删除最旧的断点文件;其中,max_to_keep 默认值为5keep_checkpoint_every_n_hours: 在训练过程中每n 小时做一次断点检查,保证只有一个断点文件;其中,该选项默认是关闭的
    ├── checkpoint
    ├── ckp.data-00000-of-00001
    ├── ckp.index
    ├── ckp.meta
    

    相关文章

      网友评论

          本文标题:[tf]模型存储和加载

          本文链接:https://www.haomeiwen.com/subject/qugthqtx.html