美文网首页机器学习与数据挖掘机器学习机器学习和人工智能入门
集成学习系列2:AdaBoost算法和实例(例子简单易懂)

集成学习系列2:AdaBoost算法和实例(例子简单易懂)

作者: b424191ea349 | 来源:发表于2019-04-10 10:20 被阅读0次

    1. 概念

    首先我们看这样的一张图,很明显能够看出来,AdaBoost是boosting 家族的一员。


    既然是boosting家族的一员,那么我们看看AdaBoost如何解决boosting中的两个问题:

    1. 每一轮如何改变训练数据的权值或概率分布,以获取不同的弱分类器?
      AdaBoost通过提高那些被前一轮弱分类器错误分类样本的权值,降低那些被正确分类样本的权值来实现的。
    2. 如何将弱分类器组合成一个强分类器?
      AdaBoost通过加权多数表决,加大分类误差率小的弱分类器的权值,使其在表决中起较大的作用,减小分类误差率大的弱分类器的权值,使其在表决中起较小的作用。

    模型图如下:


    2. 具体算法

    输入:二分类的训练数据集 T=\left\{\left(x_{1}, y_{1}\right),\left(x_{2}, y_{2}\right), \cdots,\left(x_{N}, y_{N}\right)\right\}\;\;x_{i} \in \mathcal{X} \subseteq \mathbf{R}^{n} \quad y_{i} \in \mathcal{Y}=\{-1,+1\}

    这个很好理解,就是一组数据二分类的数据。
    输出:最终分类器G(x)

    步骤一:初始化训练数据的起始权值分布D_{1}=\left(w_{11}, \cdots, w_{1 i}, \cdots, w_{1 N}\right) \quad w_{1 i}=\frac{1}{N}, \quad i=1,2, \cdots, N
    说白了,就是初始化了一些权值w,这是第一步,所以w的行都是1,列表示N个数据,也就是第一行全部初始化成\frac1N

    步骤二:对第m个弱分类器 m=1,2,\dots M
    a首先:在权值D_m下训练数据集,得到弱分类器:G_{m}(x) : \mathcal{X} \rightarrow\{-1,+1\}
    b然后:计算G_m的训练误差e_{m}=P\left(G_{m}\left(x_{i}\right) \neq y_{i}\right)=\sum_{i=1}^{N} w_{m i} I\left(G_{m}\left(x_{i}\right) \neq y_{i}\right),这里写的很复杂,但实际上就是把所有算错的样本点对应的权值w都加在一起。
    c接着:计算G_m的系数:\alpha_{m}=\frac{1}{2} \log \frac{1-e_{m}}{e_{m}}
    d接着:更新训练数据集的权值分布
    \begin{array}{c}{D_{m+1}=\left(w_{m+1,1}, \cdots, w_{m+1, i}, \cdots, w_{m+1, N}\right)} \\ {w_{m+1, i}=\frac{w_{m i}}{Z_{m}} \exp \left(-\alpha_{m} y_{i} G_{m}\left(x_{i}\right)\right), \quad i=1,2, \cdots, N}\end{array}
    这里Z是规范化因子(这个东东是不是和CRF部分类似):
    Z_{m}=\sum_{i=1}^{N} w_{m i} \exp \left(-\alpha_{m} y_{i} G_{m}\left(x_{i}\right)\right)

    步骤三:构建弱分类器的线性组合
    f(x)=\sum_{m=1}^{M} \alpha_{m} G_{m}(x)
    得到最终分类器:
    G(x)=\operatorname{sign}(f(x))=\operatorname{sign}\left(\sum_{m=1}^{M} \alpha_{m} G_{m}(x)\right)

    这个算法复杂不?看起来复杂,但是一步步看下去,真的是一点也不复杂。

    深入的说明

    1. 步骤一中假设训练集具有均匀的权值分布,即每个训练样本在基本分类器中作用相同,这一假设保证了第一步能够在原始数据上学习基本分类器G_1(x)
    2. 步骤二Adaboost反复学习基本分类器,在每一轮m=1,2,\dots M,顺次地执行以下操作:
      (a). 使用当前分布加权的训练数据集,学习基本分类器G_m(x)
      (b). 计算基本分类器Gm(x)在加权训练数据集上的分类误差率:
        e_{m}=P\left(G_{m}\left(x_{i}\right) \neq y_{i}\right)=\sum_{G_{m}\left(x_{i}\right) \neq y_{i}} w_{m i}
      这里w_{mi}表示第m轮中第i个实例的权值,\sum_{i=1}^{N}w_{m i}=1。这表明G_m(x)在加权训练数据集上的分类误差率是被G_m(x)误分类样本的权值之和,由此可以看出数据权值分布D_m与基本分类器G_m(x)的分类误差率的关系。
      (c)计算基本分类器G_m(x)的系数\alpha_m\alpha_m表示G_m(x)在最终分类器中的重要性,由\alpha_m的计算公式可知,当e_m \le \frac12时,\alpha_m \ge 0。并且\alpha_m随着e_m的减小而增大,所以分类误差率小的基本分类器在最终分类器的作用越大,这也符合我们通常的认知。
      (d) 更新训练数据的权值为下一轮作准备:
      w_{m+1, i}=\left\{\begin{array}{ll}{\frac{w_{m i}}{Z_{m}} \mathrm{e}^{-\alpha_{m}},} & {G_{m}\left(x_{i}\right)=y_{i}} \\ {\frac{w_{m i}}{Z_{m}} \mathrm{e}^{\alpha_{m i}},} & {G_{m}\left(x_{i}\right) \neq y_{i}}\end{array}\right.
      由此可知,被基本分类器G_m(x)误分类的样本权值得以扩大,而正确分类的样本权值得以缩小,两者相比较,误分类的样本权值被放大\mathrm{e}^{2 \alpha_{m}}=\frac{e_{m}}{1-e_{m}}倍,因此误分类样本在下一轮学习中起更大作用,不改变所给的训练数据,而不断改变训练数据权值的分布,使得训练数据在基本分类器中的学习中起不同作用,这是Adaboost的一个特点。
    3. 步骤3线性组合f(x)实现M个基本分类器的加权表决。系数\alpha_m表示了基本分类器G_m(x)的重要性。这里所有\alpha_m之和并不为1,f(x)的符号决定实例x的类,f(x)的绝对值表示分类的确信度,利用基本分类器的线性组合构建最终分类器是AdaBoost的另一个特点。

    3. 一个例子

    给定下列训练样本,试用AdaBoost算法学习一个强分类器:


    首先初始化数据权值分布:
    \begin{array}{c}{D_{1}=\left(w_{11}, w_{12}, \cdots, w_{110}\right)} \\ {w_{1 i}=0.1, \quad i=1,2, \cdots, 10}\end{array}

    然后我们看到这里实际上由两个弱分类器选择,阈值以下我们预测为1,阈值以上我们预测为-1,阈值分别取2.5(错误分类点序号是7、8、9),8.5(错误分类点序号为4、5、6)。

    m=1来说:
    a、在权值分布为D1的数据集上,阈值取2.5,分类误差率最小,基本弱分类器:
    G_{1}(x)=\left\{\begin{array}{ll}{1,} & {x<2.5} \\ {-1,} & {x>2.5}\end{array}\right.
    注:关于这里为什么取2.5,对于两个弱分类器阈值分别为2.5和8.5,它的误差率都是0.3,所以无所谓了,取第一个就好了。

    b、G1(x)的误差率 :e_{1}=P\left(G_{1}\left(x_{i}\right) \neq y_{i}\right)=0.3
    c、G1(x)的系数:\alpha_{1}=\frac{1}{2} \log \frac{1-e_{1}}{e_{1}}=0.4236
    d、更新训练数据的权值分布:
    D_{2} =(w_{21}, \cdots, w_{2 i}, \cdots, w_{210})

    w_{2 i} =\frac{w_{1 i}}{Z_{1}} \exp \left(-\alpha_{1} y_{i} G_{1}\left(x_{i}\right)\right), \quad i=1,2, \cdots, 10

    \begin{array}{l}{D_{2}=(0.0715,0.0715,0.0715,0.0715,0.0715,0.0715,} \\ {0.1666,0.1666,0.1666,0.0715 )} \\ {f_{1}(x)=0.4236 G_{1}(x)}\end{array}
    弱基本分类器G_1(x)=sign[f_1(x)]在更新的数据集上有3个误分类点。

    此时样本点变成这样了:

    对m=2来说:
    a、在权值分布为D2的训练数据上,阈值v取8.5时误差率最低,故基本分类器为:
    G_{2}(x)=\left\{\begin{array}{ll}{1,} & {x<8.5} \\ {-1,} & {x>8.5}\end{array}\right.
    注:对于这里为什么取阈值为8.5而不是取2.5,当阈值为2.5是,错误分类的点序号是7、8、9,对应x=6、7、8,误差率为0.1666+0.1666+0.1666=0.4998,很明显比阈值为8.5的误差率0.2143大,所以我们选阈值为8.5的,当然不只是和阈值为2.5比较,也可以任意选择一个阈值比较,都是这个最小。

    b、误差率e_{2}=0.2143
    c、计算\alpha_{2}=0.6496
    d、更新权值分布:
    \begin{aligned} D_{3}=&(0.0455,0.0455,0.0455,0.1667,0.1667,0.1667,0.1060,0.1060,0.1060,0.0455 ) \\& \end{aligned}
    f_{2}(x)=0.4236 G_{1}(x)+0.6496 G_{2}(x)

    分类器G_2(x)=sign[f_2(x)]有三个误分类点。

    此时样本点变成了:


    对m=3来说:
    a、在权值分布D3上,阈值v=5.5时,分类误差率最低
    G_{3}(x)=\left\{\begin{array}{ll}{1,} & {x>5.5} \\ {-1,} & {x<5.5}\end{array}\right.
    注:这里取5.5同上面的取法一样的。

    b、误差率e_{3}=0.1820
    c、计算\alpha_{3}=0.7514
    d、更新权值分布
    D_{4}=(0.125,0.125,0.125,0.102,0.102,0.102,0.102,0.065,0.065,0.065,0.125)
    f_{3}(x)=0.4236 G_{1}(x)+0.6496 G_{2}(x)+0.7514 G_{3}(x)
    分类器sign(f3(x))在训练数据集上有0个误分类点:
    G(x)=\operatorname{sign}\left[f_{3}(x)\right]=\operatorname{sign}\left[0.4236 G_{1}(x)+0.6496 G_{2}(x)+0.7514 G_{3}(x)\right]

    注:很多同学不清楚为什么这里的误分类点是0个?详细解释以下:
    首先我们由三个弱分类器了,然后这三个弱分类器组合成了一个强分类器:
    G_{1}(x)=\left\{\begin{array}{ll}{1,} & {x<2.5} \\ {-1,} & {x>2.5}\end{array}\right.
    G_{2}(x)=\left\{\begin{array}{ll}{1,} & {x<8.5} \\ {-1,} & {x>8.5}\end{array}\right.
    G_{3}(x)=\left\{\begin{array}{ll}{1,} & {x>5.5} \\ {-1,} & {x<5.5}\end{array}\right.

    序号 1 2 3 4 5 6 7 8 9 10
    X 0 1 2 3 4 5 6 7 8 9
    G1(x) 1 1 1 -1 -1 -1 -1 -1 -1 -1
    G2(x) 1 1 1 1 1 1 1 1 1 -1
    G3(x) -1 -1 -1 -1 -1 -1 1 1 1 1
    fx 0.3218 0.3218 0.3218 -0.5254 -0.5254 -0.5254 0.9774 0.9774 0.9774 -0.3218
    预测 1 1 1 -1 -1 -1 1 1 1 -1
    真实 1 1 1 -1 -1 -1 1 1 1 -1

    可见,误差率确实是0。

    参考

    《统计学习方法》
    Adaboost算法原理分析和实例+代码(简明易懂)

    相关文章

      网友评论

        本文标题:集成学习系列2:AdaBoost算法和实例(例子简单易懂)

        本文链接:https://www.haomeiwen.com/subject/qwwgiqtx.html