TensorFlow入门(四)高级API用法示例

作者: 迅速傅里叶变换 | 来源:发表于2017-06-12 12:57 被阅读665次

    tf.contrib.learn Quickstart

    TensorFlow的机器学习高级API(tf.contrib.learn)使配置、训练、评估不同的学习模型变得更加容易。在这个教程里,你将使用tf.contrib.learn在Iris data set上构建一个神经网络分类器。代码有一下5个步骤:

    1. 在TensorFlow数据集上加载Iris
    2. 构建神经网络
    3. 用训练数据拟合
    4. 评估模型的准确性
    5. 在新样本上分类

    Complete Neural Network Source Code


    这里是神经网络的源代码:

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os
    import urllib
    
    import numpy as np
    import tensorflow as tf
    
    # Data sets
    IRIS_TRAINING = "iris_training.csv"
    IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
    
    IRIS_TEST = "iris_test.csv"
    IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
    
    def main():
      # If the training and test sets aren't stored locally, download them.
      if not os.path.exists(IRIS_TRAINING):
        raw = urllib.urlopen(IRIS_TRAINING_URL).read()
        with open(IRIS_TRAINING, "w") as f:
          f.write(raw)
    
      if not os.path.exists(IRIS_TEST):
        raw = urllib.urlopen(IRIS_TEST_URL).read()
        with open(IRIS_TEST, "w") as f:
          f.write(raw)
    
      # Load datasets.
      training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
          filename=IRIS_TRAINING,
          target_dtype=np.int,
          features_dtype=np.float32)
      test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
          filename=IRIS_TEST,
          target_dtype=np.int,
          features_dtype=np.float32)
    
      # Specify that all features have real-value data
      feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
    
      # Build 3 layer DNN with 10, 20, 10 units respectively.
      classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                                  hidden_units=[10, 20, 10],
                                                  n_classes=3,
                                                  model_dir="/tmp/iris_model")
      # Define the training inputs
      def get_train_inputs():
        x = tf.constant(training_set.data)
        y = tf.constant(training_set.target)
    
        return x, y
    
      # Fit model.
      classifier.fit(input_fn=get_train_inputs, steps=2000)
    
      # Define the test inputs
      def get_test_inputs():
        x = tf.constant(test_set.data)
        y = tf.constant(test_set.target)
    
        return x, y
    
      # Evaluate accuracy.
      accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
                                           steps=1)["accuracy"]
    
      print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
    
      # Classify two new flower samples.
      def new_samples():
        return np.array(
          [[6.4, 3.2, 4.5, 1.5],
           [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
    
      predictions = list(classifier.predict(input_fn=new_samples))
    
      print(
          "New Samples, Class Predictions:    {}\n"
          .format(predictions))
    
    if __name__ == "__main__":
        main()
    

    Load the Iris CSV data to TensorFlow


    Iris data set包含了150行数据,3个种类:Iris setosa, Iris virginica, and Iris versicolor.

    每一行包括了以下的数据:花萼的宽度,长度,花瓣的宽度,花的种类。花的种类有整数表示,0表示Iris setosa, 1表示Iris virginica, 2表示Iris versicolor.

    Sepal Length Sepal Width Petal Length Petal Width Species
    5.1 3.5 1.4 0.2 0
    4.9 3.0 1.4 0.2 0
    4.7 3.2 1.3 0.2 0
    7.0 3.2 4.7 1.4 1
    6.4 3.2 4.5 1.5 1
    6.9 3.1 4.9 1.5 1
    6.5 3.0 5.2 2.0 2
    6.2 3.4 5.4 2.3 2
    5.9 3.0 5.1 1.8 2

    这里,Iris数据随机分割成了两组不同的CSV文件:

    • 120个样本的训练数据(iris_training.csv)
    • 30个样本的测试数据(iris_test.csv).

    开始时,首先引进所有必要的模块,然后定义下载存储数据集的路径:

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import os
    import urllib
    
    import tensorflow as tf
    import numpy as np
    
    IRIS_TRAINING = "iris_training.csv"
    IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
    
    IRIS_TEST = "iris_test.csv"
    IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
    

    然后,如果训练和测试集没有在本地存储,下载:

    if not os.path.exists(IRIS_TRAINING):
      raw = urllib.urlopen(IRIS_TRAINING_URL).read()
      with open(IRIS_TRAINING,'w') as f:
        f.write(raw)
    
    if not os.path.exists(IRIS_TEST):
      raw = urllib.urlopen(IRIS_TEST_URL).read()
      with open(IRIS_TEST,'w') as f:
        f.write(raw)
    

    然后,用learn.datasets.base的load_csv_with_header()方法加载训练集和测试集成Dataset S,load_csv_with_header()包涵一下三个参数:

    • filename,CSV文件的路径
    • target_dtype,数据集目标值的numpy数据类型
    • features_dtype,数据集特征值的numpy数据类型

    这里,目标是花的种类,是0-2的整数,所以数据类型是np.int:

    # Load datasets.
    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=IRIS_TRAINING,
        target_dtype=np.int,
        features_dtype=np.float32)
    test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=IRIS_TEST,
        target_dtype=np.int,
        features_dtype=np.float32)
    

    tf.contrib.learn中的Dataset S是tuple,你可以通过data,target来访问特征值和目标值,比如,training_set.data,training_set.target

    Construct a Deep Neural Network Classifier


    tf.contrib.learn提供了多种预定义的模型,称为 Estimator S,你可以用“黑盒子”在你的数据上来训练和评估节点。这里,你讲配置深度神经网络分类器来拟合Iris数据,你可以用tf.contrib.learn.DNNClassifier作为示例:

    # Specify that all features have real-value data
    feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
    
    # Build 3 layer DNN with 10, 20, 10 units respectively.
    classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                                hidden_units=[10, 20, 10],
                                                n_classes=3,
                                                model_dir="/tmp/iris_model")
    

    首先定义特征所在的列,有4个特征,所以dimension设定为4.

    然后,构建了DNNClassifier,包含以下参数:

    • feature_columns=feature_columns.上面定义的特征的列
    • hidden_units=[10, 20, 10]. 三个隐层,分别包含10,20,10个神经元
    • n_classes=3.三个目标
    • model_dir=/tmp/iris_model.训练模型时保存的断点数据

    Describe the training input pipeline


    tf.contrib.learn API使用输入函数,创建TensorFlow节点来生成模型数据。这里,数据比较小,可以放在tf.constant。

    # Define the test inputs
    def get_train_inputs():
      x = tf.constant(training_set.data)
      y = tf.constant(training_set.target)
    
      return x, y
    

    Fit the DNNClassifier to the Iris Training Data


    配置了DNN分类器,你可以用fit方法来拟合数据,传递get_train_inputs到input_fn参数中,循环训练2000次:

    # Fit model.
    classifier.fit(input_fn=get_train_inputs, steps=2000)
    

    等效于:

    classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
    classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
    

    如果你想追踪训练模型,你可以用TensorFlow monitor来执行节点的日志。
    “Logging and Monitoring Basics with tf.contrib.learn”

    Evaluate Model Accuracy


    你已经用训练数据拟合了模型,现在,你可以用evaluate方法在测试集上评估准确性。像fit一样,evaluate也需要一个输入函数来构建输入的通道,并返回评估结果的字典。

    # Define the test inputs
    def get_test_inputs():
      x = tf.constant(test_set.data)
      y = tf.constant(test_set.target)
    
      return x, y
    
    # Evaluate accuracy.
    accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
                                         steps=1)["accuracy"]
    
    print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
    

    运行整个脚本,打印:

    Test Accuracy: 0.966667
    

    Classify New Samples


    用predict()方法来分类新的样本,比如,你有下面的两个新样本:

    Sepal Length Sepal Width Petal Length Petal Width
    6.4 3.2 4.5 1.5
    5.8 3.1 5.0 1.7

    predict方法返回一个generator,可以转换成list

    # Classify two new flower samples.
    def new_samples():
      return np.array(
        [[6.4, 3.2, 4.5, 1.5],
         [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
    
    predictions = list(classifier.predict(input_fn=new_samples))
    
    print(
        "New Samples, Class Predictions:    {}\n"
        .format(predictions))
    

    结果大致如下:

    New Samples, Class Predictions:    [1 2]
    

    相关文章

      网友评论

        本文标题:TensorFlow入门(四)高级API用法示例

        本文链接:https://www.haomeiwen.com/subject/wddlqxtx.html