美文网首页
[tf]一个比较好的代码框架以及各种Hook

[tf]一个比较好的代码框架以及各种Hook

作者: VanJordan | 来源:发表于2019-01-16 17:17 被阅读17次
    • images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)首先获取数据。
    • 接着用模型里面的不同函数,返回不同的op。
        logits = catdog_model.inference(images)
        loss = catdog_model.loss(logits, labels)
        # accuracy = catdog_model.accuracy(logits, labels)
    
        train_op = catdog_model.train(loss)
    
    • 那么应该在model = Mode()的时候就把所有的部件都初始化了。
    • 所以要把所有的进行交互的暴露在main.py里面的train函数里面比较已读。
    • 把需要打印的信息放在一个函数里面,用类_LoggerHook中继承自类tf.train.SessionRunHook,里面的before_runafter_run函数分别训练前和开始训练后执行。可以用来返回所有在运行中想要查看的信息,比如loss或者accuracy,以及打印信息。
    def train():
    
        # 因为要使用StopAtStepHook,故global_step是必须的
        global_step = tf.train.get_or_create_global_step()
    
        # 输入
        images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)
        
        logits = catdog_model.inference(images)
        loss = catdog_model.loss(logits, labels)
        # accuracy = catdog_model.accuracy(logits, labels)
    
        train_op = catdog_model.train(loss)
    
        class _LoggerHook(tf.train.SessionRunHook):
            """ 
            该类用来打印训练信息
            """
            def begin(self):
                self._step = -1
                self._start_time = time.time()
    
            def before_run(self, run_context):
                self._step += 1
                # 该函数在训练运行之前自动调用
                # 在这里返回所有你想在运行过程中查看到的信息
                # 以list的形式传递,如:[loss, accuracy]
                return tf.train.SessionRunArgs(loss)
    
            def after_run(self, run_context, run_values):
    
                # 打印信息的步骤间隔
                display_step = 10
                if self._step % display_step == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time
                    # results返回的就是上面before_run()的返回结果,上面是loss故这里是loss
                    # 若输入的是list,返回也是一个list
                    loss = run_values.results
    
                    # 每秒使用的样本数
                    examples_per_sec = display_step * BATCH_SIZE / duration
                    # 每batch使用的时间
                    sec_per_batch = float(duration / display_step)
                    format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                                  'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss,
                                        examples_per_sec, sec_per_batch))
    
                    
        with tf.train.MonitoredTrainingSession(
                hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
                       tf.train.NanTensorHook(loss),
                       _LoggerHook()],  # 将上面定义的_LoggerHook传入
                config=tf.ConfigProto(
                    log_device_placement=False)) as sess:
    
            coord = tf.train.Coordinator()
            # 开启文件读取队列
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
            while not sess.should_stop():
                sess.run(train_op)
    
            coord.request_stop()
            coord.join(threads)
    

    另外一种是直接把sess直接给模型,然后就不用暴露op了

    image.png
    • 因为进入MonitoredTrainingSession里面模型就冻住了不能更改,所以在外面完成建图,设置了checkpoint_dir=./那么会自动加载一个最新的checkpoint继续训练。
      image.png
    • 训练在MonitoredTrainingSession里面。
      image.png

    hook的内部函数的执行顺序

    image.png

    定义一个early_stopping Hook

    from tensorflow.python.training.session_run_hook \
    import SessionRunHook,SessionRunArgs
    
    class EarlyStoppingHook(SessionRunHook):
    
        def __init__(self, loss_name , feed_dict = {},
                    tolerance = 0.001, stopping_step = 1000):
            self.loss_name = loss_name
            self.feed_dict = feed_dict
            self.tolerance = tolerance
            self.stopping_step = stopping_step
    
        def begin(self):
            self._global_step = tf.train.get_global_step()
            if self._global_step is None:
                raise RuntimeError('global step must be defined')
            self._step = 0
    
        def before_run(self, run_context):
            if self._step % self.stopping_step == 0:
                graph = run_context.session.graph
                loss = graph.get_tensor_by_name(self.loss_name)
                td = {}
                for key, value in self.feed_dict.items():
                    placeholder = graph.get_tensor_by_name(key)
                    fd[palceholder] = value
                return SessionRunArgs({"step": self._global_step,"loss": loss}, feed_dict=fd)
            else:
                return SessionRunArgs({"step": self._global_step})
        
        def after_run(self, run_context, run_values):
            if self._step % self.stopping_step == 0:
                global_step = run_values.results["step"]
                current_loss = run_values.results["loss"]
                if current_loss < self.tolerance:
                    run_context.request_stop()
            else:
                global_step = run_values.results["step"]
            self._step = global_step
    

    一个比较完善的early_stopping hook

    class EarlyStoppingHook(session_run_hook.SessionRunHook):
        def __init__(self, loss_name, feed_dict={}, tolerance=0.01, stopping_step=50, start_step=100):
            self.loss_name = loss_name
            self.feed_dict = feed_dict
            self.tolerance = tolerance
            self.stopping_step = stopping_step
            self.start_step = start_step
    
        # Initialize global and internal step counts
        def begin(self):
            self._global_step_tensor = training_util._get_or_create_global_step_read()
            if self._global_step_tensor is None:
                raise RuntimeError("Global step should be created to use EarlyStoppingHook.")
            self._prev_step = -1
            self._step = 0
    
        # Evaluate early stopping loss every 1000 steps
        # (avoiding repetition when multiple run calls are made each step)
        def before_run(self, run_context):
            if (self._step % self.stopping_step == 0) and \
               (not self._step == self._prev_step) and (self._step > self.start_step):
    
                print("\n[ Early Stopping Check ]")
                
                # Get graph from run_context session
                graph = run_context.session.graph
    
                # Retrieve loss tensor from graph
                loss_tensor = graph.get_tensor_by_name(self.loss_name)
    
                # Populate feed dictionary with placeholders and values
                fd = {}
                for key, value in self.feed_dict.items():
                    placeholder = graph.get_tensor_by_name(key)
                    fd[placeholder] = value
    
                return session_run_hook.SessionRunArgs({'step': self._global_step_tensor,
                                                        'loss': loss_tensor}, feed_dict=fd)
            else:
                return session_run_hook.SessionRunArgs({'step': self._global_step_tensor})
                                                        
        # Check if current loss is below tolerance for early stopping
        def after_run(self, run_context, run_values):
            if (self._step % self.stopping_step == 0) and \
               (not self._step == self._prev_step) and (self._step > self.start_step):
                global_step = run_values.results['step']
                current_loss = run_values.results['loss']
                print("Current stopping loss  =  %.10f\n" %(current_loss))
                
                if current_loss < self.tolerance:
                    print("[ Early Stopping Criterion Satisfied ]\n")
                    run_context.request_stop()
                self._prev_step = global_step            
            else:
                global_step = run_values.results['step']
                self._step = global_step
    
    

    相关文章

      网友评论

          本文标题:[tf]一个比较好的代码框架以及各种Hook

          本文链接:https://www.haomeiwen.com/subject/sqomdqtx.html