代码使用环境: tensorflow r1.12
1. tf.estimator.EstimatorSpec API回顾
tf.estimator.EstimatorSpec
接口: r.12, r.1.13及 r2.0相同
@staticmethod
__new__(
cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None
)
1.1官方API描述
-
Args
:-
mode
: A ModeKeys. Specifies if this is training, evaluation or prediction. -
predictions
: Predictions Tensor or dict of Tensor. -
loss
: Training loss Tensor. Must be either scalar, or with shape [1]. -
train_op
: Op for the training step. -
eval_metric_ops
: Dict of metric results keyed by name. The values of the dict can be one of the following: (1) instance of Metric class. (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching. -
export_outputs
: Describes the output signatures to be exported to SavedModel and used during serving. A dict {name: output} where:-
name
: An arbitrary name for this output. -
output
: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. If no entry is provided, a default PredictOutput mapping to predictions will be created.
-
-
training_chief_hooks
: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training. -
training_hooks
: Iterable of tf.train.SessionRunHook objects to run on all workers during training. -
scaffold
: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training. -
evaluation_hooks
: Iterable of tf.train.SessionRunHook objects to run during evaluation. -
prediction_hooks
: Iterable of tf.train.SessionRunHook objects to run during predictions.
-
在这里,我们可以观察到 training_hooks, scaffold, evaluation_hooks, prediction_hooks 这四项。 hook
1.2 Scaffold类官方描述
Structure to create or gather pieces commonly needed to train a model.
When you build a model for training you usually need ops to initialize
variables, a Saver
to checkpoint them, an op to collect summaries for
the visualizer, and so on.
Various libraries built on top of the core TensorFlow library take care of
creating some or all of these pieces and storing them in well known
collections in the graph. The Scaffold
class helps pick these pieces from
the graph collections, creating and adding them to the collections if needed.
If you call the scaffold constructor without any arguments, it will pick
pieces from the collections, creating default ones if needed when
scaffold.finalize()
is called. You can pass arguments to the constructor to
provide your own pieces. Pieces that you pass to the constructor are not
added to the graph collections.
也就是说,我们可以通过scaffold
可以对 saver, variables summary等进行操作,与hook类似。
2.通过scaffold 实现加载参数范例
我们只需要实现一个inin_fn
并传入scaffold 即可。
def get_init_fn_for_scaffold(checkpoint_path, model_dir, checkpoint_exclude_scopes, ignore_missing_vars, use_v1=False):
flags_checkpoint_path = checkpoint_path
# Warn the user if a checkpoint exists in the model_dir. Then ignore.
if tf.train.latest_checkpoint(model_dir):
tf.logging.info('Ignoring --checkpoint_path because a checkpoint already exists in %s' % model_dir)
return None
if flags_checkpoint_path is None:
return None
exclusions = []
if checkpoint_exclude_scopes:
exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
variables_to_restore = []
for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
excluded = False
#print(var.op.name)
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
if tf.gfile.IsDirectory(flags_checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(flags_checkpoint_path)
else:
checkpoint_path = flags_checkpoint_path
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, ignore_missing_vars))
if not variables_to_restore:
raise ValueError('variables_to_restore cannot be empty')
if ignore_missing_vars:
reader = tf.train.NewCheckpointReader(checkpoint_path)
if isinstance(variables_to_restore, dict):
var_dict = variables_to_restore
else:
var_dict = {var.op.name: var for var in variables_to_restore}
available_vars = {}
for var in var_dict:
if reader.has_tensor(var):
available_vars[var] = var_dict[var]
else:
tf.logging.warning('Variable %s missing in checkpoint %s', var, checkpoint_path)
variables_to_restore = available_vars
if variables_to_restore:
saver = tf.train.Saver(variables_to_restore, reshape=False, write_version=tf.train.SaverDef.V1 if use_v1 else tf.train.SaverDef.V2)
saver.build()
def callback(scaffold, session):
saver.restore(session, checkpoint_path)
return callback
else:
tf.logging.warning('No Variables to restore')
return None
使用时请模仿上面的init_fn
实例化一个scaffold
,并传入EstimatorSpec
即可。
# setup fine tune scaffold
scaffold = tf.train.Scaffold(init_fn=get_init_fn_for_scaffold(params['checkpoint_path'], params['model_dir'],
params['checkpoint_exclude_scopes'], params['ignore_missing_vars']))
# create estimator training spec
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics,
scaffold=scaffold )
参考文档:
- tensorflow 官方文档: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/estimator/EstimatorSpec
https://www.tensorflow.org/api_docs/python/tf/train/Scaffold
https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook - 参考简书
https://www.jianshu.com/p/1df991a4b815 - tf.fashionAI github项目
https://github.com/HiKapok/tf.fashionAI
网友评论