美文网首页
pytorch-crf

pytorch-crf

作者: 三方斜阳 | 来源:发表于2021-05-17 09:36 被阅读0次

    官方文档:
    pytorch-crf — pytorch-crf 0.7.2 documentation
    使用pytorch 实现的条件随机场(CRF)模型,基于 AllenNLP CRF 模块,关于 CRF 的原理理解可以看这篇:CRF-条件随机场 - 简书 (jianshu.com)

    1. 安装:

    pip install pytorch-crf
    

    2. 导入模块使用:

    import torch
    from torchcrf import CRF
    num_tags = 5  # number of tags is 5
    model = CRF(num_tags , batch_first=True)
    

    3. 计算转移概率:

    seq_length = 3  # maximum sequence length in a batch
    batch_size = 2  # number of samples in the batch
    emissions = torch.randn(batch_size,seq_length, num_tags)
    >>
    tensor([[[ 0.3920, -2.0889,  1.0805, -0.6806, -0.0954],
             [ 0.1010,  0.2014, -0.0918, -0.7187, -1.2575],
             [-0.6948,  0.0528, -1.9853,  0.1679, -0.7857]],
    
            [[-1.0272, -0.2852, -0.5759,  1.3462,  0.7249],
             [ 0.6465,  0.1241, -0.9154, -0.6966, -0.0647],
             [-1.4029, -1.0029, -1.1149,  0.9312,  0.0092]]])
    >>
    tags = torch.tensor([[0,2,3], [1,4,1]], dtype=torch.long)  #(batch_size, seq_length)
    model(emissions, tags)
    >>
    tensor(-9.8121, grad_fn=<SumBackward0>)
    

    4. 如果输入有padding

    # mask size is (batch_size,seq_length)
    # the last sample has length of 1
    mask = torch.tensor([[1, 1,1], [1, 1,0]], dtype=torch.uint8)
    model(emissions, tags, mask=mask)
    >>
    tensor(-8.7959, grad_fn=<SumBackward0>)
    

    5. 解码:

    model.decode(emissions , mask=mask)
    >>
    [[1, 4, 4], [4, 1]]
    

    相关文章

      网友评论

          本文标题:pytorch-crf

          本文链接:https://www.haomeiwen.com/subject/krndjltx.html