美文网首页
1_get_started

1_get_started

作者: happy_19 | 来源:发表于2018-06-12 20:05 被阅读13次

    文章主要介绍Tensorflow编程环境,并通过鸢尾花分类问题简单介绍如何使用Tensorflow高级API解决实际问题。
    首先,可以通过下图了解Tensorflow编程环境,Tensorflow提供一个包含有多个API层的编程环境:


    tensorflow_programming_environment

    0 使用高级API Estimator

    Estimator是Tensorflow对完整模型的高级表示,所有的Estimator类都是继承自Tensorflow的tf.estimator.Estimator类。Tensorflow提供了一组预创建好的Estimator(例如DNNClassifierDNNRegressorLinearClassifier等),除了这些预定义的Estimator,还可以自定义Estimator(具体方法后续记录)。本文主要记录如何使用预定义的Estimator。根据预定义的Estimator编写Tensorflow程序,必须按照如下步骤来进行:

    1. 创建一个或多个输入函数
    2. 定义模型的特征列
    3. 实例化Estimator,同时指定特征列和各种超参数
    4. 在Estimator对象上调用一个或多个方法,传递适当的输入函数并未数据的来源。

    接下来按照如上步骤来完成鸢尾花的分类问题,源码在最后给出。

    1. 创建输入函数

    在对模型进行训练、评估和预测的时候需要一个输入函数作为数据的来源。
    输入函数返回tf.data.Dataset对象,该对象会输出下列含有两个元素的元组:

    • feature python字典
      • key为特征的名称
      • value为所有样本在当前特征下的取值的数组
    • label 包含有所有样本的标签值的数组

    可以使用如下简单方式来实现输入函数:

    def input_evaluation_set():
        features = {'SepalLength': np.array([6.4, 5.0]),
                    'SepalWidth':  np.array([2.8, 2.3]),
                    'PetalLength': np.array([5.6, 3.3]),
                    'PetalWidth':  np.array([2.2, 1.0])}
        labels = np.array([2, 1])
        return features, labels
    

    上述方式虽然可以作为模型的输入函数,但是这里强烈建议使用Tensorflow中的Dataset API来实现,如下所示:

    def train_input_fn(features, labels, batch_size):
        """An input function for training"""
        # Convert the inputs to a Dataset.
        dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    
        # Shuffle, repeat, and batch the examples.
        return dataset.shuffle(1000).repeat().batch(batch_size)
    

    为了简化数据处理过程,一般情况函数参数featureslabels都是Pandas数据。

    2. 定义feature columns

    feature columns用于说明模型应该如何使用特征字典中的原始输入数据。在构建Estimator时,需要向其传递一个feature columns的列表,其中包含有模型使用的所有特征。feature columns列表中每一个特征都是一个tf.feature_column对象。对于鸢尾花分类问题,4个特征都是数值型的,按照如下构建即可:

    my_feature_columns = []
    for key in train_x.keys():
        my_feature_columns.append(tf.feature_column.numerici_column(key=key))
    

    3. 实例化Estimator

    鸢尾花分类是典型的分类问题。Tensorflow提供了几个预创建的分类器Esitimator,如下所示:

    • tf.estimator.DNNClassifier:适用于执行多类别分类的深度神经网络模型
    • tf.estimator.DNNLinearCombinedClassifier: wide and deep分类模型
    • tf.estimator.LinearClassifier:基于线性模型的分类器

    这里直接使用tf.estimator.DNNClassifier即可,代码如下所示:

    classifier = tf.estimator.DNNClassifier(
        feature_columns=my_feature_columns,
        # Two hidden layers of 10 nodes each.
        hidden_units=[10, 10],
        # The model must choose between 3 classes.
        n_classes=3)
    

    4. 训练、评估和预测

    4.1 训练模型

    通过调用Estimator的train方法即可以对模型进行训练,如下所示:

    classifier.train(
        input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
        steps=args.train_steps)
    

    4.2 评估训练的模型

    模型经过训练之后,我们可以对模型的效果进行评估统计,可以通过如下代码对模型的准确率进行计算:

    eval_result = classifier.evaluate(
        input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
    print "\nTest set accuracy: {accuracy:0.3f}\n".format(**eval_result)
    

    执行之后产出如下输出:

    Test set accuracy: 0.967
    

    4.3 预测

    训练好模型之后,可以使用模型对无标签的样本进行预测,代码如下所示:

    expected = ['Setosa', 'Versicolor', 'Virginica']
    predict_x = {
        'SepalLength': [5.1, 5.9, 6.9],
        'SepalWidth': [3.3, 3.0, 3.1],
        'PetalLength': [1.7, 4.2, 5.4],
        'PetalWidth': [0.5, 1.5, 2.1],
    }
    predictions = classifier.predict(
        input_fn=lambda:iris_data.eval_input_fn(
            predict_x,
            batch_size=args.batch_size
        )
    )
    

    predict方法返回一个python可以迭代的对象,为每个样本生成一个预测结果字典,可以通过如下代码输出预测结果以及对应的概率:

    for pred_dict, expec in zip(predictions, expected):
        template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
    
        class_id = pred_dict['class_ids'][0]
        probability = pred_dict['probabilities'][class_id]
    
        print(template.format(iris_data.SPECIES[class_id],
                              100 * probability, expec))
    

    结果输出如下:

    Prediction is "Setosa" (99.8%), expected "Setosa"
    
    Prediction is "Versicolor" (99.7%), expected "Versicolor"
    
    Prediction is "Virginica" (96.9%), expected "Virginica"
    

    5 源码

    #iris_data.py文件
    import pandas as pd
    import tensorflow as tf
    
    TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
    TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
    
    CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
                        'PetalLength', 'PetalWidth', 'Species']
    SPECIES = ['Setosa', 'Versicolor', 'Virginica']
    
    def maybe_download():
        train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
        test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
    
        return train_path, test_path
    
    def load_data(y_name='Species'):
        """Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
        train_path, test_path = maybe_download()
    
        train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
        train_x, train_y = train, train.pop(y_name)
    
        test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
        test_x, test_y = test, test.pop(y_name)
    
        return (train_x, train_y), (test_x, test_y)
    
    
    def train_input_fn(features, labels, batch_size):
        """An input function for training"""
        # Convert the inputs to a Dataset.
        dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    
        # Shuffle, repeat, and batch the examples.
        dataset = dataset.shuffle(1000).repeat().batch(batch_size)
    
        # Return the dataset.
        return dataset
    
    
    def eval_input_fn(features, labels, batch_size):
        """An input function for evaluation or prediction"""
        features=dict(features)
        if labels is None:
            # No labels, use only features.
            inputs = features
        else:
            inputs = (features, labels)
    
        # Convert the inputs to a Dataset.
        dataset = tf.data.Dataset.from_tensor_slices(inputs)
    
        # Batch the examples
        assert batch_size is not None, "batch_size must not be None"
        dataset = dataset.batch(batch_size)
    
        # Return the dataset.
        return dataset
    
    
    # The remainder of this file contains a simple example of a csv parser,
    #     implemented using the `Dataset` class.
    
    # `tf.parse_csv` sets the types of the outputs to match the examples given in
    #     the `record_defaults` argument.
    CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]
    
    def _parse_line(line):
        # Decode the line into its fields
        fields = tf.decode_csv(line, record_defaults=CSV_TYPES)
    
        # Pack the result into a dictionary
        features = dict(zip(CSV_COLUMN_NAMES, fields))
    
        # Separate the label from the features
        label = features.pop('Species')
    
        return features, label
    
    
    def csv_input_fn(csv_path, batch_size):
        # Create a dataset containing the text lines.
        dataset = tf.data.TextLineDataset(csv_path).skip(1)
    
        # Parse each line.
        dataset = dataset.map(_parse_line)
    
        # Shuffle, repeat, and batch the examples.
        dataset = dataset.shuffle(1000).repeat().batch(batch_size)
    
        # Return the dataset.
        return dataset
    
    #premade_estimator.py
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import tensorflow as tf
    
    import iris_data
    
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=100, type=int, help='batch size')
    parser.add_argument('--train_steps', default=1000, type=int,
                        help='number of training steps')
    
    def main(argv):
        args = parser.parse_args(argv[1:])
    
        # Fetch the data
        (train_x, train_y), (test_x, test_y) = iris_data.load_data()
    
        # Feature columns describe how to use the input.
        my_feature_columns = []
        for key in train_x.keys():
            my_feature_columns.append(tf.feature_column.numeric_column(key=key))
    
        # Build 2 hidden layer DNN with 10, 10 units respectively.
        classifier = tf.estimator.DNNClassifier(
            feature_columns=my_feature_columns,
            # Two hidden layers of 10 nodes each.
            hidden_units=[10, 10],
            # The model must choose between 3 classes.
            n_classes=3)
    
        # Train the Model.
        classifier.train(
            input_fn=lambda:iris_data.train_input_fn(train_x, train_y,
                                                     args.batch_size),
            steps=args.train_steps)
    
        # Evaluate the model.
        eval_result = classifier.evaluate(
            input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,
                                                    args.batch_size))
    
        print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
    
        # Generate predictions from the model
        expected = ['Setosa', 'Versicolor', 'Virginica']
        predict_x = {
            'SepalLength': [5.1, 5.9, 6.9],
            'SepalWidth': [3.3, 3.0, 3.1],
            'PetalLength': [1.7, 4.2, 5.4],
            'PetalWidth': [0.5, 1.5, 2.1],
        }
    
        predictions = classifier.predict(
            input_fn=lambda:iris_data.eval_input_fn(predict_x,
                                                    labels=None,
                                                    batch_size=args.batch_size))
    
        template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
    
        for pred_dict, expec in zip(predictions, expected):
            class_id = pred_dict['class_ids'][0]
            probability = pred_dict['probabilities'][class_id]
    
            print(template.format(iris_data.SPECIES[class_id],
                                  100 * probability, expec))
    
    if __name__ == '__main__':
        tf.logging.set_verbosity(tf.logging.INFO)
        tf.app.run(main)
    

    相关文章

      网友评论

          本文标题:1_get_started

          本文链接:https://www.haomeiwen.com/subject/vjmheftx.html