美文网首页
[tf]一个比较好的代码框架以及各种Hook

[tf]一个比较好的代码框架以及各种Hook

作者: VanJordan | 来源:发表于2019-01-16 17:17 被阅读17次
  • images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)首先获取数据。
  • 接着用模型里面的不同函数,返回不同的op。
    logits = catdog_model.inference(images)
    loss = catdog_model.loss(logits, labels)
    # accuracy = catdog_model.accuracy(logits, labels)

    train_op = catdog_model.train(loss)
  • 那么应该在model = Mode()的时候就把所有的部件都初始化了。
  • 所以要把所有的进行交互的暴露在main.py里面的train函数里面比较已读。
  • 把需要打印的信息放在一个函数里面,用类_LoggerHook中继承自类tf.train.SessionRunHook,里面的before_runafter_run函数分别训练前和开始训练后执行。可以用来返回所有在运行中想要查看的信息,比如loss或者accuracy,以及打印信息。
def train():

    # 因为要使用StopAtStepHook,故global_step是必须的
    global_step = tf.train.get_or_create_global_step()

    # 输入
    images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)
    
    logits = catdog_model.inference(images)
    loss = catdog_model.loss(logits, labels)
    # accuracy = catdog_model.accuracy(logits, labels)

    train_op = catdog_model.train(loss)

    class _LoggerHook(tf.train.SessionRunHook):
        """ 
        该类用来打印训练信息
        """
        def begin(self):
            self._step = -1
            self._start_time = time.time()

        def before_run(self, run_context):
            self._step += 1
            # 该函数在训练运行之前自动调用
            # 在这里返回所有你想在运行过程中查看到的信息
            # 以list的形式传递,如:[loss, accuracy]
            return tf.train.SessionRunArgs(loss)

        def after_run(self, run_context, run_values):

            # 打印信息的步骤间隔
            display_step = 10
            if self._step % display_step == 0:
                current_time = time.time()
                duration = current_time - self._start_time
                self._start_time = current_time
                # results返回的就是上面before_run()的返回结果,上面是loss故这里是loss
                # 若输入的是list,返回也是一个list
                loss = run_values.results

                # 每秒使用的样本数
                examples_per_sec = display_step * BATCH_SIZE / duration
                # 每batch使用的时间
                sec_per_batch = float(duration / display_step)
                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), self._step, loss,
                                    examples_per_sec, sec_per_batch))

                
    with tf.train.MonitoredTrainingSession(
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],  # 将上面定义的_LoggerHook传入
            config=tf.ConfigProto(
                log_device_placement=False)) as sess:

        coord = tf.train.Coordinator()
        # 开启文件读取队列
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        while not sess.should_stop():
            sess.run(train_op)

        coord.request_stop()
        coord.join(threads)

另外一种是直接把sess直接给模型,然后就不用暴露op了

image.png
  • 因为进入MonitoredTrainingSession里面模型就冻住了不能更改,所以在外面完成建图,设置了checkpoint_dir=./那么会自动加载一个最新的checkpoint继续训练。
    image.png
  • 训练在MonitoredTrainingSession里面。
    image.png

hook的内部函数的执行顺序

image.png

定义一个early_stopping Hook

from tensorflow.python.training.session_run_hook \
import SessionRunHook,SessionRunArgs

class EarlyStoppingHook(SessionRunHook):

    def __init__(self, loss_name , feed_dict = {},
                tolerance = 0.001, stopping_step = 1000):
        self.loss_name = loss_name
        self.feed_dict = feed_dict
        self.tolerance = tolerance
        self.stopping_step = stopping_step

    def begin(self):
        self._global_step = tf.train.get_global_step()
        if self._global_step is None:
            raise RuntimeError('global step must be defined')
        self._step = 0

    def before_run(self, run_context):
        if self._step % self.stopping_step == 0:
            graph = run_context.session.graph
            loss = graph.get_tensor_by_name(self.loss_name)
            td = {}
            for key, value in self.feed_dict.items():
                placeholder = graph.get_tensor_by_name(key)
                fd[palceholder] = value
            return SessionRunArgs({"step": self._global_step,"loss": loss}, feed_dict=fd)
        else:
            return SessionRunArgs({"step": self._global_step})
    
    def after_run(self, run_context, run_values):
        if self._step % self.stopping_step == 0:
            global_step = run_values.results["step"]
            current_loss = run_values.results["loss"]
            if current_loss < self.tolerance:
                run_context.request_stop()
        else:
            global_step = run_values.results["step"]
        self._step = global_step

一个比较完善的early_stopping hook

class EarlyStoppingHook(session_run_hook.SessionRunHook):
    def __init__(self, loss_name, feed_dict={}, tolerance=0.01, stopping_step=50, start_step=100):
        self.loss_name = loss_name
        self.feed_dict = feed_dict
        self.tolerance = tolerance
        self.stopping_step = stopping_step
        self.start_step = start_step

    # Initialize global and internal step counts
    def begin(self):
        self._global_step_tensor = training_util._get_or_create_global_step_read()
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use EarlyStoppingHook.")
        self._prev_step = -1
        self._step = 0

    # Evaluate early stopping loss every 1000 steps
    # (avoiding repetition when multiple run calls are made each step)
    def before_run(self, run_context):
        if (self._step % self.stopping_step == 0) and \
           (not self._step == self._prev_step) and (self._step > self.start_step):

            print("\n[ Early Stopping Check ]")
            
            # Get graph from run_context session
            graph = run_context.session.graph

            # Retrieve loss tensor from graph
            loss_tensor = graph.get_tensor_by_name(self.loss_name)

            # Populate feed dictionary with placeholders and values
            fd = {}
            for key, value in self.feed_dict.items():
                placeholder = graph.get_tensor_by_name(key)
                fd[placeholder] = value

            return session_run_hook.SessionRunArgs({'step': self._global_step_tensor,
                                                    'loss': loss_tensor}, feed_dict=fd)
        else:
            return session_run_hook.SessionRunArgs({'step': self._global_step_tensor})
                                                    
    # Check if current loss is below tolerance for early stopping
    def after_run(self, run_context, run_values):
        if (self._step % self.stopping_step == 0) and \
           (not self._step == self._prev_step) and (self._step > self.start_step):
            global_step = run_values.results['step']
            current_loss = run_values.results['loss']
            print("Current stopping loss  =  %.10f\n" %(current_loss))
            
            if current_loss < self.tolerance:
                print("[ Early Stopping Criterion Satisfied ]\n")
                run_context.request_stop()
            self._prev_step = global_step            
        else:
            global_step = run_values.results['step']
            self._step = global_step

相关文章

网友评论

      本文标题:[tf]一个比较好的代码框架以及各种Hook

      本文链接:https://www.haomeiwen.com/subject/sqomdqtx.html