美文网首页我爱编程
TensorFlow入门18: 创建自定义的Estimator

TensorFlow入门18: 创建自定义的Estimator

作者: LabVIEW_Python | 来源:发表于2018-06-16 21:11 被阅读67次

上节,详细分析了在自定义模型函数中,创建神经网络的三个步骤:创建输入层、创建隐藏层、创建输出层。本节主要介绍自定义模型函数的最后一步:编写实现预测、评估和训练的分支代码

回忆一下:《TensorFlow入门16: 创建自定义的Estimator 2

                  1,Model_fn的返回值是: tf.estimator.EstimatorSpec

                  2,Estimator对象的三个方法train、evaluate、predict都会调用model_fn给Estimator传参数。

                  3,当Estimator对象调用 train、evaluate 或 predict 方法时,Estimator 对象会在调用模型函数前,将 mode 参数设置为对应的值:ModeKeys.TRAIN、ModeKeys.EVAL、ModeKeys.PREDICT。

由此,model_fn函数创建好神经网络后,检测mode值,根据不同的mode,实现对应的代码,并返回: tf.estimator.EstimatorSpec,具体的实现,参考下图:

model_fn的代码实现

完成model_fn函数编写后,回到main函数,可以发现,只有创建classifier对象的代码,略有不同,其余代码一模一样,如下图所示

创建classifier对象代码比较

相关文章

网友评论

    本文标题:TensorFlow入门18: 创建自定义的Estimator

    本文链接:https://www.haomeiwen.com/subject/jwabeftx.html