美文网首页
从数据的角度理解TensorFlow鸢尾花分类程序7

从数据的角度理解TensorFlow鸢尾花分类程序7

作者: LabVIEW_Python | 来源:发表于2018-06-14 11:55 被阅读36次

上节,本节继续分析:4,训练模型:

如代码所示:

classifier = tf.estimator.DNNClassifier(

        feature_columns = my_feature_column,

        #Two hidden layers of 10 nodes each.

        hidden_units = [10, 10],

        #The model must choose between 3 classes

        n_classes = 3

    )

实例化 tf.Estimator.DNNClassifier 后,会创建一个用于学习模型的框架,存储在classifier对象中

调用classifier对象的train方法,可以实现模型训练,如下图所示:

classifier.train

这里,重点分析一下输入参数input_fn

input_fn: A function that provides input data for training as minibatches。一个为模型训练提供输入数据的函数。该函数返回:

一个 “tf.data.Dataset”对象,该对象是由 (features, labels)构成的tuple,如下图所示:

train_input_fn

train方法传递给函数 train_input_fn(features, labels, batch_size)的参数是:train_x, train_yargs.batch_size

Dataframe类型的变量train_x储存训练数据集的特征值,Series类型的变量train_y储存训练数据集的特征标签,args.batch_size定义批次大小

train_input_fn 函数依赖于 Dataset API。这是High-Level TensorFlow API,用于读取数据并将其转换为 train 方法所需的格式。以下调用会将输入特征和标签转换为 tf.data.Dataset 对象,代码实现如下所示:

dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

tf.dataset 类提供很多用于准备训练样本的实用函数。如果训练样本是随机排列的,则训练效果最好。要对样本进行随机化处理,调用shuffle方法,buffer_size 设置为大于样本数 (120) 的值可确保数据得到充分的随机化处理。

在训练期间,train 方法通常会多次处理样本。在不使用任何参数的情况下调用repeat 方法可确保 train 方法拥有无限量的训练集样本(现已得到随机化处理)。

train 方法一次处理一样本。tf.data.Dataset.batch 方法通过组合多个样本来创建一个批次。一般来说,较小的批次大小通常会使 train 方法(有时)以牺牲准确率为代价来加快训练模型

shuffle+repeat+batch 组合代码实现,如下所示:

dataset = dataset.shuffle(buffer_size=1000).repeat(count=None).batch(batch_size)

经过处理的数据,通过return语句,返回调用方(train方法)。

到此:4,训练模型 分析完毕。

相关文章

网友评论

      本文标题:从数据的角度理解TensorFlow鸢尾花分类程序7

      本文链接:https://www.haomeiwen.com/subject/pugeeftx.html