美文网首页
[TensorFlow] Estimator高级API

[TensorFlow] Estimator高级API

作者: nlpming | 来源:发表于2021-11-24 19:20 被阅读0次

1. 简介

  • tf.estimator是tensorflow一种高级API,封装了训练、评估、预测、导出以供使用所有的相关流程;tf.estimator另外一个很大的优势其可以很简单的实现分布式训练(比如ps),在cpu, gpu或者tpu上运行而无需修改代码。并且tf.estimator提供了模型保存,tensorboard可视化等基本功能。
  • 在用 Estimator 编写应用时,您必须将数据输入流水线与模型分离。这种分离简化了使用不同数据集进行实验。tf.estimator定义好了训练、评估、预测的基本框架,开发者只需要关注模型结构的定义model_fn,输入数据的获取input_fn
  • tf.estimator有预定义好的,用于分类或者回归任务的Estimator:

2. 预定义Estimator

2.1 mnist分类问题

#coding:utf-8
import tensorflow as tf
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

# 设置日志级别
tf.logging.set_verbosity(tf.logging.INFO)

# 读取数据
mnist = input_data.read_data_sets('MNIST_data')

def input(dataset):
    return dataset.images, dataset.labels.astype(np.int32)

# Specify feature
feature_columns = [tf.feature_column.numeric_column("x", shape=[28, 28])]

# Build 2 layer DNN classifier
classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf.train.AdamOptimizer(1e-4),
    n_classes=10,
    dropout=0.1,
    model_dir="./tmp/mnist_model"
)

# Define the training inputs
train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": input(mnist.train)[0]},
    y=input(mnist.train)[1],
    num_epochs=None,
    batch_size=50,
    shuffle=True
)

classifier.train(input_fn=train_input_fn, steps=100000)

# Define the test inputs
test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": input(mnist.test)[0]},
    y=input(mnist.test)[1],
    num_epochs=1,
    shuffle=False
)

# Evaluate accuracy
accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]
print("\nTest Accuracy: {0:f}%\n".format(accuracy_score*100))

2.2 wide & deep模型

(1)首先需要自定义 feature_columns
(2)调用自定义tf.estimator.LinearClassifier, tf.estimator.DNNClassifier, tf.estimator.DNNLinearCombinedClassifier
(3)自定义input_fn获取输入数据;
(4)开始训练、评估、测试;

def build_estimator(model_dir, model_type):
  """Build an estimator."""
  if model_type == "wide":
    m = tf.estimator.LinearClassifier(
        model_dir=model_dir, feature_columns=base_columns + crossed_columns)
  elif model_type == "deep":
    m = tf.estimator.DNNClassifier(
        model_dir=model_dir,
        feature_columns=deep_columns,
        hidden_units=[100, 50])
  else:
    m = tf.estimator.DNNLinearCombinedClassifier(
        model_dir=model_dir,
        linear_feature_columns=crossed_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=[100, 50])
  return m


def input_fn(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn(
      x=df_data,
      y=labels,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)


def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
  """Train and evaluate the model."""
  train_file_name, test_file_name = maybe_download(train_data, test_data)
  model_dir = tempfile.mkdtemp() if not model_dir else model_dir

  m = build_estimator(model_dir, model_type)
  # set num_epochs to None to get infinite stream of data.
  m.train(
      input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True),
      steps=train_steps)
  # set steps to None to run evaluation until all data consumed.
  results = m.evaluate(
      input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False),
      steps=None)
  print("model directory = %s" % model_dir)
  for key in sorted(results):
    print("%s: %s" % (key, results[key]))

3. 自定义Estimator

(1)自定义model_fn,并定义模型、损失函数、优化器等; 根据model_fn生成Estimator对象。
(2)自定义input_fn;
(3)自定义训练过程;

3.1 自定义model_fn

  • model_fn大致有下面几个参数:
    (1)features, labels是必须传入的参数,这两个值都由input_fn传入;
    (2)mode表示当前属于哪个阶段(可选参数):train, eval, predict分为三个阶段,分别对应三个常量:tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT
    model_fn参数说明.png
def model_body(features):
  net = tf.identity(features)
  net = tf.layers.conv2d(net, filters=32, kernel_size=5, activation=tf.nn.relu)
  net = tf.layers.max_pooling2d(net, pool_size=2, strides=2)
  net = tf.layers.conv2d(net, filters=64, kernel_size=5, activation=tf.nn.relu)
  net = tf.layers.max_pooling2d(net, pool_size=2, strides=2)
  net = tf.contrib.layers.flatten(net)
  net = tf.layers.dense(net, 1024, activation=tf.nn.relu)

  return net


def build_model_fn(hparams):
  '''Build the model function.'''
  def model_fn(features, labels, mode, params=None):
    '''Define the model graph.'''

    if params:
      hparams.override_from_dict(params)

    net = model_body(features)

    logits = tf.layers.dense(net, units=N_LABELS)
    xentropies = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                         labels=labels)
    loss = tf.reduce_mean(xentropies)

    if hparams.get('learning_rate_decay_scheme') == 'exponential':
      learning_rate = tf.train.exponential_decay(hparams.learning_rate,
                                                 tf.train.get_global_step(),
                                                 hparams.decay_steps,
                                                 hparams.decay_rate)
    else:
      learning_rate = hparams.learning_rate

    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                           hparams.momentum)

    if hparams.use_tpu:
      optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

    def metric_fn(labels, logits):
      predictions = tf.argmax(logits, axis=1)
      accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions)
      return {'accuracy': accuracy}
    eval_metrics = (metric_fn, [labels, logits])
    eval_metric_ops = metric_fn(labels, logits)

    if hparams.use_tpu:
      return tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             eval_metrics=eval_metrics)
    else:
      return tf.estimator.EstimatorSpec(mode=mode,
                                        loss=loss,
                                        train_op=train_op,
                                        eval_metric_ops=eval_metric_ops)
  return model_fn

3.2 自定义input_fn

  • 自定义数据的读入过程,使用tf.data类;
  • input_fn作为model_fn的输入;
def build_input_fn(fnames, hparams, is_training=True, use_tpu=True):
  '''Build the input function.'''
  def parse_fn(proto):
    '''Parse a single Tensorflow example from a `TFRecord`.
    Args:
      proto: The serialized protobuf of the `tf.Example`
    Returns:
      A `Tensor` containing the image.
      A one-hot `Tensor` containing the label.
    '''

    features = {
        'im': tf.FixedLenFeature([28 * 28], tf.float32),
        'label': tf.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.parse_single_example(proto, features)
    im = tf.reshape(parsed_features['im'], [28, 28, 1])
    label = tf.one_hot(parsed_features['label'], N_LABELS)

    return im, label

  def input_fn(params):
    '''Feed input into the graph.'''
    with tf.variable_scope('image_preprocessing'):
      dataset = tf.data.TFRecordDataset(fnames)
      dataset = dataset.shuffle(len(fnames))
      dataset = dataset.map(parse_fn)
      if is_training:
        dataset = dataset.shuffle(args.shuffle_buffer_size)
        dataset = dataset.repeat()
      if use_tpu:
        dataset = dataset.apply(
            tf.contrib.data.batch_and_drop_remainder(params['batch_size']))
      else:
        dataset = dataset.batch(hparams.batch_size)
      dataset = dataset.prefetch(buffer_size=1)
      iterator = dataset.make_one_shot_iterator()
      features, label = iterator.get_next()

    return features, label
  return input_fn

参考资料

相关文章

网友评论

      本文标题:[TensorFlow] Estimator高级API

      本文链接:https://www.haomeiwen.com/subject/qlmntrtx.html