tensorflow serving java

作者: 晴天哥_王志 | 来源:发表于2019-03-03 11:50 被阅读41次

    背景介绍

     这篇文章是tensorflow serving java api使用的参考案例,基本上把TFS的核心API的用法都介绍清楚。案例主要分为三部分:

    • 动态更新模型:用于在TFS处于runtime时候动态加载模型。
    • 获取模型状态:用于获取加载的模型的基本信息。
    • 在线模型预测:进行在线预测,分类等操作,着重介绍在线预测。

    因为模型的预测需要参考模型内部变量,所以可以先行通过TFS的REST接口获取TF模型的元数据然后才能构建TFS的RPC请求对象

    TFS 使用入门

    模型源数据获取

    curl http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/metadata
    

    说明:

        public static void getModelStatus() {
    
            // 1、设置访问的RPC协议的host和port
            ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
    
            // 2、构建PredictionServiceBlockingStub对象
            PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub =
                    PredictionServiceGrpc.newBlockingStub(channel);
    
             // 3、设置待获取的模型
            Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
                    .setName("wdl_model").build();
    
            // 4、构建获取元数据的请求
            GetModelMetadata.GetModelMetadataRequest modelMetadataRequest =
                    GetModelMetadata.GetModelMetadataRequest.newBuilder()
                            .setModelSpec(modelSpec)
                            .addAllMetadataField(Arrays.asList("signature_def"))
                            .build();
             // 5、获取元数据
            GetModelMetadata.GetModelMetadataResponse getModelMetadataResponse =
                    predictionServiceBlockingStub.getModelMetadata(modelMetadataRequest);
    
            channel.shutdownNow();
        }
    

    说明:

    • Model.ModelSpec.newBuilder绑定需要访问的模型的名字。
    • GetModelMetadataRequest中addAllMetadataField绑定curl命令返回的metadata当中的signature_def字段。

    动态更新模型

        public static void addNewModel() {
            // 1、构建动态更新模型1
            ModelServerConfigOuterClass.ModelConfig modelConfig1 =
                    ModelServerConfigOuterClass.ModelConfig.newBuilder()
                            .setBasePath("/models/new_model")
                            .setName("new_model")                      
                            .setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW)
                            .build();
    
            // 2、构建动态更新模型2
            ModelServerConfigOuterClass.ModelConfig modelConfig2 =
                    ModelServerConfigOuterClass.ModelConfig.newBuilder()
                            .setBasePath("/models/wdl_model")
                            .setName("wdl_model")
                            .setModelType(ModelServerConfigOuterClass.ModelType.TENSORFLOW)
                            .build();
            
            // 3、合并动态更新模型到ModelConfigList对象中
            ModelServerConfigOuterClass.ModelConfigList modelConfigList =
                    ModelServerConfigOuterClass.ModelConfigList.newBuilder()
                            .addConfig(modelConfig1)
                            .addConfig(modelConfig2)
                            .build();
    
            // 4、添加到ModelConfigList到ModelServerConfig对象当中
            ModelServerConfigOuterClass.ModelServerConfig modelServerConfig = 
                    ModelServerConfigOuterClass.ModelServerConfig.newBuilder()
                    .setModelConfigList(modelConfigList)
                    .build();
    
            // 5、构建ReloadConfigRequest并绑定ModelServerConfig对象。
            ModelManagement.ReloadConfigRequest reloadConfigRequest =
                    ModelManagement.ReloadConfigRequest.newBuilder()
                            .setConfig(modelServerConfig)
                            .build();
    
            // 6、构建modelServiceBlockingStub访问句柄
            ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
            ModelServiceGrpc.ModelServiceBlockingStub modelServiceBlockingStub = 
            ModelServiceGrpc.newBlockingStub(channel);
    
            ModelManagement.ReloadConfigResponse reloadConfigResponse =
                    modelServiceBlockingStub.handleReloadConfigRequest(reloadConfigRequest);
    
            System.out.println(reloadConfigResponse.getStatus().getErrorMessage());
    
            channel.shutdownNow();
    
        }
    

    说明:

    • 动态更新模型是一个全量的模型加载,在发布A模型后想动态发布B模型需要同时传递模型A和B的信息。
    • 再次强调,需要全量更新,全量更新,全量更新!!!

    在线模型预测

        public static void doPredict() throws Exception {
    
            // 1、构建feature
            Map<String, Feature> featureMap = new HashMap<>();
            featureMap.put("match_type", feature(""));
            featureMap.put("position", feature(0.0f));
            featureMap.put("brand_prefer_1d", feature(0.0f));
            featureMap.put("brand_prefer_1m", feature(0.0f));
            featureMap.put("brand_prefer_1w", feature(0.0f));
            featureMap.put("brand_prefer_2w", feature(0.0f));
            featureMap.put("browse_norm_score_1d", feature(0.0f));
            featureMap.put("browse_norm_score_1w", feature(0.0f));
            featureMap.put("browse_norm_score_2w", feature(0.0f));
            featureMap.put("buy_norm_score_1d", feature(0.0f));
            featureMap.put("buy_norm_score_1w", feature(0.0f));
            featureMap.put("buy_norm_score_2w", feature(0.0f));
            featureMap.put("cate1_prefer_1d", feature(0.0f));
            featureMap.put("cate1_prefer_2d", feature(0.0f));
            featureMap.put("cate1_prefer_1m", feature(0.0f));
            featureMap.put("cate1_prefer_1w", feature(0.0f));
            featureMap.put("cate1_prefer_2w", feature(0.0f));
            featureMap.put("cate2_prefer_1d", feature(0.0f));
            featureMap.put("cate2_prefer_1m", feature(0.0f));
            featureMap.put("cate2_prefer_1w", feature(0.0f));
            featureMap.put("cate2_prefer_2w", feature(0.0f));
            featureMap.put("cid_prefer_1d", feature(0.0f));
            featureMap.put("cid_prefer_1m", feature(0.0f));
            featureMap.put("cid_prefer_1w", feature(0.0f));
            featureMap.put("cid_prefer_2w", feature(0.0f));
            featureMap.put("user_buy_rate_1d", feature(0.0f));
            featureMap.put("user_buy_rate_2w", feature(0.0f));
            featureMap.put("user_click_rate_1d", feature(0.0f));
            featureMap.put("user_click_rate_1w", feature(0.0f));
    
            Features features = Features.newBuilder().putAllFeature(featureMap).build();
            Example example = Example.newBuilder().setFeatures(features).build();
    
            // 2、构建Predict请求
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
    
            // 3、构建模型请求维度ModelSpec,绑定模型名和预测的签名
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName("wdl_model");
            modelSpecBuilder.setSignatureName("predict");
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
    
            // 4、构建预测请求的维度信息DIM对象
            TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(300).build();
            TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(dim).build();
            TensorProto.Builder tensor = TensorProto.newBuilder();
            tensor.setTensorShape(shapeProto);
            tensor.setDtype(DataType.DT_STRING);
    
            // 5、批量绑定预测请求的数据
            for (int i=0; i<300; i++) {
                tensor.addStringVal(example.toByteString());
            }
            predictRequestBuilder.putInputs("examples", tensor.build());
    
            // 6、构建PredictionServiceBlockingStub对象准备预测
            ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
            PredictionServiceGrpc.PredictionServiceBlockingStub predictionServiceBlockingStub = 
                PredictionServiceGrpc.newBlockingStub(channel);
            
            // 7、执行预测
            Predict.PredictResponse predictResponse = 
             predictionServiceBlockingStub.predict(predictRequestBuilder.build());
          
            // 8、解析请求结果
            List<Float> floatList = predictResponse
             .getOutputsOrThrow("probabilities")
             .getFloatValList();
        }
    

    说明:

    • TFS的RPC请求过程中设置的参数需要考虑TF模型的数据结构。
    • TFS的RPC请求有同步和异步两种方式,上述只展示同步方式。

    TF模型结构

    {
        "model_spec": {
            "name": "wdl_model",
            "signature_name": "",
            "version": "4"
        },
        "metadata": {
            "signature_def": {
                "signature_def": {
                    "predict": {
                        "inputs": {
                            "examples": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                        "size": "-1",
                                        "name": ""
                                    }],
                                    "unknown_rank": false
                                },
                                "name": "input_example_tensor:0"
                            }
                        },
                        "outputs": {
                            "logistic": {
                                "dtype": "DT_FLOAT",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "1",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/logistic:0"
                            },
                            "class_ids": {
                                "dtype": "DT_INT64",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "1",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/ExpandDims:0"
                            },
                            "probabilities": {
                                "dtype": "DT_FLOAT",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "2",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/probabilities:0"
                            },
                            "classes": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "1",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/str_classes:0"
                            },
                            "logits": {
                                "dtype": "DT_FLOAT",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "1",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "add:0"
                            }
                        },
                        "method_name": "tensorflow/serving/predict"
                    },
                    "classification": {
                        "inputs": {
                            "inputs": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                        "size": "-1",
                                        "name": ""
                                    }],
                                    "unknown_rank": false
                                },
                                "name": "input_example_tensor:0"
                            }
                        },
                        "outputs": {
                            "classes": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "2",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/Tile:0"
                            },
                            "scores": {
                                "dtype": "DT_FLOAT",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "2",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/probabilities:0"
                            }
                        },
                        "method_name": "tensorflow/serving/classify"
                    },
                    "regression": {
                        "inputs": {
                            "inputs": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                        "size": "-1",
                                        "name": ""
                                    }],
                                    "unknown_rank": false
                                },
                                "name": "input_example_tensor:0"
                            }
                        },
                        "outputs": {
                            "outputs": {
                                "dtype": "DT_FLOAT",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "1",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/logistic:0"
                            }
                        },
                        "method_name": "tensorflow/serving/regress"
                    },
                    "serving_default": {
                        "inputs": {
                            "inputs": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                        "size": "-1",
                                        "name": ""
                                    }],
                                    "unknown_rank": false
                                },
                                "name": "input_example_tensor:0"
                            }
                        },
                        "outputs": {
                            "classes": {
                                "dtype": "DT_STRING",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "2",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/Tile:0"
                            },
                            "scores": {
                                "dtype": "DT_FLOAT",
                                "tensor_shape": {
                                    "dim": [{
                                            "size": "-1",
                                            "name": ""
                                        },
                                        {
                                            "size": "2",
                                            "name": ""
                                        }
                                    ],
                                    "unknown_rank": false
                                },
                                "name": "head/predictions/probabilities:0"
                            }
                        },
                        "method_name": "tensorflow/serving/classify"
                    }
                }
            }
        }
    }
    

    相关文章

      网友评论

        本文标题:tensorflow serving java

        本文链接:https://www.haomeiwen.com/subject/peyjuqtx.html