美文网首页
bert用于文本相似性计算

bert用于文本相似性计算

作者: 佛系小懒 | 来源:发表于2020-02-08 19:37 被阅读0次

    环境构建

    前置环境安装:

    pip install bert-serving-server

    pip install bert-serving-client

    及其依赖,如tensorflow>=1.10

    GitHub地址:链接

    预训练模型:chinese_L-12_H-768_A-12.zip

    通过容器方式进行服务启动,容器名:bert-as-service ,Dockerfile(cpu)文件如下:

    FROM tensorflow/tensorflow:1.12.0-py3

    RUN pip install bert-serving-server

    COPY ./ /app

    #COPY ./docker/entrypoint.sh /app

    WORKDIR /app

    ENTRYPOINT ["/app/entrypoint.sh"]

    CMD []

    服务端启动命令样例

    bert-serving-start -num_worker=4 -model_dir /data/tools/chinese_L-12_H-768_A-12 -max_seq_len=20

    服务端可配置参数

    (1)支持GPU、CPU计算,通过-cpu参数指定,默认为false,即默认是使用GPU进行计算;

    (2)支持设置工作线程数,通过-num_worker,最好设置为对应的cpu或gpu核数

    (3)指定预训练模型的路径,通过-model_dir,常将预训练好的BERT模型下载下来放到该目录,该项须指定

    (4)指定微调模型的路径,通过-tuned_model_dir,通常将微调好的BERT模型生成到该目录

    (5)指定checkpoint文件,通过-ckpt_name,默认文件名为bert_model.ckpt,常与预训练模型在同一目录

    (6)指定配置文件,通过-config_name,默认文件名为bert_config.json,常与预训练模型在同一目录

    (7)指定临时的graph文件,通过-graph_tmp_dir,默认是/tmp/XXXXXX

    (8)指定最大序列长度,通过-max_seq_len,使用None来动态加载(mini)batch做为最大序列长度,默认取值25,实际在我的机器上面设置20即可,否则服务启动的时候会卡住

    (9)指定每个worker能够处理的最大序列数,通过-max_batch_size,默认256

    (10)指定高优先级处理的batch大小,通过-priority_batch_size,默认值16

    (11)指定客户端向服务端push数据的端口号,通过-port,默认5555

    (12)指定服务端向客户端发布结果的的端口号,通过-port_out,默认5556

    (13)指定服务端接受http请求的端口号,通过-http_port,无默认值

    (14)指定池化-采样策略,通过-pooling_strategy,默认取值REDUCE_MEAN,取值限制在:NONE, REDUCE_MEAN, REDUCE_MAX, REDUCE_MEAN_MAX, CLS_TOKEN, FIRST_TOKEN, SEP_TOKEN, LAST_TOKEN ,为了获取sequence中的每个token应该将其置为NONE

    (15)指定池化-采样层,通过-pooling_layer,默认取值[-2], -1表示最后一层, -2 倒数第二层, [-1, -2] 表示连接最后两层的结果

    (16)指定XLA compiler 进行graph的优化,通过-xla,默认取值false

    客户端可配置参数

    (1)指定bert服务端的ip,通过-ip指定,默认是localhost

    (2)指定客户端向服务端push数据的端口号,通过-port,默认5555,需与服务端配置一致

    (3)指定服务端向客户端发布结果的的端口号,通过-port_out,默认5556,需与服务端配置一致

    (4)指定 sentence编码格式,通过-output_fmt ,默认ndarray,也可使用list

    (5)指定是否在开始连接的时候显示服务端配置,通过-show_server_config

    (6)指定是否要求服务端与客户端版本保持一致性,通过-check_version,默认取值false

    (7)指定multi-casting模式下客户端的uuid以对多个客户端进行区分,通过-identity,默认取值None

    (8)指定客户端接受操作的超时时间,默认-1(没有超时限制),单位ms

    文本相似性度量

    基于余弦夹角进行相似度评估,源码借鉴网上的思路,具体如下:

    class Nlper:

        def __init__(self, bert_client):

            self.bert_client = bert_client

        def get_text_similarity(self, base_text, compaired_text, algorithm='cosine'):

            if isinstance(algorithm, str) and algorithm.lower() == 'cosine':

                arrays = self.bert_client.encode([base_text, compaired_text])

                norm_1 = np.linalg.norm(arrays[0])

                norm_2 = np.linalg.norm(arrays[1])

                dot_product = np.dot(arrays[0], arrays[1])

                similarity = round(0.5 + 0.5 * (dot_product / (norm_1 * norm_2)), 2)

                return similarity

    if __name__ == '__main__':

        bc = BertClient()

        nlper = Nlper(bert_client=bc)

        similarity = nlper.get_text_similarity('你好', '您好')

        print(similarity)

    备注

    直接使用默认的max_seq_len,可能会导致server无法启动,开发环境定义-max_seq_len=20正常启动

    相关文章

      网友评论

          本文标题:bert用于文本相似性计算

          本文链接:https://www.haomeiwen.com/subject/mztwxhtx.html