美文网首页
序列比对(十六)——Baum-Welch算法估算HMM参数

序列比对(十六)——Baum-Welch算法估算HMM参数

作者: 生信了 | 来源:发表于2019-11-14 11:21 被阅读0次

原创:hxj7

本文介绍了如何用Baum-Welch算法来估算HMM模型中的概率参数。

Baum-Welch算法应用于HMM的效果

前文《序列比对(15)EM算法以及Baum-Welch算法的推导》介绍了EM算法Baum-Welch算法的推导过程。Baum-Welch算法是EM算法的一个特例,用来估算HMM模型中的概率参数。其具体步骤如下:

image
<center> 图片引自《生物序列分析》</center>

本文给出了Baum-Welch算法的C代码,还是以投骰子为例,估算出了转移概率以及发射概率。

具体效果如图:
(下面几张图中的 Real 表示真实的转移概率以及发射概率,而Baum-Welch表示用Baum-Welch算法估算的转移概率以及发射概率。)
首先是当若干条序列总长度300时:

image
image
然后是当若干条序列总长度30000时:
image
image
可以看出总长度为30000时已经很接近真实值了。但是,Baum-Welch算法的结果时一个局部最优值,很依赖初始值的设定。所以,当初始值不同时,也有可能会出现这种结果:
image
image
小结一下:
  • Baum-Welch算法通过多次迭代来估算HMM模型中的概率参数。
  • 本文代码设定了迭代的终止条件:当“归一化后的平均对数似然”的变化小于预先设定的阈值时或者迭代次数超出最大迭代次数时,迭代终止。
  • Baum-Welch算法的最终结果非常依赖初始值的设定。
  • 本文代码中的初始值是随机值。
  • 在计算期望次数时,使用了伪计数。

代码中所用公式及其推导

其中的A_{kl}指的是a_{kl}在所有训练序列中出现的期望次数,而E_k(b)指的是e_k(b)在所有训练序列中出现的期望次数。用符号表示就是(其中x^j表示第j条符号序列):
\begin{align} \displaystyle A_{kl} & = \sum_{j} \sum_{\pi} P(\pi^j|x^j,\theta) A_{kl}(\pi^j) \\ & = \sum_{j} \sum_i P(\pi_i^j=k, \pi_{i+1}^j=l|x^j,\theta) \tag{1.1} \end{align}
\begin{align} \displaystyle E_k(b) & = \sum_j \sum_\pi P(\pi^j|x^j, \theta) E_k(b, \pi^j) \\ & = \sum_{j} \sum_i P(\pi_i^j=k, x_i^j=b|x^j,\theta) \\ & = \sum_{j} \sum_{\{i|x_i^j=b\}} P(\pi_i^j=k|x^j,\theta) \tag{1.2} \end{align}

我们可以推导出,对某一条序列x^j有如下结论:
P(\pi_i=k, \pi_{i+1}=l|x,\theta) = \tilde{f}_k(i) a_{kl} e_l(x_{i+1}) \tilde{b}_l(i+1) \tag{2.1}
P(\pi_i=k|x,\theta) = \tilde{f}_k(i) \tilde{b}_k(i) s_i \tag{2.2}

其中\tilde{f}_k(i)\tilde{b}_k(i) 以及 s_i 的定义在前文《序列比对(12):计算后验概率》中已经给出(下文给出了计算公式)。

公式(2.1)的推导如下:
\begin{align} & P(\pi_i=k, \pi_{i+1}=l|x,\theta) \\ & = \frac {P(\pi_i=k, \pi_{i+1}=l,x|\theta)} {P(x|\theta)} \\ & = \frac {P(\pi_i=k, \pi_{i+1}=l, x_1, ..., x_i, x_{i+1},...,x_L|\theta)} {P(x|\theta)} \\ & = \frac {P(x_1,...,x_i,\pi_i=k|\theta) P(x_{i+1},...,x_L,\pi_{i+1}=l|x_1,...,x_i,\pi_i=k,\theta)} {P(x|\theta)} \\ & = \frac {f_k(i) P(x_{i+1},...,x_L,\pi_{i+1}=l|x_1,...,x_i,\pi_i=k,\theta)}{P(x|\theta)} \\ & = \frac {f_k(i) P(x_{i+1},...,x_L,\pi_{i+1}=l|\pi_i=k,\theta)}{P(x|\theta)} \\ & = \frac {f_k(i) P(x_{i+2},...,x_L,\pi_{i+1}=l|x_{i+1},\pi_{i+1}=l,\pi_i=k,\theta) P(x_{i+1},\pi_{i+1}=l|\pi_i=k,\theta)}{P(x|\theta)} \\ & = \frac {f_k(i) P(x_{i+2},...,x_L,\pi_{i+1}=l|\pi_{i+1}=l,\theta) P(x_{i+1},\pi_{i+1}=l|\pi_i=k,\theta)} {P(x|\theta)} \\ & = \frac {f_k(i) b_l(i+1) P(x_{i+1},\pi_{i+1}=l|\pi_i=k,\theta)} {P(x|\theta)} \\ & = \frac {f_k(i) b_l(i+1) P(\pi_{i+1}=l|\pi_i=k,\theta) P(x_{i+1}|\pi_i=k,\pi_{i+1}=l,\theta)} {P(x|\theta)} \\ & = \frac {f_k(i) b_l(i+1) a_{kl} P(x_{i+1}|\pi_{i+1}=l,\theta)} {P(x|\theta)} \\ & = \frac {f_k(i) b_l(i+1) a_{kl} e_l(x_{i+1})} {P(x|\theta)} \end{align}
同时,由我们知道:
\displaystyle f_k(i) = \tilde{f}_k(i) \prod_{r=1}^{i} s_r \\ \displaystyle b_k(i) = \tilde{b}_k(i) \prod_{r=i}^{L} s_r \\ \displaystyle P(x) = \prod_{r=1}^L s_r
所以:
\begin{aligned} P( & \pi_i=k, \pi_{i+1}=l|x,\theta) \\ & = \frac {f_k(i) b_l(i+1) a_{kl} e_l(x_{i+1})} {P(x|\theta)} \\ & = \frac {\bigg[ \tilde{f}_k(i) \displaystyle \prod_{r=1}^{i} s_r \bigg] \bigg[ \tilde{b}_l(i+1) \prod_{r=i+1}^{L} s_r \bigg] a_{kl} e_l(x_{i+1})} {\displaystyle \prod_{r=1}^L s_r} \\ & = \tilde{f}_k(i) a_{kl} e_l(x_{i+1}) \tilde{b}_l(i+1) \end{aligned}
公式(2.2)的证明已经在前文《序列比对(12):计算后验概率》中给出过了。

由式子(1.1)、(1.2)、(2.1)、(2.2),我们可以得到:
\displaystyle A_{kl} = \sum_{j} \sum_i \tilde{f}^{j}_k(i) a_{kl} e_l(x^j_{i+1}) \tilde{b}^j_l(i+1) \tag{3.1}
\displaystyle E_k(b) = \sum_{j} \sum_{\{i|x_i^j=b\}} \tilde{f}^j_k(i) \tilde{b}^j_k(i) s^j_i \tag{3.2}

实际上,代码中使用了状态0,构建了初始概率向量。假设以B代表初始向量的“转移”期望次数,那么它是A_{kl}当k=0时的一个特例:
\displaystyle B_{0l} = \sum_j \tilde{b}^j_l(1) a_{0l} e_l(x^j_1) \tag{3.3}

由于我们使用了伪计数r_{kl} 以及 r_k(b),所以:
A'_{kl} = A_{kl} + r_{kl} \tag{4.1}
E'_{k}(b) = E_{k}(b) + r_{k}(b) \tag{4.2}
B'_{0l} = B_{0l} + r_{0l} \tag{4.3}

最终,我们可以估算转移概率和发射概率:
a_{kl} = \frac {A'_{kl}} {\displaystyle \sum_{l'} A'_{kl'}} \tag{5.1}

e_k(b) = \frac {E'_k(b)} {\displaystyle \sum_{b'} E'_k(b')} \tag{5.2}

本文代码实际使用的计算公式就是(5.1)和(5.2)。

具体代码

具体代码如下:
(本文代码利用结构体重新梳理了过程,与之前文章中的代码相比,更工整了。)

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>

typedef char State;
typedef char Symbol;
struct MarkovChain {
    double* b;   // 初始概率向量
    double** a;
    double** e;
    Symbol* symb;
    State* st;
    int* idx;    // 每个符号向量所对应的序号
    int L;      // 符号向量的长度
    double** fscore;
    double** bscore;
    double* scale;
    double logScaleSum;
};
typedef struct MarkovChain* MChain;

State state[] = {'F', 'L'};   // 所有的可能状态
Symbol symbol[] = {'1', '2', '3', '4', '5', '6'};   // 所有的可能符号
double init[] = {0.9, 0.1};    // 初始状态的概率向量
double emission[][6] = {   // 发射矩阵:行对应着状态,列对应着符号
    1.0/6, 1.0/6, 1.0/6, 1.0/6, 1.0/6, 1.0/6,
    0.1, 0.1, 0.1, 0.1, 0.1, 0.5
};
double trans[][2] = {   // 转移矩阵:行和列都是状态
    0.95, 0.05,
    0.1, 0.9
};
const int nstate = 2;
const int nsymbol = 6;

MChain create(const int n);
int random(double* prob, const int n);  // 根据一个概率向量随机生成一个0 ~ n - 1的整数
void randSeq(MChain mc);
void getSymbolIndex(MChain mc);
void forward(MChain mc);
void backward(MChain mc);
void printState(State* st, const int n);
void printSymbol(Symbol* symb, const int n);
void printMChain(MChain mc);
void destroy(MChain mc);
void toz(double* a, const int n);   // 将概率数组除以其和,使得新的概率的和为1
void BaumWelch(MChain* amc, const int n);

int main(void) {
    int nchain = 3;
    int initLen = 80;
    int step = 20;
    int i;
    MChain* amc;
    MChain mc;
    if ((amc = (MChain*) malloc(sizeof(MChain) * nchain)) == NULL) {
        fputs("Error: out of space!\n", stderr);
        exit(1);
    }
    for (i = 0; i < nchain; i++) {
        mc = create(initLen + step * i);
        randSeq(mc);
        getSymbolIndex(mc);
        //printMChain(mc);
        amc[i] = mc;
    }
    BaumWelch(amc, nchain);
    for (i = 0; i < nchain; i++)
        destroy(amc[i]);
    free(amc);
    return 0;
}

MChain create(const int n) {
    int k;
    MChain mc;
    if ((mc = (MChain) malloc(sizeof(struct MarkovChain))) == NULL) {
        fputs("Error: out of space!\n", stderr);
        exit(1);        
    }
    mc->L = n;
    if ((mc->symb = (Symbol*) malloc(sizeof(Symbol) * mc->L)) == NULL || \
        (mc->st = (State*) malloc(sizeof(State) * mc->L)) == NULL || \
        (mc->idx = (int*) malloc(sizeof(int) * mc->L)) == NULL || \
        (mc->fscore = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (mc->bscore = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (mc->scale = (double*) malloc(sizeof(double) * mc->L)) == NULL) {
        fputs("Error: out of space!\n", stderr);
        exit(1);
    }
    for (k = 0; k < nstate; k++) {
        if ((mc->fscore[k] = (double*) malloc(sizeof(double) * mc->L)) == NULL || \
            (mc->bscore[k] = (double*) malloc(sizeof(double) * mc->L)) == NULL) {
            fputs("Error: out of space!\n", stderr);
            exit(1);        
        }
    }
    return mc;
}

int random(double* prob, const int n) {
    int i;
    double p = rand() / 1.0 / (RAND_MAX + 1);
    for (i = 0; i < n - 1; i++) {
        if (p <= prob[i])
            break;
        p -= prob[i];
    }
    return i;
}

void randSeq(MChain mc) {
    int i, ls, lr;
    srand((unsigned int) time(NULL));
    ls = random(init, nstate);
    lr = random(emission[ls], nsymbol);
    mc->st[0] = state[ls];
    mc->symb[0] = symbol[lr];
    for (i = 1; i < mc->L; i++) {
        ls = random(trans[ls], nstate);
        lr = random(emission[ls], nsymbol);
        mc->st[i] = state[ls];
        mc->symb[i] = symbol[lr];
    }
}

void getSymbolIndex(MChain mc) {
    int i;
    for (i = 0; i < mc->L; i++)
        mc->idx[i] = mc->symb[i] - symbol[0];
}

void forward(MChain mc) {
    int i, l, k, idx;
    double logpx;
    // 缩放因子向量初始化
    for (i = 0; i < mc->L; i++)
        mc->scale[i] = 0;
    // 计算第0列分值
    idx = mc->idx[0];
    for (l = 0; l < nstate; l++) {
        mc->fscore[l][0] = mc->e[l][idx] * mc->b[l];
        mc->scale[0] += mc->fscore[l][0];
    }
    for (l = 0; l < nstate; l++)
        mc->fscore[l][0] /= mc->scale[0];
    // 计算从第1列开始的各列分值
    for (i = 1; i < mc->L; i++) {
        idx = mc->idx[i];
        for (l = 0; l < nstate; l++) {
            mc->fscore[l][i] = 0;
            for (k = 0; k < nstate; k++) {
                mc->fscore[l][i] += mc->fscore[k][i - 1] * mc->a[k][l];
            }
            mc->fscore[l][i] *= mc->e[l][idx];
            mc->scale[i] += mc->fscore[l][i];
        }
        for (l = 0; l < nstate; l++)
            mc->fscore[l][i] /= mc->scale[i];
    }
    // P(x) = product(scale)
    // P(x)就是缩放因子向量所有元素的乘积
    logpx = 0;
    for (i = 0; i < mc->L; i++)
        logpx += log(mc->scale[i]);
    mc->logScaleSum = logpx;
    /*
    // 打印结果
    printf("forward: logP(x) = %f\n", logpx);
    for (l = 0; l < nstate; l++) {
        for (i = 0; i < mc->L; i++)
            printf("%f ", mc->fscore[l][i]);
        printf("\n");
    }
    */
}

void backward(MChain mc) {
    int i, l, k, idx;
    double tx, logpx;
    // 计算最后一列分值
    for (l = 0; l < nstate; l++)
        mc->bscore[l][mc->L - 1] = 1 / mc->scale[mc->L - 1];
    // 计算从第n - 2列开始的各列分值
    for (i = mc->L - 2; i >= 0; i--) {
        idx = mc->idx[i + 1];
        for (k = 0; k < nstate; k++) {
            mc->bscore[k][i] = 0;
            for (l = 0; l < nstate; l++) {
                mc->bscore[k][i] += mc->bscore[l][i + 1] * mc->a[k][l] * mc->e[l][idx];
            }
        }
        for (l = 0; l < nstate; l++)
            mc->bscore[l][i] /= mc->scale[i];
    }
    /*
    // 计算P(x)
    tx = 0;
    idx = mc->idx[0];
    for (l = 0; l < nstate; l++)
        tx += mc->b[l] * mc->e[l][idx] * mc->bscore[l][0];
    logpx = log(tx) + mc->logScaleSum;
    // 打印结果
    printf("backward: logP(x) = %f\n", logpx);
    for (l = 0; l < nstate; l++) {
        for (i = 0; i < mc->L; i++)
            printf("%f ", mc->bscore[l][i]);
        printf("\n");
    }
    */
}

void printState(State* st, const int n) {
    int i;
    for (i = 0; i < n; i++)
        printf("%c", st[i]);
    printf("\n");
}

void printSymbol(Symbol* symb, const int n) {
    int i;
    for (i = 0; i < n; i++)
        printf("%c", symb[i]);
    printf("\n");
}

void printMChain(MChain mc) {
    int k;
    int ll = 60;
    int nl = mc->L / ll;
    int nd = mc->L % ll;
    for (k = 0; k < nl; k++) {
        printf("Rolls\t");
        printSymbol(mc->symb + k * ll, ll);
        printf("Die\t");
        printState(mc->st + k * ll, ll);
        printf("\n");
    }
    if (nd > 0) {
        printf("Rolls\t");
        printSymbol(mc->symb + k * ll, nd);
        printf("Die\t");
        printState(mc->st + k * ll, nd);
        printf("\n"); 
    }
    printf("\n\n");
}

void destroy(MChain mc) {
    int i;
    free(mc->symb);
    free(mc->st);
    free(mc->idx);
    free(mc->scale);
    for (i = 0; i < nstate; i++) {
        free(mc->fscore[i]);
        free(mc->bscore[i]);
    }
    free(mc->fscore);
    free(mc->bscore);
    free(mc);
}

void toz(double* a, const int n) {
    int i;
    double sum;
    for (i = 0, sum = 0; i < n; i++)
        sum += a[i];
    if (sum == 0) {
        for (i = 0; i < n; i++)
            a[i] = 1.0 / n;
    } else {
        for (i = 0; i < n; i++)
            a[i] /= sum;
    }
}

void BaumWelch(MChain* amc, const int n) {
    int i, k, j, l;
    double* b;   // 初始概率向量
    double** e;
    double** a;
    double* B;
    double** A;
    double** E;
    double* rb;   // 伪计数
    double** ra;
    double** re;
    int maxIter = 500;   // 最大迭代次数
    int niter;
    int totalLen;     // 序列总长度
    double minLogDiff = 1e-6;      // 终止阈值
    double loglh1, loglh2;   // log likelyhood
    double tmp, sum;
    // 初始化空间
    if ((b = (double*) malloc(sizeof(double) * nstate)) == NULL || \
        (e = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (a = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (B = (double*) malloc(sizeof(double) * nstate)) == NULL || \
        (A = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (E = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (rb = (double*) malloc(sizeof(double) * nstate)) == NULL || \
        (ra = (double**) malloc(sizeof(double*) * nstate)) == NULL || \
        (re = (double**) malloc(sizeof(double*) * nstate)) == NULL) {
        fputs("Error: out of space!\n", stderr);
        exit(1);     
    }
    for (k = 0; k < nstate; k++) {
        if ((e[k] = (double*) malloc(sizeof(double) * nsymbol)) == NULL || \
            (a[k] = (double*) malloc(sizeof(double) * nstate)) == NULL || \
            (E[k] = (double*) malloc(sizeof(double) * nsymbol)) == NULL || \
            (A[k] = (double*) malloc(sizeof(double) * nstate)) == NULL || \
            (re[k] = (double*) malloc(sizeof(double) * nsymbol)) == NULL || \
            (ra[k] = (double*) malloc(sizeof(double) * nstate)) == NULL) {
            fputs("Error: out of space!\n", stderr);
            exit(1);        
        }
    }
    // 序列总长度
    for (i = 0, totalLen = 0; i < n; i++)
        totalLen += amc[i]->L;
    // 初始化参数值,概率使用随机数,次数使用伪计数
    srand((unsigned int) time(NULL));
    for (k = 0; k < nstate; k++) {
        rb[k] = 0;
        b[k] = rand() / (float) RAND_MAX;
    }
    toz(b, nstate);  // 将概率向量的和转换为1
    for (k = 0; k < nstate; k++) {
        for (l = 0; l < nstate; l++) {
            ra[k][l] = 1;
            a[k][l] = rand() / (float) RAND_MAX;
        }
        toz(a[k], nstate);
    }
    for (k = 0; k < nstate; k++) {
        for (i = 0; i < nsymbol; i++) {
            re[k][i] = 1;
            e[k][i] = rand() / (float) RAND_MAX;
        }
        toz(e[k], nsymbol);    
    }
    // 开始迭代过程
    for (j = 0, loglh2 = 0; j < n; j++) {
        amc[j]->e = e;
        amc[j]->a = a;
        amc[j]->b = b;
        forward(amc[j]);
        backward(amc[j]);
        loglh2 += amc[j]->logScaleSum;
    }
    loglh2 = loglh2 * 1000 / totalLen;    // 用序列总长度归一化,得到每个符号的平均log-likelyhood
    loglh1 = loglh2 - minLogDiff - 1;
    for (niter = 0; niter < maxIter && loglh2 - loglh1 > minLogDiff; niter++) {
        loglh1 = loglh2;
        // 使用伪计数赋值给初始次数
        for (k = 0; k < nstate; k++)
            B[k] = rb[k];
        for (k = 0; k < nstate; k++) {
            for (l = 0; l < nstate; l++) 
                A[k][l] = ra[k][l];
        }
        for (k = 0; k < nstate; k++) {
            for (i = 0; i < nsymbol; i++) 
                E[k][i] = re[k][i];   
        }
        // 利用旧参数计算期望次数
        for (j = 0; j < n; j++) {
            for (k = 0; k < nstate; k++) {
                B[k] += amc[j]->bscore[k][0] * b[k] * e[k][amc[j]->idx[0]];
            }
            for (k = 0; k < nstate; k++)
                for (l = 0; l < nstate; l++)
                    for (i = 0; i < amc[j]->L - 1; i++)
                        A[k][l] += amc[j]->fscore[k][i] * amc[j]->bscore[l][i + 1] * a[k][l] * e[l][amc[j]->idx[i + 1]];
            for (k = 0; k < nstate; k++)
                for (i = 0; i < amc[j]->L; i++)
                    E[k][amc[j]->idx[i]] += amc[j]->fscore[k][i] * amc[j]->bscore[k][i] * amc[j]->scale[i];
        } 
        // 利用期望次数计算新参数
        for (k = 0, sum = 0; k < nstate; k++)
            sum += B[k];
        for (k = 0; k < nstate; k++)
            b[k] = B[k] / sum;
        for (k = 0; k < nstate; k++) {
            for (l = 0, sum = 0; l < nstate; l++)
                sum += A[k][l];
            for (l = 0; l < nstate; l++)
                a[k][l] = A[k][l] / sum;
        }
        for (k = 0; k < nstate; k++) {
            for (i = 0, sum = 0; i < nsymbol; i++)
                sum += E[k][i];
            for (i = 0; i < nsymbol; i++)
                e[k][i] = E[k][i] / sum;
        }
        // 计算新的log-likelyhood
        for (j = 0, loglh2 = 0; j < n; j++) {
            amc[j]->e = e;
            amc[j]->a = a;
            amc[j]->b = b;
            forward(amc[j]);
            backward(amc[j]);
            loglh2 += amc[j]->logScaleSum;
        }
        loglh2 = loglh2 * 1000 / totalLen;
    }
    // 输出结果
    printf("num_of_seq = %d\n", n);
    printf("total_seq_len = %d\n", totalLen);
    printf("max_iter_num = %d\n", maxIter);
    printf("num_of_iter = %d\n", niter);
    printf("min_log_diff = %f\n", minLogDiff);
    printf("final_log_diff = %f\n", loglh2 - loglh1);
    printf("\n");
    printf("Real trans:\n");
    for (k = 0; k < nstate; k++) {
        printf("  ");
        for (l = 0; l < nstate; l++)
            printf("%f ", trans[k][l]);
        printf("\n");
    }  
    printf("Baum-Welch trans:\n");
    for (k = 0; k < nstate; k++) {
        printf("  ");
        for (l = 0; l < nstate; l++)
            printf("%f ", a[k][l]);
        printf("\n");
    }
    printf("\n");
    printf("Real emission:\n");
    for (k = 0; k < nstate; k++) {
        printf("  ");
        for (i = 0; i < nsymbol; i++)
            printf("%f ", emission[k][i]);
        printf("\n");
    }
    printf("Baum-Welch emission:\n");
    for (k = 0; k < nstate; k++) {
        printf("  ");
        for (i = 0; i < nsymbol; i++)
            printf("%f ", e[k][i]);
        printf("\n");
    }
    printf("\n");    
    // 释放空间
    free(b);
    free(B);
    free(rb);
    for (k = 0; k < nstate; k++) {
        free(ra[k]);
        free(re[k]);
        free(A[k]);
        free(E[k]);
        free(a[k]);
        free(e[k]);
    }
    free(ra);
    free(re);
    free(A);
    free(E);
    free(a);
    free(e);
}

(公众号:生信了)

相关文章

网友评论

      本文标题:序列比对(十六)——Baum-Welch算法估算HMM参数

      本文链接:https://www.haomeiwen.com/subject/lstrictx.html