官方文档:
pytorch-crf — pytorch-crf 0.7.2 documentation
使用pytorch 实现的条件随机场(CRF)模型,基于 AllenNLP CRF 模块,关于 CRF 的原理理解可以看这篇:CRF-条件随机场 - 简书 (jianshu.com)
1. 安装:
pip install pytorch-crf
2. 导入模块使用:
import torch
from torchcrf import CRF
num_tags = 5 # number of tags is 5
model = CRF(num_tags , batch_first=True)
3. 计算转移概率:
seq_length = 3 # maximum sequence length in a batch
batch_size = 2 # number of samples in the batch
emissions = torch.randn(batch_size,seq_length, num_tags)
>>
tensor([[[ 0.3920, -2.0889, 1.0805, -0.6806, -0.0954],
[ 0.1010, 0.2014, -0.0918, -0.7187, -1.2575],
[-0.6948, 0.0528, -1.9853, 0.1679, -0.7857]],
[[-1.0272, -0.2852, -0.5759, 1.3462, 0.7249],
[ 0.6465, 0.1241, -0.9154, -0.6966, -0.0647],
[-1.4029, -1.0029, -1.1149, 0.9312, 0.0092]]])
>>
tags = torch.tensor([[0,2,3], [1,4,1]], dtype=torch.long) #(batch_size, seq_length)
model(emissions, tags)
>>
tensor(-9.8121, grad_fn=<SumBackward0>)
4. 如果输入有padding
# mask size is (batch_size,seq_length)
# the last sample has length of 1
mask = torch.tensor([[1, 1,1], [1, 1,0]], dtype=torch.uint8)
model(emissions, tags, mask=mask)
>>
tensor(-8.7959, grad_fn=<SumBackward0>)
5. 解码:
model.decode(emissions , mask=mask)
>>
[[1, 4, 4], [4, 1]]
网友评论