美文网首页
TensorFlow Model export

TensorFlow Model export

作者: 倪伟_2131 | 来源:发表于2018-04-23 14:26 被阅读0次

    以TensorFlow Sample中的iris分类任务为例。
    sample git仓库地址:https://github.com/tensorflow/models
    iris目录:models/samples/core/get_started/premade_estimator.py

    这个例子创建一个鸢尾花分类模型。该模型将鸢尾花分为三个类型:Setosa(山鸢尾)、Versicolor(杂色鸢尾)、Virginica(弗吉尼亚鸢尾),模型根据[萼片和花瓣]的[长度和宽度]这四个属性来推断是哪种鸢尾花。种类的翻译可能有误,但是无所谓,就是一种类别...
    训练样本和测试样本的数据是从csv中读取的,前四个是属性,最后一个是分类标签:

    CSV_COLUMN_NAMES = ['SepalLength','SepalWidth','PetalLength','PetalWidth','Species']
    #如下四组数据:
    #5.9, 3.0, 4.2, 1.5, 1
    #6.9, 3.1, 5.4, 2.1, 2
    #5.1, 3.3, 1.7, 0.5, 0
    #6.0, 3.4, 4.5, 1.6, 1
    

    Sample中使用具有10*10隐藏层的三分类DNN分类器,默认激活函数是relu,损失函数是softmax_cross_entropy_loss。

    Sample中的训练数据为120组,经过shuffle打乱,无限重复,然后分批次(每批次batch_size=100),按批迭代训练1000次(train_steps)。训练数据量:100 * 1000(即batch_size * train_steps)。
    Sample中没有输出模型和输出模型的使用,在此补上相关知识。

    1. CLI方式导出
    feature_map = {}
    #定义cli输入属性名,并映射为模型的输入Tensor名
    for i in range(len(iris_data.CSV_COLUMN_NAMES) -1):
        feature_map[iris_data.CSV_COLUMN_NAMES[i]] = tf.placeholder(tf.float32,shape=[None],name=iris_data.CSV_COLUMN_NAMES[i])
    #将训练后的模型导出到目录./iris/下
    classifier.export_savedmodel(export_dir_base='./iris/',serving_input_receiver_fn=tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map))
    

    导出后的模型位于当前目录下的iris文件夹下,以时间戳为名的文件夹中。
    在TensorFlow安装目录的bin文件夹下有saved_model_cli,是cli执行文件。可在shell中执行如下指令,来使用导出的model:

    saved_model_cli run --dir iris/1524451065/ --tag_set serve --signature_def predict --input_exprs 'SepalLength=[5.1, 5.9, 6.9];PetalLength=[1.7, 4.2, 5.4];PetalWidth=[0.5, 1.5, 2.1];SepalWidth=[3.3, 3.0, 3.1]'
    //输出如下
    Result for output key class_ids:
    [[0]
     [1]
     [2]]
    Result for output key classes:
    [['0']
     ['1']
     ['2']]
    Result for output key logits:
    [[  4.9489717  -0.6004308 -24.400116 ]
     [ -5.10651     2.5732546  -3.476701 ]
     [ -7.8528028   1.4337913   5.3207707]]
    Result for output key probabilities:
    [[9.9612528e-01 3.8747080e-03 1.7871768e-13]
     [4.6078415e-04 9.9718791e-01 2.3513364e-03]
     [1.8619706e-06 2.0095062e-02 9.7990304e-01]]
    

    上面在cli中输入了三组数据,运用模型预测结果分别输入分类0,1,2,logits和probabilities是模型输出结果和softmax结果,代表三组数据属于三个分类的概率。

    也可以运行如下指令来查看导出model的SignatureDef,所谓SignatureDef就是模型的输入输出情况:

    saved_model_cli show --dir iris/1524451065/ --tag_set serve --signature_def predict
    
    //输出情况大致如下:
    The given SavedModel SignatureDef contains the following input(s):
    inputs['PetalLength'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: PetalLength:0
    inputs['PetalWidth'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: PetalWidth:0
    inputs['SepalLength'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: SepalLength:0
    inputs['SepalWidth'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: SepalWidth:0
    The given SavedModel SignatureDef contains the following output(s):
    outputs['class_ids'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 1)
        name: dnn/head/predictions/ExpandDims:0
    outputs['classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: dnn/head/predictions/str_classes:0
    outputs['logits'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 3)
        name: dnn/logits/BiasAdd:0
    outputs['probabilities'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 3)
        name: dnn/head/predictions/probabilities:0
    Method name is: tensorflow/serving/predict
    
    1. Serving方式
      根据官网提供的Installing from source方法安装Serving。系统平台osx 10.11.6。主要步骤如下
    # 在tensorFlow目录下拉代码:/Users/xxx/mycode/tensorFlow
    cd tensorFlow
    # 切换python环境virtualenv
    source bin/activate
    
    git clone --recurse-submodules https://github.com/tensorflow/serving
    cd serving
    # 开始编译
    bazel build -c opt tensorflow_serving/...
    
    #注意:
    #1. mac上xcode的版本要在8.0版本以上,不然编译报错pthread。
    #2. 编译过程中可能出现python依赖模块的缺失,安装好继续编译就好。
    #3. python最好切换到virtualenv环境下,在系统环境下有些模块的版本无法升级,导致编译不通过。
    #4. 编译产物在bazel-bin下,model server位置:
    # bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server
    

    相关文章

      网友评论

          本文标题:TensorFlow Model export

          本文链接:https://www.haomeiwen.com/subject/yjtllftx.html