美文网首页
部署tensorflow serving+python,java

部署tensorflow serving+python,java

作者: wxrg2012 | 来源:发表于2019-01-30 17:28 被阅读0次

    本文介绍使用docker的方法部署tensorflow serving,并提供pythonjava client代码实例。(本文参考了较多博文和tensorflow官方文档,旨在补充多数博文遗留的坑,和精简官方文档的繁琐)。
    为了避免bazel编译源码这个大坑(会报一些奇怪的错误,主要是各个依赖项的版本不对应),本文直接选择docker的方式部署tensorflow serving。
    注:只需按照步骤一步一步来,就能从零到部署成功,最后会提供一个使用案例:文本分类模型

    1 Docker安装

    1.1 Mac环境下安装

    参考网站
    建议选择手动安装,安装完毕后,选择(Check for Updates)更新到最新版本

    1.2 centos环境下安装

    前提条件:CentOS 7 上,要求系统为64位、系统内核版本为 3.10 以上,通过指令uname -r 查看自己的系统版本

    移除旧的版本:

    $ sudo yum remove docker \
                      docker-client \
                      docker-client-latest \
                      docker-common \
                      docker-latest \
                      docker-latest-logrotate \
                      docker-logrotate \
                      docker-selinux \
                      docker-engine-selinux \
                      docker-engine
    

    安装依赖项:
    sudo yum install -y yum-utils device-mapper-persistent-data lvm2
    添加源信息:
    sudo yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo
    更新 yum 缓存:
    sudo yum makecache fast
    安装 Docker-ce:
    sudo yum -y install docker-ce
    启动 Docker 后台服务:
    sudo systemctl start docker
    测试运行 hello-world:
    docker run hello-world 或者 直接查看版本 docker --version

    2 serving部署

    2.1 拉取serving 镜像

    docker pull tensorflow/serving
    完成之后 查看安装好的镜像
    docker images

    2.2 导出模型

    serving不能直接使用以HDF5和.ckpt方式保存的模型,需要进行一次转化,本文以keras保存的HDF5文件为例进行介绍,.ckpt转换方式大同小异,游客可自行查询。

    import tensorflow as tf
    from keras import backend as K
    from keras.models import Sequential, Model
    from os.path import isfile
    from keras.models import load_model
    import os
    
    def save_model_to_serving(model, export_version, export_path='prod_models'):
        print(model.input, model.output)
        signature = tf.saved_model.signature_def_utils.predict_signature_def(
            inputs={'textdata': model.input}, outputs={'market': model.output})
        export_path = os.path.join(
            tf.compat.as_bytes(export_path),
            tf.compat.as_bytes(str(export_version)))
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
        builder.add_meta_graph_and_variables(
            sess=K.get_session(),
            tags=[tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'market_classification': signature,
            },
            legacy_init_op=legacy_init_op)
        builder.save()
    model = load_model('自己的路径/blistm-checkpoint-02e-val_acc_0.96.hdf5')
    save_model_to_serving(model, "1", "bgru_serving")#bgru_serving表示转换后的模型会存储到该路径下
    

    模型转化结束后会生成下面几个文件


    2.3 运行容器

    docker run -p 8500:8500 \
          --mount type=bind,source=自己的路径/bgru_serving/,target=/models/market_blstm \
          -e MODEL_NAME=market_blstm -t tensorflow/serving
    

    注:测试建议使用8500端口 ,自己的路径->绝对路径 (重点)
    各个参数的含义:

    • -p 8500:8500 :指的是开放8500这个gRPC端口
    • --mount type=bind, source=自己的路径/bgru_serving/, target=/models/market_blstm:把你导出的本地模型文件夹挂载到docker container的/models/market_blstm这个文件夹,tensorflow serving会从容器内的/models/market_blstm文件夹里面找到你的模型
    • --MODEL_NAME:模型名字,在导出模型的时候设置的名字
    • -t 指定使用tensorflow/serving这个镜像,可以替换其他版本,例如tensorflow/serving:latest-gpu,但你需要docker pull tensorflow/serving:latest-gpu把这个镜像拉下来

    3 client客户端

    3.1 python 案例

    注:最好使用python3.5+,不然如果使用高版本的tensorflow会报错
    安装依赖库sudo pip3 install tensorflow-serving-api
    客户端代码

    from __future__ import print_function
    from grpc.beta import implementations
    import tensorflow as tf
    import numpy as np
    import re,json,jieba,time
    import codecs
    import random,time
    
    from tensorflow_serving.apis import predict_pb2
    from tensorflow_serving.apis import prediction_service_pb2
    
    def loadData(filename):  #加载json文件 生成字典
        with codecs.open(filename,'r','utf-8') as fr:
            resdict = json.load(fr)
        return resdict
    
    
    vocab = loadData('vocab_bgru.dict')#   加载词典 ,格式:"中国":12045
    
    def denoise(text): #文本预处理并粉刺,再根据embegging所需的词典生成词的索引矩阵----处理单条文本数据
        x_train_word_ids = []
        tem = []
        patten=re.compile(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b',re.S)
        line = text.strip()
        line = patten.sub('',line.decode("utf-8","ignore"))
        line = re.sub(r'{url(.*)网页链接}','',line)
        line = line.replace('\\','').replace('\n',' ').replace('https://',' ')
        wordlist = [emt.strip() for emt in jieba.cut(line) if len(emt.strip())>=2]
        for i,word in enumerate(wordlist):
            try:code = vocab[word]
            except:
                try:code = vocab[word.encode('utf-8')]
                except:continue
            tem.append(code)
            x_train_word_ids.append(tem)
        if len(x_train_word_ids)==0:return [[0]]
        return x_train_word_ids
    
    def pad_sequences(x_train_word_ids,maxlen=64): #根据denoise函数得到的一条文本的索引矩阵生成符合lstm输入的词向量
        len_x = len(x_train_word_ids[0])
        if len_x>maxlen:
            res = [x_train_word_ids[0][i] for i in range(len_x-maxlen,len_x)]
            return res
        else:
            res = [0]*maxlen
            for i,emt in enumerate(x_train_word_ids[0]):
                res[maxlen-len_x+i]=emt
            return res
    
    
    tf.app.flags.DEFINE_string('server', '127.0.0.1:8500',
                               'PredictionService host:port') #ip和端口,ip可换成要连接的服务器ip
    FLAGS = tf.app.flags.FLAGS
    
    start_time = time.time()
    
    batch_size = 120
    
    host,port = FLAGS.server.split(":")
    
    channel = implementations.insecure_channel(host,int(port))
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
    
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'market_blstm' # 这个name跟tensorflow_model_server  --model_name="market_blstm" 对应
    
    request.model_spec.signature_name = 'market_classification' # 这个signature_name  跟2.2模型导出中的market_classification 对应
    
    text_list = ['吴亦凡同款 Sup扎染卫衣 全身顶级数码直喷 印花带做旧感 就是看起来脏脏的 一件衣服印花大几十块 完美还原面料为420G毛圈轻捉毛 质感很好',"360儿童5周年不止5折# 360儿童手表五周年&双十一特惠! 喜欢![失望]"]
    
    x_train = np.array([pad_sequences(denoise(text)) for text in text_list])
    request.inputs['textdata'].CopyFrom(
                      tf.contrib.util.make_tensor_proto(x_train, shape=[batch_size,64],dtype=tf.float32)) # shape跟 keras的model.input类型对应,且textdata对应2.2中的textdata
    result = stub.Predict(request, 10.0)
    reslist = result.outputs['market'].float_val
    print(reslist)
    

    结果如下:
    [0.013646061532199383, 0.9863539338111877, 0.16853764653205872, 0.8314623832702637] 每两个是一对预测数据,例如0.013646061532199383, 0.9863539338111877表示分别表示text_list中第一条数据属于0类的概率为0.013646061532199383,1类的概率为0.9863539338111877

    3.2 java案例

    pom.xml文件中的依赖项:

    <dependencies>
            <dependency>
                <groupId>com.yesup.oss</groupId>
                <artifactId>tensorflow-client</artifactId>
                <version>1.4-2</version>
            </dependency>
    
            <dependency>
                <groupId>io.grpc</groupId>
                <artifactId>grpc-netty</artifactId>
                <version>1.7.0</version>
            </dependency>
    
            <dependency>
                <groupId>io.netty</groupId>
                <artifactId>netty-tcnative-boringssl-static</artifactId>
                <version>2.0.7.Final</version>
            </dependency>
    
            <dependency>
                <groupId>com.huaban</groupId>
                <artifactId>jieba-analysis</artifactId>
                <version>1.0.2</version>
            </dependency>
    
            <dependency>
                <groupId>net.sf.json-lib</groupId>
                <artifactId>json-lib</artifactId>
                <version>2.4</version>
                <classifier>jdk15</classifier>
            </dependency>
    
            <dependency>
                <groupId>commons-io</groupId>
                <artifactId>commons-io</artifactId>
                <version>2.6</version>
            </dependency>
        </dependencies>
    

    具体代码:

    import com.huaban.analysis.jieba.JiebaSegmenter;
    import com.huaban.analysis.jieba.WordDictionary;
    import io.grpc.ManagedChannel;
    import io.grpc.ManagedChannelBuilder;
    import net.sf.json.JSONObject;
    import org.tensorflow.framework.DataType;
    import org.tensorflow.framework.TensorProto;
    import org.tensorflow.framework.TensorShapeProto;
    import tensorflow.serving.Model;
    import tensorflow.serving.Predict;
    import tensorflow.serving.PredictionServiceGrpc;
    
    public class TensorServClient {
    
        PredictionServiceGrpc.PredictionServiceBlockingStub stub = null;
    
        private JiebaSegmenter segmenter;
        private JSONObject json;
    
        private static int maxlen = 64;   //padding的最大长度
        private static int batch = 200;
    
        public TensorServClient(){
            ManagedChannel channel = ManagedChannelBuilder.forAddress("127.0.0.1",8500).usePlaintext(true).build();
            //这里还是先用block模式
            stub = PredictionServiceGrpc.newBlockingStub(channel);
    
            WordDictionary dictAdd = WordDictionary.getInstance();
            dictAdd.loadUserDict(Paths.get("jiebaextradic_java.dict"));//加载自定义词典
            segmenter = new JiebaSegmenter();
    
            try {
                json = LoadJsonFile.load("vocab_bgru.dict"); //加载词位置索引词典 ,格式:"中国":12045
            }catch (Exception ex){
                ex.printStackTrace();
            }
        }
        private ArrayList<Integer> denoise(String line){
            ArrayList<Integer>x_train_word_ids = new ArrayList<Integer>();
            line = line.replaceAll("(http|ftp|https):\\/\\/[\\w\\-_]+(\\.[\\w\\-_]+)+([\\w\\-\\.,@?^=%&amp;:/~\\+#]*[\\w\\-\\@?^=%&amp;/~\\+#])?","");
            line = line.replaceAll("\\{url(.*)网页链接\\}","");
            line = line.replaceAll("\\\\","").replaceAll("\\r|\\n","").replaceAll("https://","");
            ArrayList<String> wordjiebaList = (ArrayList<String>) segmenter.sentenceProcess(line);
            for (String word:wordjiebaList) {
                try {
                    if (this.json.containsKey(word)){
                        x_train_word_ids.add(this.json.getInt(word));
                    }
                }catch (Exception e){
                    x_train_word_ids.add(0);
                }
            }
            return x_train_word_ids;
        }
        private float[] padSequences(ArrayList<Integer>x_train_word_ids){
            float []res=new float[maxlen];
            int len_x = x_train_word_ids.size();
            if (len_x>maxlen){
                for (int i = len_x-maxlen,j=0; i < len_x; i++,j++) {
                    res[j]=x_train_word_ids.get(i);
                }
                return res;
            }else {
                for (int i = 0; i < len_x; i++) {
                    res[maxlen-len_x+i]=x_train_word_ids.get(i);
                }
                return res;
            }
        }
        private float[][]gen_predict_data(String []textlist){
            float [][] predict_data = new float[batch][maxlen];
            for (int i = 0; i < textlist.length; i++) {
                predict_data[i]=padSequences(denoise(textlist[i]));
            }
            return predict_data;
        }
    
    public void predict(String[] textlist){
            //        //创建请求
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            //模型名称和模型方法名预设
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName("market_blstm");
            modelSpecBuilder.setSignatureName("market_classification");
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
    
            //设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
            TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
            tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
    
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(batch));
            tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(maxlen));
    
            tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
    
            float[][]featuresTensorData = gen_predict_data(textlist);
    
            for (int i = 0; i < featuresTensorData.length; ++i) {
                for (int j = 0; j < featuresTensorData[i].length; ++j) {
                    tensorProtoBuilder.addFloatVal(featuresTensorData[i][j]);
                }
            }
    
            predictRequestBuilder.putInputs("textdata",tensorProtoBuilder.build());
            //访问并获取结果
            Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
            TensorProto result = predictResponse.toBuilder().getOutputsOrThrow("market");
            List<Float> reslist = result.getFloatValList();
    }
    
    public static void main(String[] args) throws Exception{
            long startTime = System.currentTimeMillis();
            TensorServClient tensorServClient = new TensorServClient();
            long midTime = System.currentTimeMillis();
            String[] textlist = {"吴亦凡同款 Sup扎染卫衣 全身顶级数码直喷 印花带做旧感 就是看起来脏脏的 一件衣服印花大几十块 完美还原面料为420G毛圈轻捉毛 质感很好","360儿童5周年不止5折# 360儿童手表五周年&双十一特惠! 喜欢![失望]",....};//这个数组的长度为 batch ,方便批处理
            tensorServClient.predict(textlist); 
     }
    }
    

    注:java 案例中textlist的长度为batch,每个位置上是一条文本;结果与python案例保持一致,亦是两个一对

    相关文章

      网友评论

          本文标题:部署tensorflow serving+python,java

          本文链接:https://www.haomeiwen.com/subject/hjlcsqtx.html