美文网首页
2021-12-06 bert model

2021-12-06 bert model

作者: Cipolee | 来源:发表于2021-12-22 19:21 被阅读0次

    attention mask如何使用

    • attention_mask List[int] 0-mask,1-attention
      forward(,attention_mask,):
    encoder_outputs = self.encoder(
                embedding_output,
                attention_mask=extended_attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_extended_attention_mask,
                past_key_values=past_key_values,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
    
    • extend_attention_mask
    extended_attention_mask: torch.Tensor = \
    self.get_extended_attention_mask(attention_mask, input_shape, device)
    
     def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
            """
            Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
            Arguments:
                attention_mask (:obj:`torch.Tensor`):
                    Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
                input_shape (:obj:`Tuple[int]`):
                    The shape of the input to the model.
                device: (:obj:`torch.device`):
                    The device of the input to the model.
            Returns:
                :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
            """
            # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
            # ourselves in which case we just need to make it broadcastable to all heads.
            if attention_mask.dim() == 3:
                extended_attention_mask = attention_mask[:, None, :, :]
            elif attention_mask.dim() == 2:
                # Provided a padding mask of dimensions [batch_size, seq_length]
                # - if the model is a decoder, apply a causal mask in addition to the padding mask
                # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
                if self.config.is_decoder:
                    batch_size, seq_length = input_shape
                    seq_ids = torch.arange(seq_length, device=device)
                    causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                    # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                    # causal and attention masks must have same type with pytorch version < 1.3
                    causal_mask = causal_mask.to(attention_mask.dtype)
    
                    if causal_mask.shape[1] < attention_mask.shape[1]:
                        prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                        causal_mask = torch.cat(
                            [
                                torch.ones(
                                    (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
                                ),
                                causal_mask,
                            ],
                            axis=-1,
                        )
    
                    extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
                else:
                    extended_attention_mask = attention_mask[:, None, None, :]
            else:
                raise ValueError(
                    f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
                )
    
            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
            return extended_attention_mask
    
    
    • get_extended_attention_mask

    attention_mask=extend_attention_mask

    • is_decoder中encoder_attention_mask: encoder_extend_attention_mask=self.invert_attention_mask()

    形成一个下三角矩阵

    最终mask在BertSelfAttention里起作用。

    • 在forward函数里求出attention score之后,通过运行
    if attention_mask is not None:
                # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
                attention_scores = attention_scores + attention_mask
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    

    在BertModel传入attention_mask

    这是attention已经在BertModel的forward的get_extended_attention_mask处转变
    其中get_extended_attention_mask

    其中get_extended_attention_mask来自modeling_utils.py文件

            extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    

    目的,将attention_mask中为0的变为大负数,1的为0

    • 此时传给encoder的attention_mask已经改变,(encoder_attention_mask根据是否decoder传值)。
    • encoder来自 BertEncoder(config)
    • BertEncoder封装了num_hidden_layer个BertLayer
    • BertLayer封装了BertAttention和BertIntermediate和BertOutput
      *BertAttention封装了BertSelfAttention,和BertSelfOutput

    一个疑惑:BertModel的init具体初始化了那些东西

    *Bert的init函数里有

    super().__init__(config)
    self.post_init()
    

    在QA中,tokenizer之后的inputs的attention_mask仍然保持全1状态,需要手动调整

    相关文章

      网友评论

          本文标题:2021-12-06 bert model

          本文链接:https://www.haomeiwen.com/subject/ztagxrtx.html