美文网首页
transformers中的bert用法

transformers中的bert用法

作者: 雪糕遇上夏天 | 来源:发表于2021-07-24 16:40 被阅读0次

    1. Bert模型下载

    这里直接使用huggingface提供的pre-trained的bert模型,直接去官网即可搜索想要的模型并下载:https://huggingface.co/models

    这里以bert-base-chinese为例。首先将其下载到本地

    git lfs install
    git clone https://huggingface.co/bert-base-chinese
    

    注意此时下载的模型,还不完成,需要我们手动下载pytorch_model.bin到模型目录下。

    image.png

    具体做法是点击Files and versions,下载pytorch_model.bin,覆盖掉模型目录原有的同名文件。

    image.png

    至此呢,我们就把模型准备好了。

    2. 在transformers中使用

    在正式使用之前,首先要安装transformers包。

    pip install transformers
    

    然后既可以正式使用啦,首先根据模型所在目录加载tokenizer和model。

    import torch
    from transformers import BertModel, BertConfig, BertTokenizer
    
    modle_path = '/xxx/bert-base-chinese'
    tokenizer = BertTokenizer.from_pretrained(modle_path)
    model = BertModel.from_pretrained(modle_path)
    input_ids = torch.tensor([tokenizer.encode("五福临门", add_special_tokens=True)])
    with torch.no_grad():
      output = model(input_ids)
      last_hidden_state = output[0]
      pooler_output = output[1]
      print(last_hidden_state[:, 0, :])
    

    然后通过tokenizer将我们想encode的句子编码成id,注意[CLS]和[SEP]。

    input_ids = torch.tensor([tokenizer.encode("五福临门", add_special_tokens=True)])
    
    input_ids
    tensor([[ 101,  758, 4886,  707, 7305,  102]])
    
    input_ids
    tensor([[ 101,  758, 4886,  707, 7305,  102]])
    

    可以看到input_ids的长度跟输入的“五福临门”并不一样,这是为什么呢,我们继续看一下:

    tokenizer.convert_ids_to_tokens(tokenizer.encode('五福临门'))
    ['[CLS]', '五', '福', '临', '门', '[SEP]']
    

    原来在tokenizer帮我们把句子转换成id是,已经为我们添加好了[CLS],[SEP]等信息。

    有了input_ids之后,就可以进一步进行编码了。

    output = model(input_ids)
    last_hidden_state = output[0]
    pooler_output = output[1]
    
    print(last_hidden_state.shape)
    torch.Size([1, 6, 768])
    print(pooler_output.shape)
    torch.Size([1, 768])
    

    last_hidden_state为句子中每个字的编码,包括[CLS],pooler_output是经过pool之后的输出。

    有的同学可能会有疑问,Bert的输入不是还有attenion_masks和token_type_ids吗。

    if attention_mask is None:
      attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
    
    if token_type_ids is None:
      if hasattr(self.embeddings, "token_type_ids"):
        buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
        buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
      token_type_ids = buffered_token_type_ids_expanded
    else:
      token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    
    

    可以看到框架内不已经进行了处理。

    如果我们不想用默认值,也可以用tokenzider.encode_plus()。

    sent_code = tokenizer.encode_plus('今天是周末', '要在家好好学习哦')
    input_ids = torch.tensor([sent_code['input_ids']])
    token_type_ids = torch.tensor([sent_code['token_type_ids']])
    
    model(input_ids=input_ids, token_type_ids=token_type_ids)
    
    with torch.no_grad():
     ouptput = model(input_ids, token_type_ids=token_type_ids)
     last_hidden_state, pooler_output = ouptput[0], ouptput[1]
    

    相关文章

      网友评论

          本文标题:transformers中的bert用法

          本文链接:https://www.haomeiwen.com/subject/nokvmltx.html