环境构建
前置环境安装:
pip install bert-serving-server
pip install bert-serving-client
及其依赖,如tensorflow>=1.10
GitHub地址:链接
预训练模型:chinese_L-12_H-768_A-12.zip
通过容器方式进行服务启动,容器名:bert-as-service ,Dockerfile(cpu)文件如下:
FROM tensorflow/tensorflow:1.12.0-py3
RUN pip install bert-serving-server
COPY ./ /app
#COPY ./docker/entrypoint.sh /app
WORKDIR /app
ENTRYPOINT ["/app/entrypoint.sh"]
CMD []
服务端启动命令样例
bert-serving-start -num_worker=4 -model_dir /data/tools/chinese_L-12_H-768_A-12 -max_seq_len=20
服务端可配置参数
(1)支持GPU、CPU计算,通过-cpu参数指定,默认为false,即默认是使用GPU进行计算;
(2)支持设置工作线程数,通过-num_worker,最好设置为对应的cpu或gpu核数
(3)指定预训练模型的路径,通过-model_dir,常将预训练好的BERT模型下载下来放到该目录,该项须指定
(4)指定微调模型的路径,通过-tuned_model_dir,通常将微调好的BERT模型生成到该目录
(5)指定checkpoint文件,通过-ckpt_name,默认文件名为bert_model.ckpt,常与预训练模型在同一目录
(6)指定配置文件,通过-config_name,默认文件名为bert_config.json,常与预训练模型在同一目录
(7)指定临时的graph文件,通过-graph_tmp_dir,默认是/tmp/XXXXXX
(8)指定最大序列长度,通过-max_seq_len,使用None来动态加载(mini)batch做为最大序列长度,默认取值25,实际在我的机器上面设置20即可,否则服务启动的时候会卡住
(9)指定每个worker能够处理的最大序列数,通过-max_batch_size,默认256
(10)指定高优先级处理的batch大小,通过-priority_batch_size,默认值16
(11)指定客户端向服务端push数据的端口号,通过-port,默认5555
(12)指定服务端向客户端发布结果的的端口号,通过-port_out,默认5556
(13)指定服务端接受http请求的端口号,通过-http_port,无默认值
(14)指定池化-采样策略,通过-pooling_strategy,默认取值REDUCE_MEAN,取值限制在:NONE, REDUCE_MEAN, REDUCE_MAX, REDUCE_MEAN_MAX, CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN ,为了获取sequence中的每个token应该将其置为NONE
(15)指定池化-采样层,通过-pooling_layer,默认取值[-2], -1表示最后一层, -2 倒数第二层, [-1, -2] 表示连接最后两层的结果
(16)指定XLA compiler 进行graph的优化,通过-xla,默认取值false
客户端可配置参数
(1)指定bert服务端的ip,通过-ip指定,默认是localhost
(2)指定客户端向服务端push数据的端口号,通过-port,默认5555,需与服务端配置一致
(3)指定服务端向客户端发布结果的的端口号,通过-port_out,默认5556,需与服务端配置一致
(4)指定 sentence编码格式,通过-output_fmt ,默认ndarray,也可使用list
(5)指定是否在开始连接的时候显示服务端配置,通过-show_server_config
(6)指定是否要求服务端与客户端版本保持一致性,通过-check_version,默认取值false
(7)指定multi-casting模式下客户端的uuid以对多个客户端进行区分,通过-identity,默认取值None
(8)指定客户端接受操作的超时时间,默认-1(没有超时限制),单位ms
文本相似性度量
基于余弦夹角进行相似度评估,源码借鉴网上的思路,具体如下:
class Nlper:
def __init__(self, bert_client):
self.bert_client = bert_client
def get_text_similarity(self, base_text, compaired_text, algorithm='cosine'):
if isinstance(algorithm, str) and algorithm.lower() == 'cosine':
arrays = self.bert_client.encode([base_text, compaired_text])
norm_1 = np.linalg.norm(arrays[0])
norm_2 = np.linalg.norm(arrays[1])
dot_product = np.dot(arrays[0], arrays[1])
similarity = round(0.5 + 0.5 * (dot_product / (norm_1 * norm_2)), 2)
return similarity
if __name__ == '__main__':
bc = BertClient()
nlper = Nlper(bert_client=bc)
similarity = nlper.get_text_similarity('你好', '您好')
print(similarity)
备注
直接使用默认的max_seq_len,可能会导致server无法启动,开发环境定义-max_seq_len=20正常启动
网友评论