-
images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)
首先获取数据。 - 接着用模型里面的不同函数,返回不同的op。
logits = catdog_model.inference(images)
loss = catdog_model.loss(logits, labels)
# accuracy = catdog_model.accuracy(logits, labels)
train_op = catdog_model.train(loss)
- 那么应该在
model = Mode()
的时候就把所有的部件都初始化了。 - 所以要把所有的进行交互的暴露在main.py里面的train函数里面比较已读。
- 把需要打印的信息放在一个函数里面,用类
_LoggerHook
中继承自类tf.train.SessionRunHook
,里面的before_run
和after_run
函数分别训练前和开始训练后执行。可以用来返回所有在运行中想要查看的信息,比如loss或者accuracy,以及打印信息。
def train():
# 因为要使用StopAtStepHook,故global_step是必须的
global_step = tf.train.get_or_create_global_step()
# 输入
images, labels = catdog_input.distorted_input(FLAGS.tfrecords_name, BATCH_SIZE)
logits = catdog_model.inference(images)
loss = catdog_model.loss(logits, labels)
# accuracy = catdog_model.accuracy(logits, labels)
train_op = catdog_model.train(loss)
class _LoggerHook(tf.train.SessionRunHook):
"""
该类用来打印训练信息
"""
def begin(self):
self._step = -1
self._start_time = time.time()
def before_run(self, run_context):
self._step += 1
# 该函数在训练运行之前自动调用
# 在这里返回所有你想在运行过程中查看到的信息
# 以list的形式传递,如:[loss, accuracy]
return tf.train.SessionRunArgs(loss)
def after_run(self, run_context, run_values):
# 打印信息的步骤间隔
display_step = 10
if self._step % display_step == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
# results返回的就是上面before_run()的返回结果,上面是loss故这里是loss
# 若输入的是list,返回也是一个list
loss = run_values.results
# 每秒使用的样本数
examples_per_sec = display_step * BATCH_SIZE / duration
# 每batch使用的时间
sec_per_batch = float(duration / display_step)
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print(format_str % (datetime.now(), self._step, loss,
examples_per_sec, sec_per_batch))
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step),
tf.train.NanTensorHook(loss),
_LoggerHook()], # 将上面定义的_LoggerHook传入
config=tf.ConfigProto(
log_device_placement=False)) as sess:
coord = tf.train.Coordinator()
# 开启文件读取队列
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
while not sess.should_stop():
sess.run(train_op)
coord.request_stop()
coord.join(threads)
另外一种是直接把sess直接给模型,然后就不用暴露op了
image.png- 因为进入MonitoredTrainingSession里面模型就冻住了不能更改,所以在外面完成建图,设置了
checkpoint_dir=./
那么会自动加载一个最新的checkpoint继续训练。
image.png - 训练在
MonitoredTrainingSession
里面。
image.png
hook的内部函数的执行顺序
image.png定义一个early_stopping Hook
from tensorflow.python.training.session_run_hook \
import SessionRunHook,SessionRunArgs
class EarlyStoppingHook(SessionRunHook):
def __init__(self, loss_name , feed_dict = {},
tolerance = 0.001, stopping_step = 1000):
self.loss_name = loss_name
self.feed_dict = feed_dict
self.tolerance = tolerance
self.stopping_step = stopping_step
def begin(self):
self._global_step = tf.train.get_global_step()
if self._global_step is None:
raise RuntimeError('global step must be defined')
self._step = 0
def before_run(self, run_context):
if self._step % self.stopping_step == 0:
graph = run_context.session.graph
loss = graph.get_tensor_by_name(self.loss_name)
td = {}
for key, value in self.feed_dict.items():
placeholder = graph.get_tensor_by_name(key)
fd[palceholder] = value
return SessionRunArgs({"step": self._global_step,"loss": loss}, feed_dict=fd)
else:
return SessionRunArgs({"step": self._global_step})
def after_run(self, run_context, run_values):
if self._step % self.stopping_step == 0:
global_step = run_values.results["step"]
current_loss = run_values.results["loss"]
if current_loss < self.tolerance:
run_context.request_stop()
else:
global_step = run_values.results["step"]
self._step = global_step
一个比较完善的early_stopping hook
。
class EarlyStoppingHook(session_run_hook.SessionRunHook):
def __init__(self, loss_name, feed_dict={}, tolerance=0.01, stopping_step=50, start_step=100):
self.loss_name = loss_name
self.feed_dict = feed_dict
self.tolerance = tolerance
self.stopping_step = stopping_step
self.start_step = start_step
# Initialize global and internal step counts
def begin(self):
self._global_step_tensor = training_util._get_or_create_global_step_read()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use EarlyStoppingHook.")
self._prev_step = -1
self._step = 0
# Evaluate early stopping loss every 1000 steps
# (avoiding repetition when multiple run calls are made each step)
def before_run(self, run_context):
if (self._step % self.stopping_step == 0) and \
(not self._step == self._prev_step) and (self._step > self.start_step):
print("\n[ Early Stopping Check ]")
# Get graph from run_context session
graph = run_context.session.graph
# Retrieve loss tensor from graph
loss_tensor = graph.get_tensor_by_name(self.loss_name)
# Populate feed dictionary with placeholders and values
fd = {}
for key, value in self.feed_dict.items():
placeholder = graph.get_tensor_by_name(key)
fd[placeholder] = value
return session_run_hook.SessionRunArgs({'step': self._global_step_tensor,
'loss': loss_tensor}, feed_dict=fd)
else:
return session_run_hook.SessionRunArgs({'step': self._global_step_tensor})
# Check if current loss is below tolerance for early stopping
def after_run(self, run_context, run_values):
if (self._step % self.stopping_step == 0) and \
(not self._step == self._prev_step) and (self._step > self.start_step):
global_step = run_values.results['step']
current_loss = run_values.results['loss']
print("Current stopping loss = %.10f\n" %(current_loss))
if current_loss < self.tolerance:
print("[ Early Stopping Criterion Satisfied ]\n")
run_context.request_stop()
self._prev_step = global_step
else:
global_step = run_values.results['step']
self._step = global_step
网友评论