美文网首页
How to add padding mask to nn.Tr

How to add padding mask to nn.Tr

作者: 魏鹏飞 | 来源:发表于2020-03-03 08:54 被阅读0次

    I think, when using src_mask, we need to provide a matrix of shape (S,S), where S is our source sequence length, for example,

    import torch
    import torch.nn as nn
    
    q = torch.randn(3, 1, 10) # source sequence length 3, batch size 1, embedding size 10
    attn = nn.MultiheadAttention(10, 1) # embedding size 10, one head
    attn(q, q, q) # self attention
    

    for attn_mask, we need matrix of shape (S, S),

    def src_mask(sz):
      mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
      mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
      return mask
    
    print(src_mask(3))
    
    # 结果
    tensor([[0., -inf, -inf],
            [0., 0., -inf],
            [0., 0., 0.]])
    
    out = attn(q, q, q, attn_mask=src_mask(3))[1] # attention output weights
    print(out)
    
    # 结果
    tensor([[[1.0000, 0.0000, 0.0000],
             [0.2497, 0.7503, 0.0000],
             [0.1139, 0.2764, 0.6097]]], grad_fn=<DivBackward0>)
    

    if we look at F.multi_head_attention_forward, then what attn_mask is doing is,

    if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(0)
            attn_output_weights += attn_mask
    

    as we added float('-inf') to some of the weights, so, when we do softmax, then it returns zero, for example,

    a = nn.Softmax(dim=-1)
    b = torch.tensor([3., 4., float('-inf')])
    print(a(b))
    
    # 结果
    tensor([0.2689, 0.7311, 0.0000])
    

    which means that we are not considering some words when finding the representation for a word, for example, when finding attn_weights for first word in our source sentence, we do not want to consider next words, for finding atten_weights for second word in our source sentence, we want to consider only fist and sencond word, and not third word.

    as for, src_key_padding_mask, it has to be of shape (N, S), where N is batch size, and S is source sentence length.

    I think it is to make us not consider any padded words for finding representation of other words.

    for example, if we want to not consider third word in our source sequence, for finding attention weights, then, (batch size of 1)

    src_key_padding_mask = torch.tensor([[0, 0, 1]]).bool()
    out = attn(q, q, q, attn_mask=src_mask(3), key_padding_mask=src_key_padding_mask)[1]
    print(out)
    
    # 结果
    tensor([[[1.0000, 0.0000, 0.0000],
             [0.2497, 0.7503, 0.0000],
             [0.2919, 0.7081, 0.0000]]], grad_fn=<DivBackward0>)
    

    the third column is always zero, as we did not consider what impact the third word has no the representation of other words.

    参考链接:
    https://discuss.pytorch.org/t/how-to-add-padding-mask-to-nn-transformerencoder-module/63390

    相关文章

      网友评论

          本文标题:How to add padding mask to nn.Tr

          本文链接:https://www.haomeiwen.com/subject/dtuflhtx.html