以TensorFlow Sample中的iris分类任务为例。
sample git仓库地址:https://github.com/tensorflow/models
iris目录:models/samples/core/get_started/premade_estimator.py
这个例子创建一个鸢尾花分类模型。该模型将鸢尾花分为三个类型:Setosa(山鸢尾)、Versicolor(杂色鸢尾)、Virginica(弗吉尼亚鸢尾),模型根据[萼片和花瓣]的[长度和宽度]这四个属性来推断是哪种鸢尾花。种类的翻译可能有误,但是无所谓,就是一种类别...
训练样本和测试样本的数据是从csv中读取的,前四个是属性,最后一个是分类标签:
CSV_COLUMN_NAMES = ['SepalLength','SepalWidth','PetalLength','PetalWidth','Species']
#如下四组数据:
#5.9, 3.0, 4.2, 1.5, 1
#6.9, 3.1, 5.4, 2.1, 2
#5.1, 3.3, 1.7, 0.5, 0
#6.0, 3.4, 4.5, 1.6, 1
Sample中使用具有10*10隐藏层的三分类DNN分类器,默认激活函数是relu,损失函数是softmax_cross_entropy_loss。
Sample中的训练数据为120组,经过shuffle打乱,无限重复,然后分批次(每批次batch_size=100),按批迭代训练1000次(train_steps)。训练数据量:100 * 1000(即batch_size * train_steps)。
Sample中没有输出模型和输出模型的使用,在此补上相关知识。
- CLI方式导出
feature_map = {}
#定义cli输入属性名,并映射为模型的输入Tensor名
for i in range(len(iris_data.CSV_COLUMN_NAMES) -1):
feature_map[iris_data.CSV_COLUMN_NAMES[i]] = tf.placeholder(tf.float32,shape=[None],name=iris_data.CSV_COLUMN_NAMES[i])
#将训练后的模型导出到目录./iris/下
classifier.export_savedmodel(export_dir_base='./iris/',serving_input_receiver_fn=tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map))
导出后的模型位于当前目录下的iris文件夹下,以时间戳为名的文件夹中。
在TensorFlow安装目录的bin文件夹下有saved_model_cli
,是cli执行文件。可在shell中执行如下指令,来使用导出的model:
saved_model_cli run --dir iris/1524451065/ --tag_set serve --signature_def predict --input_exprs 'SepalLength=[5.1, 5.9, 6.9];PetalLength=[1.7, 4.2, 5.4];PetalWidth=[0.5, 1.5, 2.1];SepalWidth=[3.3, 3.0, 3.1]'
//输出如下
Result for output key class_ids:
[[0]
[1]
[2]]
Result for output key classes:
[['0']
['1']
['2']]
Result for output key logits:
[[ 4.9489717 -0.6004308 -24.400116 ]
[ -5.10651 2.5732546 -3.476701 ]
[ -7.8528028 1.4337913 5.3207707]]
Result for output key probabilities:
[[9.9612528e-01 3.8747080e-03 1.7871768e-13]
[4.6078415e-04 9.9718791e-01 2.3513364e-03]
[1.8619706e-06 2.0095062e-02 9.7990304e-01]]
上面在cli中输入了三组数据,运用模型预测结果分别输入分类0,1,2,logits和probabilities是模型输出结果和softmax结果,代表三组数据属于三个分类的概率。
也可以运行如下指令来查看导出model的SignatureDef,所谓SignatureDef就是模型的输入输出情况:
saved_model_cli show --dir iris/1524451065/ --tag_set serve --signature_def predict
//输出情况大致如下:
The given SavedModel SignatureDef contains the following input(s):
inputs['PetalLength'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: PetalLength:0
inputs['PetalWidth'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: PetalWidth:0
inputs['SepalLength'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: SepalLength:0
inputs['SepalWidth'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: SepalWidth:0
The given SavedModel SignatureDef contains the following output(s):
outputs['class_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, 1)
name: dnn/head/predictions/ExpandDims:0
outputs['classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 1)
name: dnn/head/predictions/str_classes:0
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 3)
name: dnn/logits/BiasAdd:0
outputs['probabilities'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 3)
name: dnn/head/predictions/probabilities:0
Method name is: tensorflow/serving/predict
- Serving方式
根据官网提供的Installing from source方法安装Serving。系统平台osx 10.11.6。主要步骤如下
# 在tensorFlow目录下拉代码:/Users/xxx/mycode/tensorFlow
cd tensorFlow
# 切换python环境virtualenv
source bin/activate
git clone --recurse-submodules https://github.com/tensorflow/serving
cd serving
# 开始编译
bazel build -c opt tensorflow_serving/...
#注意:
#1. mac上xcode的版本要在8.0版本以上,不然编译报错pthread。
#2. 编译过程中可能出现python依赖模块的缺失,安装好继续编译就好。
#3. python最好切换到virtualenv环境下,在系统环境下有些模块的版本无法升级,导致编译不通过。
#4. 编译产物在bazel-bin下,model server位置:
# bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server
网友评论