接上节,本节继续分析:4,训练模型:
如代码所示:
classifier = tf.estimator.DNNClassifier(
feature_columns = my_feature_column,
#Two hidden layers of 10 nodes each.
hidden_units = [10, 10],
#The model must choose between 3 classes
n_classes = 3
)
实例化 tf.Estimator.DNNClassifier 后,会创建一个用于学习模型的框架,存储在classifier对象中。
调用classifier对象的train方法,可以实现模型训练,如下图所示:

这里,重点分析一下输入参数input_fn
input_fn: A function that provides input data for training as minibatches。一个为模型训练提供输入数据的函数。该函数返回:
一个 “tf.data.Dataset”对象,该对象是由 (features, labels)构成的tuple,如下图所示:

train方法传递给函数 train_input_fn(features, labels, batch_size)的参数是:train_x, train_y 和args.batch_size
Dataframe类型的变量train_x储存训练数据集的特征值,Series类型的变量train_y储存训练数据集的特征标签,args.batch_size定义批次大小。
train_input_fn 函数依赖于 Dataset API。这是High-Level TensorFlow API,用于读取数据并将其转换为 train 方法所需的格式。以下调用会将输入特征和标签转换为 tf.data.Dataset 对象,代码实现如下所示:
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
tf.dataset 类提供很多用于准备训练样本的实用函数。如果训练样本是随机排列的,则训练效果最好。要对样本进行随机化处理,调用shuffle方法,buffer_size 设置为大于样本数 (120) 的值可确保数据得到充分的随机化处理。
在训练期间,train 方法通常会多次处理样本。在不使用任何参数的情况下调用repeat 方法可确保 train 方法拥有无限量的训练集样本(现已得到随机化处理)。
train 方法一次处理一批样本。tf.data.Dataset.batch 方法通过组合多个样本来创建一个批次。一般来说,较小的批次大小通常会使 train 方法(有时)以牺牲准确率为代价来加快训练模型
shuffle+repeat+batch 组合代码实现,如下所示:
dataset = dataset.shuffle(buffer_size=1000).repeat(count=None).batch(batch_size)
经过处理的数据,通过return语句,返回调用方(train方法)。
到此:4,训练模型 分析完毕。
网友评论