美文网首页
tf.train.SessionRunHook 让 estima

tf.train.SessionRunHook 让 estima

作者: theoqian | 来源:发表于2019-05-28 15:45 被阅读0次

    estimator

    estimator 是 tensorflow 提供的使用非常方便的模型封装。estimator 中提供了许多内置的模型,例如 LinearClassifier、DNNLinearCombinedClassifier、LinearRegressor等。用户也可以通过 model_fn 定制模型结构。在 estimator 对象的基础上任何模型都可以直接调用 train 和 eval 函数进行训练和测试,用户无需手动地创建 session 和 run session。estimator 的具体使用方式可以参考[1]。 estimator.png

    dataset

    tensorflow 底层 API 中都是使用 placeholder 和 feed_dict 向模型输入数据的,这样的方式效率较低。我们可以利用 dataset 库,这里提供了高效读取数据并且输入给模型训练的方式。

    可以直接用 numpy 数组创建 dataset。直接用数组创建 dataset 的一个问题是 tensorflow 会直接把 dataset 中的数据写到 graph 中,当数据量较大时会报错,因为 graph 在序列化到 pb 文件时现在最大2GB。

    def input_fn():
      features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))
      dataset = tf.data.Dataset.from_tensor_slices((features,labels))
      dataset = dataset.shuffle(100000).repeat().batch(batch_size)
      return dataset
    
    ...
    estimator.train(input_fn)
    

    为了在大数据量时使用 dataset,我们可以用 placeholder 创建 dataset。这时数据就不会直接写到 graph 中,graph 中只有一个 placeholder 占位符。但是,用了 placeholder 就需要我们在一开始对它进行初始化填数据,需要调用 sess.run(iter.initializer, feed_dict={ x: data })。更多关于 dataset 的使用介绍可以参考文献[2]。

    def input_fn():
      x = tf.placeholder(tf.float32, shape=[None,2])
      dataset = tf.data.Dataset.from_tensor_slices(x)
      dataset = dataset.shuffle(100000).repeat().batch(batch_size)
      iter = dataset.make_initializable_iterator()
      return iter.get_next()
    

    SessionRunHook

    既然前面说到 estimator 是 tensorflow 对模型的一种封装,我们不需要也无法拿到训练和测试时创建的 session,那么我们如何在 estimator 中对上一节使用 placeholder 的 dataset 的 initializeble_iterator 调用 sess.run 进行初始化呢?这时候就要用到 SessionRunHook 了。
    先从字面意思理解一下 SessionRunHook 这个类。Session 就是 tensorflow 运行模型计算时的会话,Run就是整个 session 运行过程,Hook 是挂钩的意思即把某些事情挂在这个对象上可以理解为回调。

    再看一下 SessionRunHook 源码[3]中的定义:
    A SessionRunHook extends session.run() calls for the MonitoredSession.
    SessionRunHooks are useful to track training, report progress, request early
    stopping and more. SessionRunHooks use the observer pattern and notify at the
    following points:

    • when a session starts being used
    • before a call to the session.run()
    • after a call to the session.run()
    • when the session closed
    class SessionRunHook(object):
      """Hook to extend calls to MonitoredSession.run()."""
    
      def begin(self):
        """Called once before using the session.
        When called, the default graph is the one that will be launched in the
        session.  The hook can modify the graph by adding new operations to it.
        After the `begin()` call the graph will be finalized and the other callbacks
        can not modify the graph anymore. Second call of `begin()` on the same
        graph, should not change the graph.
        """
        pass
    
      def after_create_session(self, session, coord):  # pylint: disable=unused-argument
        """Called when new TensorFlow session is created.
        This is called to signal the hooks that a new session has been created. This
        has two essential differences with the situation in which `begin` is called:
        * When this is called, the graph is finalized and ops can no longer be added
            to the graph.
        * This method will also be called as a result of recovering a wrapped
            session, not only at the beginning of the overall session.
        Args:
          session: A TensorFlow Session that has been created.
          coord: A Coordinator object which keeps track of all threads.
        """
        pass
    
      def before_run(self, run_context):  # pylint: disable=unused-argument
        """Called before each call to run().
        You can return from this call a `SessionRunArgs` object indicating ops or
        tensors to add to the upcoming `run()` call.  These ops/tensors will be run
        together with the ops/tensors originally passed to the original run() call.
        The run args you return can also contain feeds to be added to the run()
        call.
        The `run_context` argument is a `SessionRunContext` that provides
        information about the upcoming `run()` call: the originally requested
        op/tensors, the TensorFlow Session.
        At this point graph is finalized and you can not add ops.
        Args:
          run_context: A `SessionRunContext` object.
        Returns:
          None or a `SessionRunArgs` object.
        """
        return None
    
      def after_run(self,
                    run_context,  # pylint: disable=unused-argument
                    run_values):  # pylint: disable=unused-argument
        """Called after each call to run().
        The `run_values` argument contains results of requested ops/tensors by
        `before_run()`.
        The `run_context` argument is the same one send to `before_run` call.
        `run_context.request_stop()` can be called to stop the iteration.
        If `session.run()` raises any exceptions then `after_run()` is not called.
        Args:
          run_context: A `SessionRunContext` object.
          run_values: A SessionRunValues object.
        """
        pass
    
      def end(self, session):  # pylint: disable=unused-argument
        """Called at the end of session.
        The `session` argument can be used in case the hook wants to run final ops,
        such as saving a last checkpoint.
        If `session.run()` raises exception other than OutOfRangeError or
        StopIteration then `end()` is not called.
        Note the difference between `end()` and `after_run()` behavior when
        `session.run()` raises OutOfRangeError or StopIteration. In that case
        `end()` is called but `after_run()` is not called.
        Args:
          session: A TensorFlow Session that will be soon closed.
        """
        pass
    

    我们看到 SessionRunHook 源码中为 5 中不同的事件提供了回调函数,用户只需要继承 SessionRunHook 这个类并且具体实现想要的回调函数即可,具体用法看下一节。

    estimator 结合 SessionRunHook 实现 placeholder 初始化

    仔细看一下 estimator 的 train 和 evaluate 函数定义可以发现它们都接收 hooks 参数,这个参数的定义是:List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. 就是上一节提到的用户继承自 SessionRunHook 的类的实例对象列表。

    train(
        input_fn,
        hooks=None,
        steps=None,
        max_steps=None,
        saving_listeners=None
    )
    

    我们现在想要在训练之前初始化 dataset 的 placeholder,那么我们就应该具体实现 SessionRunHook 的after_create_session 成员函数:

    class IteratorInitializerHook(tf.train.SessionRunHook):
       def __init__(self):
           super(IteratorInitializerHook, self).__init__()
           self.iterator_initializer_fn = None
    
       def after_create_session(self, session, coord):
           del coord
           self.iterator_initializer_fn(session)
    
    def make_input_fn():
       iterator_initializer_hook = IteratorInitializerHook()
    
       def input_fn():
           x = tf.placeholder(tf.float32, shape=[None,2])
           dataset = tf.data.Dataset.from_tensor_slices(x)
           dataset = dataset.shuffle(100000).repeat().batch(batch_size)
           iter = dataset.make_initializable_iterator()
           data = np.random.sample((100,2))
           iterator_initializer_hook.iterator_initializer_fn = (
               lambda sess: sess.run(iter.initializer, feed_dict={x: data})
           )
           return iter.get_next()
       return input_fn, iterator_initializer_hook
    
    ...
    input_fn, iterator_initializer_hook = make_input_fn()
    estimator.train(input_fn, hooks=[iterator_initializer_hook])
    

    当然,SessionRunHook 不光能用在初始化上,还有许多应用场景,可以参考源码[3]中提供的几个内置 Hook 和文献[4]。

    [1] https://github.com/tensorflow/models/tree/master/samples/core/get_started
    [2] https://www.jiqizhixin.com/articles/03137
    [3] https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/session_run_hook.py
    [4] https://blog.csdn.net/mrr1ght/article/details/81011280

    相关文章

      网友评论

          本文标题:tf.train.SessionRunHook 让 estima

          本文链接:https://www.haomeiwen.com/subject/ekwrtctx.html