美文网首页
tf2 HuggingFace Transformer2.0 b

tf2 HuggingFace Transformer2.0 b

作者: 阿东7 | 来源:发表于2020-02-17 16:19 被阅读0次

    之前在网络上搜索基于tf2 的 HuggingFace Transformer2.0 资料比较少,就给自己做个笔记

    词向量原理在此不介绍

    bert原理在此不介绍

    bert的输入参数

    1. input_ids
    2. token_type_ids
    3. attention_mask

    bert的输出参数

    1. 句子字向量

    python环境

    1. tensorflow2
    2. torch1.4.1
    3. transformers2.4.1
      实际程序未用到torch


      image.png

    数据准备

    1. HuggingFace Transformer需要的数据
      bert 的相关数据参考
      https://www.cnblogs.com/lian1995/p/11947522.html
      BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
      'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
      'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
      'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
      'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
      'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
      'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
      'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
      'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
      'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
      'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
      'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
      'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
      'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
      'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
      'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
      }

    PRETRAINED_VOCAB_FILES_MAP = {
    'vocab_file':
    {
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
    'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
    'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
    'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
    'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
    'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
    'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
    'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
    'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
    }
    }

    但是我这里用了tf2的数据参数
    TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
    "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
    "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
    "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
    "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
    "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
    "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
    "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
    "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
    "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
    "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
    "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
    "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
    "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
    "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
    "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
    "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
    "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
    "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
    "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
    "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/tf_model.h5",
    }

    image.png

    ner需要的训练和测试数据

    https://codeload.github.com/liushaoweihua/keras-bert-ner/zip/master

    具体训练代码

    import os
    # os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    import tensorflow as tf
    import numpy as np
    from transformers import *
    
    #设置ner类型
    label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
    #类型个数
    label_category_total_num = len(label_category)
    #句子长度,包含几个字
    max_length = 128
    #小于128的句子,mask用0填补
    mask_padding_with_zero=True
    pad_token=0
    pad_token_segment_id=0
    #虚拟的数据,HuggingFace Transformer2.0的bert模型实例化会使用到,默认是3, 5,这边我们需要keras.build初始128个所有shape=(3,128)
    DUMMY_INPUTS = [np.random.randint(0,128,size=128), np.random.randint(0,128,size=128), np.random.randint(0,128,size=128)]
    
    #该损失函数,其实是tf复制过来的,方便调试
    def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
        return tf.keras.backend.sparse_categorical_crossentropy(
            y_true, y_pred, from_logits=True, axis=axis)
    
    #从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
    class TFMyBertModel(TFBertPreTrainedModel):
        @property
        def dummy_inputs(self):
            """ Dummy inputs to build the network.
    
            Returns:
                tf.Tensor with dummy inputs
            """
            return {"input_ids": tf.constant(DUMMY_INPUTS)}
    
        def __init__(self, config, *inputs, **kwargs):
            super().__init__(config, *inputs, **kwargs)
            self.bert = TFBertMainLayer(config, name="bert")
            #根据句子长度,生成类别的网络,这边是128个
            self.classifiers = []
            for i in range(max_length):
                self.classifiers.append(tf.keras.layers.Dense(label_category_total_num, name="classifier"+str(i)))
    
        def call(self, inputs, **kwargs):
            sequence_output, pooled_output = self.bert(inputs, **kwargs)
            print(sequence_output.shape)
    
            #传入的数据是[None, 128, 768] 转为list(128, [None, 768]),我们需要对每个字分类属于哪种的ner类型 label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
            words_output = tf.split(sequence_output, sequence_output.shape[1], 1)
            #words_output数组长度为句子长度
            logits = []
            for i,o in enumerate(words_output):
                #删除多余的维度,o这边是三维,当tf.split根据 字分组后,字还会存在多余的维度
                ot = tf.squeeze(o, axis=1)
                #生成分类结果
                logit = self.classifiers[i](ot)
                logits.append(logit)
            return logits
    
    train_path = './data_ner/train.txt'
    valid_path = './data_ner/dev.txt'
    
    #HuggingFace Transformer2.0需要的bert目录文件,这边用tf2的h5文件
    pretrained_path = './my-bert-base-chinese'
    config_path = os.path.join(pretrained_path, 'config.json')
    vocab_path = os.path.join(pretrained_path, 'vocab.txt')
    
    tokenizer = BertTokenizer.from_pretrained(vocab_path)
    # 加载config
    config = BertConfig.from_json_file(config_path)
    # 加载tf原始模型
    model = TFMyBertModel.from_pretrained(pretrained_path,from_pt=False, config=config) #from_pt是否来自pytorch,这边用tf所有设置假
    
    #文件格式
    # 中 B-ORG
    # 共 I-ORG
    # 中 I-ORG
    # 央 I-ORG
    # 获取类型
    def nerCategory(pathfile):
        srctext = tf.data.TextLineDataset(pathfile)
        label_category = set([])
        for lineText in srctext:
            # lineText 数据格式 中 B-ORG
            char_label = tf.strings.split(lineText)
            # char_label 格式 ['中', 'B-ORG']
            if len(char_label) >= 2: #如果是小于等于零代表是句子结尾
                label_category.add(char_label[1].numpy())
        return label_category
    
    #查询索引label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
    #对应的索引代表类型
    def labelNum(label):
        i = label_category.index(label) if label in label_category else (len(label_category) -1)
        return float(i)
    
    # 各 O
    # 位 O
    # 代 O
    # 表 O
    # 、 O
    # 各 O
    # 位 O
    # 同 O
    # 志 O
    # : O
    #
    # 在 O
    # 中 B-ORG
    # 国 I-ORG
    # 致 I-ORG
    #生成数据
    def fromNer(pathfile):
        input_ids_dataset = []
        attention_mask_dataset = []
        token_type_ids_datasest = []
        label_dataset = [[] for _ in range(max_length)] #根据句子长度128,分为128列的标签类型
    
        tmp_sentence = [] #保存句子内容,如['各','位']
        tmp_label = [] #保存句子字类型, 如[6,6]
        srctext = tf.data.TextLineDataset(pathfile)
        k = 0
        word_i = 0
        for lineText in srctext:
            # lineText 数据格式 中 B-ORG
            word_i += 1
            char_label = tf.strings.split(lineText)
            # char_label 格式 ['中', 'B-ORG']
            if (len(char_label) <= 0): #等于0,代表是一句结束
                #最大句子是128,但是还要存储bert的CSP\SEP,所有预留2个位置,实际一个句子只能存126个
                if len(tmp_sentence) > (max_length-2):
                    tmp_sentence = tmp_sentence[:(max_length-2)]
                if len(tmp_label) > (max_length-2):
                    tmp_label = tmp_label[:(max_length-2)]
    
                #根据bert的输入要求生成 input_ids attention_mask token_type_ids
                inputs = tokenizer.encode_plus("".join(tmp_sentence), add_special_tokens=True, max_length=max_length, )
                input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
                attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
                padding_length = max_length - len(input_ids)
                input_ids = input_ids + ([pad_token] * padding_length)
                attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
    
                #标签也补齐CLS SEP PAD
                label_ids = [labelNum('CLS')] \
                            + tmp_label + [labelNum('SEP')] \
                            + [labelNum('PAD')] * (max_length-len(tmp_label)-2)
    
                for i, l in enumerate(label_ids):
                    label_dataset[i].append(l)
    
                #tf输入需要tensor类型
                input_ids = tf.constant(input_ids)
                attention_mask = tf.constant(attention_mask)
                token_type_ids = tf.constant(token_type_ids)
    
                input_ids_dataset.append(input_ids)
                attention_mask_dataset.append(attention_mask)
                token_type_ids_datasest.append(token_type_ids)
    
                tmp_sentence.clear()
                tmp_label.clear()
                k += 1
                if k % 100 == 0:
                    print('line: %d; word: %d'%(k, word_i))
            else:
                tmp_sentence.append(char_label[0].numpy().decode(encoding='UTF-8'))
                tmp_label.append(labelNum(char_label[1].numpy().decode(encoding='UTF-8')))
    
        for i, ls in enumerate(label_dataset):
            label_dataset[i] = tf.cast(ls, dtype=tf.float32) #tf标签需要tensor类型
    
        return (
            tf.convert_to_tensor(input_ids_dataset),
            tf.convert_to_tensor(attention_mask_dataset),
            tf.convert_to_tensor(token_type_ids_datasest),
            label_dataset
        )
    
    #生成bert的输入数据
    input_ids_t_dataset, attention_mask_t_dataset, token_type_ids_t_datasest, label_t_dataset = fromNer(train_path)
    input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest, label_v_dataset = fromNer(valid_path)
    #
    optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
    loss = [sparse_categorical_crossentropy for _ in range(max_length)]
    metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
    model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
    model.summary()
    
    model.fit(x={'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest},
              y=label_t_dataset,
              epochs=2,
              steps_per_epoch=1080,
              validation_steps=7,
              validation_data=([input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest], label_v_dataset))
    
    model.save_pretrained("./mysavener")
    

    代码注意点,网络上比较少资料有介绍到

    1. 这里用的是tf2 keras h5的数据格式,参数 from_pt=False
    model = TFMyBertModel.from_pretrained(pretrained_path,from_pt=False, config=config) #from_pt是否来自pytorch,这边用tf所有设置假
    
    1. 训练输入参数{'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest}或者[input_ids_t_dataset, attention_mask_t_dataset, token_type_ids_t_datasest],可以参考HuggingFace Transformer源码


      image.png
      image.png
    2. 训练输出参数结构,总共有128,这里是我定义一句最多128个字
      label_t_dataset的结构


      image.png
      image.png

    label_t_dataset转为keras model需要的格式tensor


    image.png
    image.png
    model.fit(x={'input_ids':input_ids_t_dataset, 'attention_mask':attention_mask_t_dataset, 'token_type_ids':token_type_ids_t_datasest},
              y=label_t_dataset,
              epochs=2,
              steps_per_epoch=1080,
              validation_steps=7,
              validation_data=([input_ids_v_dataset, attention_mask_v_dataset, token_type_ids_v_datasest], label_v_dataset))
    
    1. 输出参数label_t_dataset与model的logits对应的
    image.png image.png
    #从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
    class TFMyBertModel(TFBertPreTrainedModel):
        @property
        def dummy_inputs(self):
            """ Dummy inputs to build the network.
    
            Returns:
                tf.Tensor with dummy inputs
            """
            return {"input_ids": tf.constant(DUMMY_INPUTS)}
    
        def __init__(self, config, *inputs, **kwargs):
            super().__init__(config, *inputs, **kwargs)
            self.bert = TFBertMainLayer(config, name="bert")
            #根据句子长度,生成类别的网络,这边是128个
            self.classifiers = []
            for i in range(max_length):
                self.classifiers.append(tf.keras.layers.Dense(label_category_total_num, name="classifier"+str(i)))
    
        def call(self, inputs, **kwargs):
            sequence_output, pooled_output = self.bert(inputs, **kwargs)
            print(sequence_output.shape)
    
            #传入的数据是[None, 128, 768] 转为list(128, [None, 768]),我们需要对每个字分类属于哪种的ner类型 label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
            words_output = tf.split(sequence_output, sequence_output.shape[1], 1)
            #words_output数组长度为句子长度
            logits = []
            for i,o in enumerate(words_output):
                #删除多余的维度,o这边是三维,当tf.split根据 字分组后,字还会存在多余的维度
                ot = tf.squeeze(o, axis=1)
                #生成分类结果
                logit = self.classifiers[i](ot)
                logits.append(logit)
            return logits
    

    训练代码

    import os
    # os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    import tensorflow as tf
    import numpy as np
    from transformers import *
    
    #设置ner类型
    label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
    #类型个数
    label_category_total_num = len(label_category)
    #句子长度,包含几个字
    max_length = 128
    #小于128的句子,mask用0填补
    mask_padding_with_zero=True
    pad_token=0
    pad_token_segment_id=0
    #虚拟的数据,HuggingFace Transformer2.0的bert模型实例化会使用到,默认是3, 5,这边我们需要keras.build初始128个所有shape=(3,128)
    DUMMY_INPUTS = [np.random.randint(0,128,size=128), np.random.randint(0,128,size=128), np.random.randint(0,128,size=128)]
    
    #该损失函数,其实是tf复制过来的,方便调试
    def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
        return tf.keras.backend.sparse_categorical_crossentropy(
            y_true, y_pred, from_logits=True, axis=axis)
    
    #从HuggingFace Transformer2.0 继承,这样可从bert返回结果,自己方便扩展
    class TFMyBertModel(TFBertPreTrainedModel):
        @property
        def dummy_inputs(self):
            """ Dummy inputs to build the network.
    
            Returns:
                tf.Tensor with dummy inputs
            """
            return {"input_ids": tf.constant(DUMMY_INPUTS)}
    
        def __init__(self, config, *inputs, **kwargs):
            super().__init__(config, *inputs, **kwargs)
            self.bert = TFBertMainLayer(config, name="bert")
            #根据句子长度,生成类别的网络,这边是128个
            self.classifiers = []
            for i in range(max_length):
                self.classifiers.append(tf.keras.layers.Dense(label_category_total_num, name="classifier"+str(i)))
    
        def call(self, inputs, **kwargs):
            sequence_output, pooled_output = self.bert(inputs, **kwargs)
            print(sequence_output.shape)
    
            #传入的数据是[None, 128, 768] 转为list(128, [None, 768]),我们需要对每个字分类属于哪种的ner类型 label_category = ['B-PER', 'B-ORG', 'B-LOC', 'I-ORG', 'I-PER', 'I-LOC', 'O', 'CLS', 'SEP', 'PAD', 'UNK']
            words_output = tf.split(sequence_output, sequence_output.shape[1], 1)
            #words_output数组长度为句子长度
            logits = []
            for i,o in enumerate(words_output):
                #删除多余的维度,o这边是三维,当tf.split根据 字分组后,字还会存在多余的维度
                ot = tf.squeeze(o, axis=1)
                #生成分类结果
                logit = self.classifiers[i](ot)
                logits.append(logit)
            return logits
    
    train_path = './data_ner/train.txt'
    valid_path = './data_ner/dev.txt'
    
    #HuggingFace Transformer2.0需要的bert目录文件,这边用tf2的h5文件
    pretrained_path = './my-bert-base-chinese'
    config_path = os.path.join(pretrained_path, 'config.json')
    vocab_path = os.path.join(pretrained_path, 'vocab.txt')
    
    tokenizer = BertTokenizer.from_pretrained(vocab_path)
    # 加载config
    config = BertConfig.from_json_file(config_path)
    # 加载tf原始模型
    model = TFMyBertModel.from_pretrained('./mysavener/',from_pt=False) #from_pt是否来自pytorch,这边用tf所有设置假
    
    #被预测的句子
    text = '中国的华先生,我和他谈笑风生。'
    
    #生成符合bert的输入数据input_ids attention_mask token_type_ids
    inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_length, )
    input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
    attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
    padding_length = max_length - len(input_ids)
    input_ids = input_ids + ([pad_token] * padding_length)
    attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
    token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
    
    input_ids = tf.convert_to_tensor(input_ids)[None, :]
    attention_mask = tf.convert_to_tensor(attention_mask)[None, :]
    
    pred = model(inputs={"input_ids":input_ids, 'attention_mask':attention_mask})
    print(pred[0].numpy().argmax().item())
    print(pred[1].numpy().argmax().item())
    print(pred[2].numpy().argmax().item())
    print(pred[3].numpy().argmax().item())
    
    

    完整代码参考github

    https://github.com/wengmingdong/tf2-stu/tree/master/bert4huggingface4tran

    需要的数据在百度网盘

    链接:https://pan.baidu.com/s/1dvAMo59FffwC4nKDKG-5zQ
    提取码:heeh

    相关文章

      网友评论

          本文标题:tf2 HuggingFace Transformer2.0 b

          本文链接:https://www.haomeiwen.com/subject/dxpbfhtx.html