美文网首页NER
BiLSTM+CRF原理及代码分析

BiLSTM+CRF原理及代码分析

作者: IntoTheVoid | 来源:发表于2021-03-07 05:20 被阅读0次

    CRF基础

    • CRF是无向图模型,通过对MEMM进行改进,不直接计算状态间的转移概率,而是计算最大团势函数的乘积所得的归一化后的分值,如果要计算概率需要除以partition function, 即Z,下图中的Y表示给定的序列如下所示

    P(Y)=\frac{\prod_{C}\Psi_{C}(Y_C)}{Z} \\ Z = \sum_{Y}\prod_{C}\Psi_{C}(Y_C) \\ \Psi_{C}(Y_{C}) = exp(-E(Y_{C}))

    • CRF是判别模型,且属于log-linear model, 即给定序列X,求对应的Y序列的概率,表示如下:

    P(y|x)=\frac{exp\sum_{k=1}^{K}w_k f_k(y, x)}{Z(x)} \\ Z(x)=\sum_{y}exp\sum_{k=1}^{K}w_k f_k(y, x)

    模型的参数化形式

    img1.png
    • 其中转移状态函数:t_{k}(y_{i-1}, y_{i}, x, i)与状态特征函数:s_{l}(y_{i}, x_{i}, i)等价于f(y, x)
    • 其中转移特征的权重\lambda_{k}与状态特征的权重\mu_{l}统一用w表示
    • CRF和HMM最大的不同就是条件随机场中同一特征(即X中的某一个X_i,注意X是一个序列) 在各个位置(即Y中的某一个Y_i)都有定义,可以对同一个特征在各个位置求和,将局部特征函数转化为一个全局特征函数.而HMM某一个时刻只和上一个时刻相关

    BiLSTM+CRF建模

    1.数据定义

    规定在数据集中有两类实体,人名和组织机构名称:

    • B-Person(人名的开始部分)
    • I-Person(人名的中间部分)
    • B-Organization(组织机构的开始部分)
    • I-Organization(组织机构的中间部分)
    • O(非实体信息)

    为方便起见,我们给每个类别一个索引,如下表所示:


    tagindex.png

    训练数据集是由各种句子组成,假设所有句子中单词可以构成词典word2ix,给定单词w,可以获得其索引index_{w}

    2.BiLSTM-CRF 模型

    model.png
    • 模型输入:输入序列的单词索引
    • 模型输出:对应与输入序列的每个单词的标签索引
    • 模型组成:嵌入层>>>BiLSTM层>>>全连接层>>>CRF层

    3. CRF层的作用

    假设不加入CRF层的话,选择BiLSTM层每一个时刻分数最高的类别作为预测结果。图中所示情况得到的结果为:w0是“I-Organization”,w1是“I-Person”,w2是“O”,w3是 “B-Organization” ,w4是 “O”。显然,这次的分类结果并不准确

    CRF层可以加入一些约束来保证最终预测结果是有效的。这些约束可以在训练数据时被CRF层自动学习得到。

    • 句子的开头应该是“B-”或“O”,而不是“I-”
    • “B-label1 I-label2 I-label3…”,在该模式中,类别1,2,3应该是同一种实体类别。
    • “O I-label”是错误的,命名实体的开头应该是“B-”而不是“I-”

    4. BiLSTM-CRF 模型的损失函数

    CRF层中的损失函数包括两种类型的分数,而理解这两类分数的计算是理解CRF的关键。

    4.1.Emission score

    第一个类型的分数是发射分数(也成为状态分数)。这些状态分数来自BiLSTM层的输出。如上图黄色部分所示,w0被预测为B-Person的分数是1.5.

    定义 X_{i, y_j} ,其中i是单词的位置索引,y_j是类别的索引

    注意:BiLSTM层的输出矩阵即发射矩阵,维度时n*k, 其中n=输入序列的长度,k=所有的标签数, X_{i, :} \ne 1 因为不是概率

    4.2.Transition score

    定义 T_{y_i, y_j} 表示转移分数。例如,t_{B−Person,I−Person}=0.9表示从类别B−Person→I−Person的分数是0.9。因此,我们有一个所有类别间的转移分数矩阵。

    为了使转移分数矩阵更具鲁棒性,我们加上START 和 END两类标签。START代表一个句子的开始(不是句子的第一个单词),END代表一个句子的结束。下表是加上START和END标签的转移分数矩阵。

    要怎样得到这个转移矩阵呢?

    实际上,转移矩阵是BiLSTM-CRF模型的一个参数。在训练模型之前,你可以随机初始化转移矩阵的分数。这些分数将随着训练的迭代过程被更新,换句话说,CRF层可以自己学到这些约束条件。

    4.3.损失函数定义

    CRF损失函数由两部分组成,真实路径的分数 和 所有路径的总分数。真实路径的分数应该是所有路径中分数最高的。

    一个包含5个单词的句子,可能的类别序列如下:

      1. START B-Person B-Person B-Person B-Person B-Person END
      1. START B-Person I-Person B-Person B-Person B-Person END
    • …..
      1. START B-Person I-Person O B-Organization O END
    • N. O O O O O O O

    每种可能的路径的分数为Pi,共有N条路径,则路径的总分是

    P_{total} = P_1 + P_2 + ... + P_N = e^{S_1} + e^{S_2} + ... + e^{S_N} \\ e是常数e

    如果第十条路径是真实路径,也就是说第十条是正确预测结果,那么第十条路径的分数应该是所有可能路径里得分最高的。

    根据如下损失函数,在训练过程中,BiLSTM-CRF模型的参数值将随着训练过程的迭代不断更新,使得真实路径所占的比值越来越大。

    Loss Function = \frac{P_{realpath}}{P_1 + P_2 + ... + P_N}

    进一步变换,将其变成对数损失函数:

    Loss Function = log(\frac{P_{realpath}}{P_1 + P_2 + ... + P_N})

    由于我们的训练目标通常是最小化损失函数,所以我们加上负号:

    \begin{equation} \label{eq1} \begin{split} Loss Function & = -log(\frac{P_{realpath}}{P_1 + P_2 + ... + P_N}) \\ & = -log(\frac{e^{S_{realpath}}}{e^{S_1} + e^{S_2} + ... + e^{S_N}}) \\ & = -(log(e^{S_{realpath}}) - log(e^{S_1} + e^{S_2} + ... + e^{S_N})) \\ & = -(S_{realpath} - log(e^{S_1} + e^{S_2} + ... + e^{S_N})) \\ & = -(\sum_{i=1}^{N}X_{i, y_j} + \sum_{i=1}^{N-1}T_{y_i, y_j} - log(e^{S_1} + e^{S_2} + ... + e^{S_N})) \\ & = -(S_{realpath} - S_{allpath}) \end{split} \end{equation}

    where \quad{S_{realpath}} = Emission Score_{realpath} + Transition Score_{realpath} \\ S_{allpath} = log(e^{S_1} + e^{S_2} + ... + e^{S_N}) \\ Emission Score_{realpath} = \sum_{i=1}^{N}X_{i, y_j} \\ Transition Score_{realpath} = \sum_{i=1}^{N-1}T_{y_i, y_j}

    更新后的loss function,有两部分组成:

    • 1.给定序列的真实的tag序列的分数:即S_{realpath}
    • 2.给定序列的所有可能的tag序列的分数:即S_{allpath}
        # Compute loss function
        def neg_log_likelihood(self, sentence, tags):
            """
            sentence: token index at each timestamp
            tags: true label index at each timestamp
            """
            # Emission Matrix: feats, size=n*k, where n = len(sentence), k = len(tagsize)
            feats = self._get_lstm_features(sentence)
    
            # Real path score
            gold_score = self._score_sentence(feats, tags)
    
            # All path score
            forward_score = self._forward_alg(feats)
            
            # loss = - (S_realpath - S_allpath)
            loss = - (gold_score - forward_score) 
            return loss
    
    

    4.3.1 真实路径分数 S_{realpath}

    求真实路径分数的时候需要用到两个矩阵:

    • Emission 矩阵: 即BiLSTM的输出矩阵
    • Transition 矩阵: 可以先初始化,然后该矩阵会作为CRF的参数进行学习

    以“START B-Person I-Person O B-Organization O END”这条真实路径来说:

    句子中有5个单词,w1,w2,w3,w4,w5, 加上START和END 在句子的开始位置和结束位置,记为,w0,w6

    {S_i} = Emission Score + Transition Score

    Emission Score = \sum_{i=1}^{N}X_{i, y_j} = X_{0, START} + X_{1, B-Person} + X_{2, I-Person} + X_{3, O} + X_{4, B_Organization} + X_{5, O} + X_{6, END}

    Transition Score = \sum_{i=1}^{N-1}T_{y_i, y_j} = T_{START, B-person} + T_{B-person, I-person} + T_{I-person, O} + T_{O, B-organization} + T_{B-organization, O} + T_{O, END}

    注意:设X_{0, START}=X_{6, END} =0

        # compute real path score
        def _score_sentence(self, feats, tags):
            """gives the score of a provides tag sequence
            # feats: emission matrix from bilstm output, size = n*k
                # where n = len(sentence), k = len(tagsize)
            # tags: true label index at each timestamp
            """
    
            score = torch.zeros(1)
    
            # Put START_TAG at tag sequence head, such as [START_TAG, tag1, tag2... tagN]
            tags = torch.cat(
                [torch.tensor([self.tag2ix[START_TAG]], dtype=torch.long), tags])
    
            # Compute real path score : 
            # realpath score = each timestamp emission score + each timestamp transition score
            for i, feat in enumerate(feats):
                # transition score from i -> i+1: self.transitions[tags[i + 1], tags[i]]: 
                # emission score at i: feat[tags[i+1]], because START_TAG in tag sequence, index not i
                score = score + \
                    self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
            
            # Add value from last tag to END_TAG at score
            score = score + self.transitions[self.tag2ix[END_TAG], tags[-1]]
            return score
    
    
    

    4.3.2 所有路径分数 S_{allpath} = log(e^{S_1} + e^{S_2} + ... + e^{S_N})

    核心思想:整个过程是一个分数的积聚过程。利用动态规划,首先,i=0时刻所有路径的总分先被计算出来,然后,我们计算i=0时刻 --> i=1时刻的所有路径的得分,最后计算i=0时刻 --> i=1时刻 --> i=2时刻的所有路径的得分,也就是我们需要的结果。

    假设序列长度为L,

    • 原问题:计算从第0时刻到第L时刻的所有路径总分数
    • 子问题:计算从第0时刻到第i时刻的所有路径总分数

    log(e^{S_1} + e^{S_2} + ... + e^{S_i}) = log\left(\sum_{j\in{AllTags}} exp(\alpha_{i, j})\right)

    where \, \alpha_{i, j} = log\left(\sum_{j' \in{AllTags}} exp(\alpha_{i-1, j'} + T_{j', j} + X_{j, w_i})\right)

    • j'表示上一时刻的tag取值
    • j表示当前时刻的tag取值
    • alpha_{i-1, j'}表示上一个时刻的所有路径分数
    • T_{j', j} 表示从上一时刻的tagj'到当前时刻的tagj的转移分数,可以由transition matrix求得
    • X_{j, w_i} 表示当前时刻的单词w_i到tagj的发射分数,可以由emission matrix求得

    也可以通过下图很好的理解, 如果要求解下图中抵达i时刻,tag为j的所有路径分数

    forward_alg.png

    为了方便理解,此处定义一个函数log-sum-exp

    \begin{equation} \label{eq2} \begin{split} \alpha_{i, j} & = logSumExp(\\ & + \alpha_{i-1, 1} + T_{1,j} + X_{j, w_{i}} \\ & + \alpha_{i-1, 2} + T_{2,j} + X_{j, w_{i}} \\ & + \alpha_{i-1, 3} + T_{3,j} + X_{j, w_{i}} \\ & + \alpha_{i-1, 4} + T_{4,j} + X_{j, w_{i}} \\ & + \alpha_{i-1, 5} + T_{5,j} + X_{j, w_{i}} \\ & + \alpha_{i-1, 6} + T_{6,j} + X_{j, w_{i}} \\ & + \alpha_{i-1, 7} + T_{7,j} + X_{j, w_{i}}) \end{split} \end{equation}

    那么对于i时刻的每一个j,都可以计算上面的公式(上面的公式可以通过矩阵的方式计算,详见代码),就可以得到第i个时刻每一个tag下所有路径分数

    需要注意的两点:

    1. 对于i=0时,即第一个时刻,第一个时刻标签为START_TAG, 此时通过给定一个初始化\alpha代表\alpha_{i-1, j'},该初始化\alpha通过START_TAG定义
    2. 对于i=L时,即最后一个时刻,最后一个时刻标签为END_TAG,而END_TAG没有发射到具体的单词上,所有不存在X_{j, w_i},只需要计算上一个时刻的路径分数和上一个tag到END_TAG的转移分数
        # compute all path score
        def _forward_alg(self, feats):
            
            # for example: tagset = {START_TAG, tag1, tag2, tag3, END_TAG}
            # tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
            init_alphas = torch.full((1, self.tagset_size), -10000.)
    
            # All path score at START_TAG
            # tensor([[-10000.,-10000.,-10000.,0,-10000.]])
            init_alphas[0][self.tag2ix[START_TAG]] = 0 
    
            # initial alpha at timestamp START_TAG
            forward_var = init_alphas
    
            # feats: emission matrix from bilstm output, size = n*k
            # where n = len(sentence), k = len(tagsize)
            for feat in feats:
                
                # alphas_t: a array to store score on each tag j at time i
                alphas_t = []
                for next_tag in range(self.tagset_size):
    
                    # feat[next_tag]: get emission score at tag j
                    # tensor([3]) -> tensor([[3,3,3,3,3]])
                    emit_score = feat[next_tag].view(
                        1, -1).expand(1, self.tagset_size)
    
                    # transitions[next_tag]: get transition scores from j' to j
                    trans_score = self.transitions[next_tag].view(1, -1)
    
                    # compute alpha_{i-1, j'} + T_{j', j} + X_{j, Wi}
                    next_tag_var = forward_var + trans_score + emit_score
    
                    # compute log_sum_exp on each tag j at time i and append to alphas_t
                    alphas_t.append(log_sum_exp(next_tag_var).view(1))
                
                # get all path score at time i for each tag j
                forward_var = torch.cat(alphas_t).view(1, -1)
    
            # get all path score at last time i (tag=END_TAG) for each tag j
            terminal_var = forward_var + self.transitions[self.tag2ix[END_TAG]]
    
            # get final all path score using log_sum_exp
            # alpha = S_{allpath}
            alpha = log_sum_exp(terminal_var)
            return alpha
    

    BiLSTM+CRF预测

    根据上面的模型,我们可以学习到CRF层的参数即转移矩阵,那么给定一个句子序列,输入到BiLSTM层可以获得对应的发射矩阵,那么可以通过viterbi算法求解该序列的最优路径,整体思路和所有路径分数的求解类似。

    主要不同如下:

    • 此处需要求最优路径(而非一个总体分值),因此不需要计算log_sum_exp,即只需要计算\alpha_{i-1, j'} + T_{j', j} + X_{j, w_i}
    • 在计算α的同时追踪抵达每一时刻的每个tag的最优“上一步”来自于哪里, 即下面代码中的bptrs_t.append(best_tag_id)
    • 最后需要倒序进行最优路径的回溯
        # Compute best path score and best path
        def _viterbi_decode(self, feats):
    
            backpointers = []
    
            # tensor([[-10000.,-10000.,-10000.,-10000.,-10000.]])
            init_vars = torch.full((1, self.tagset_size), -10000.)
            # tensor([[-10000.,-10000.,-10000.,0,-10000.]])
            init_vars[0][self.tag2ix[START_TAG]] = 0
    
            forward_var = init_vars
            for feat in feats:
                bptrs_t = []  # holds the back pointers for this step
                viterbivars_t = []  # holds the viterbi variables for this step
    
                for next_tag in range(self.tagset_size):
                    next_tag_var = forward_var + self.transitions[next_tag]
                    best_tag_id = argmax(next_tag_var)
                    bptrs_t.append(best_tag_id)
                    viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
                forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
                backpointers.append(bptrs_t)
    
            # Transition to STOP_TAG
            terminal_var = forward_var + self.transitions[self.tag2ix[END_TAG]]
            best_tag_id = argmax(terminal_var)
            path_score = terminal_var[0][best_tag_id]
    
            # Follow the back pointers to decode the best path.
            best_path = [best_tag_id]
            for bptrs_t in reversed(backpointers):
                best_tag_id = bptrs_t[best_tag_id]
                best_path.append(best_tag_id)
            
            # Pop off the start tag (we dont want to return that to the caller)
            start = best_path.pop()
            assert start == self.tag2ix[START_TAG]  # Sanity check
            best_path.reverse()
            return path_score, best_path
    

    相关文章

      网友评论

        本文标题:BiLSTM+CRF原理及代码分析

        本文链接:https://www.haomeiwen.com/subject/lseeqltx.html