美文网首页
构建Tensorflow的Estimator

构建Tensorflow的Estimator

作者: Hugo_Ng_7777 | 来源:发表于2019-02-15 17:58 被阅读0次

TensorFlow 提供一个包含多个 API 层的编程堆栈

estimator api.png

  这个高层API,正是可以提高我们的GPU使用效果。因为如果使用传统的session.run()操作后即使是GPU显存塞满的情况,其利用率依然是30%以下,甚至是10%不到的。
  所以Estimator + tf.data 的API进行配合使用,可以将GPU的使用效率直逼100%!因为使用了流水线,使得预处理和模型执行过程重叠到一起。如下图所示:

流水线.png

要做到以上并行处理,可以使用prefetch函数,

 

1. 将InputExamples to a TFRecord file

def file_based_convert_examples_to_features(examples,
                                            label_list,
                                            max_seq_length,
                                            tokenizer,
                                            output_file):
    """Convert a set of `InputExample`s to a TFRecord file."""

    writer = tf.python_io.TFRecordWriter(output_file)  ##output_file="train.tf_record"

    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

        feature = convert_single_example(ex_index,
                                         example,
                                         label_list,
                                         max_seq_length,
                                         tokenizer)

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(feature.input_ids)
        features["input_mask"] = create_int_feature(feature.input_mask)
        features["segment_ids"] = create_int_feature(feature.segment_ids)
        features["label_ids"] = create_int_feature([feature.label_id])
        features["is_real_example"] = create_int_feature([int(feature.is_real_example)])

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(tf_example.SerializeToString())
    writer.close()

 

2. 编写模型的输入函数train_input_fn

WX20190215-145003.png

 正如以上图中所示,dataset有3个子类,可以直接进行使用。而且这里的模型输入函数完成的是ETL整个系统,包括提取数据 (Extract)、转换数据 (Transform)、装载数据(Load) 。
所以 train_input_fn实现可以如下:

def file_based_input_fn_builder(input_file,
                                seq_length,
                                is_training,
                                drop_remainder):
    tf.logging.info("*** file_based_input_fn_builder ***")

    name_to_features = {
        "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
        "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
        "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
        "label_ids": tf.FixedLenFeature([], tf.int64),
        "is_real_example": tf.FixedLenFeature([], tf.int64),
    }

    def _decode_record(record, name_to_features):
        """Decodes a record to a TensorFlow example."""
        example = tf.parse_single_example(record, name_to_features)
        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        for name in list(example.keys()):
            t = example[name]
            if t.dtype == tf.int64:
                t = tf.to_int32(t)
            example[name] = t
        return example

    def input_fn(params):
        tf.logging.info("*** input_fn ***")
        batch_size = params["batch_size"]

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        d = tf.data.TFRecordDataset(input_file, num_parallel_reads=4)  ##input_file可以是list形式的
        if is_training:
            d = d.repeat()  ## 默认epoch是无数
            d = d.shuffle(buffer_size=100)  ## 最好先shuffle,再map
            # TODO: d.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=100, count=-1))

        d = d.apply(tf.contrib.data.map_and_batch(lambda record: _decode_record(record, name_to_features),
                                                  batch_size=batch_size,
                                                  drop_remainder=drop_remainder))  ## map是对记录每条都解析,解析方式按照_decode_record
        d = d.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)  ## 需要加吗?
        ## TODO: d = d.apply(tf.contrib.data.prefetch_to_device("/gpu:0"))
        return d

    return input_fn

 

3. 编写模型函数model_fn

def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps, use_tpu,
                     use_one_hot_embeddings):
    """Returns `model_fn` closure for TPUEstimator."""

    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)  ## mode由estimator.train来体现赋值

        """
        以下是model的搭建
        """
        (total_loss, per_example_loss, logits, probabilities) = create_model(bert_config,
                                                                             is_training,
                                                                             input_ids,
                                                                             input_mask,
                                                                             segment_ids,
                                                                             label_ids,
                                                                             num_labels,
                                                                             use_one_hot_embeddings)

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names) = \
                modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
            if use_tpu:
                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint,
                                              assignment_map)  ##assignment_map中key表示ckpt的变量名字,value表示当前图中变量的名字

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape, init_string)

        tf.logging.info("****Mode Starting ****")
        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss,
                                                     learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps,
                                                     use_tpu)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                          loss=total_loss,
                                                          train_op=train_op,
                                                          scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            def metric_fn(per_example_loss, label_ids, logits, is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                accuracy = tf.metrics.accuracy(labels=label_ids,
                                               predictions=predictions,
                                               weights=is_real_example)
                loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
                return {"eval_accuracy": accuracy,
                        "eval_loss": loss,
                        }

            eval_metrics = (
                metric_fn, [per_example_loss, label_ids, logits, is_real_example])  ##a tuple of metric_fn and tensors
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                          loss=total_loss,
                                                          eval_metrics=eval_metrics,
                                                          scaffold_fn=scaffold_fn)
        else:
            ## predictions是想要得到的字典结果值
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                          predictions={"probabilities": probabilities},
                                                          scaffold_fn=scaffold_fn)
        return output_spec

    return model_fn

 

4. 整体使用TPU的训练流程:

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(FLAGS.tpu_name,
                                                                              zone=FLAGS.tpu_zone,
                                                                              project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.iterations_per_loop,
                                            num_shards=FLAGS.num_tpu_cores,
                                            per_host_input_for_training=is_per_host))  ## 设置tpu RunConfig

     model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=len(label_list),
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu)

    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(train_examples,
                                                label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer,
                                                train_file)  ## 转成tf_record文件
        train_input_fn = file_based_input_fn_builder(input_file=train_file,
                                                     seq_length=FLAGS.max_seq_length,
                                                     is_training=True,
                                                     drop_remainder=True)  ## 读取tf_records文件
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

 

5. 调试以上高层 API 的 trick

  因为已经将tf.estimator替换了以往的session.run() 的形式,所以在具体调试时候就不可能再用session.run来查看tensor的值,只能通过tf.Print函数来串接入整个graph中,从而可以打印出其中tensor的value值,而不是基本的维度等信息。其使用例子如下所示:


Graph.png
node1 = tf.add(input1, input2)
print_node1 = tf.Print(node1, [node1])   ## P 需要大写哦!
output = tf.multiply(print_node1, input3)

 
 
 
参考文献:

Google 开发者大会 2018
数据输入流水线性能
BERT-google
https://zhuanlan.zhihu.com/p/33906227

相关文章

网友评论

      本文标题:构建Tensorflow的Estimator

      本文链接:https://www.haomeiwen.com/subject/lmypeqtx.html