SMO算法实现

作者: lightt | 来源:发表于2020-07-17 15:38 被阅读0次

    这里根据SMO算法原论文中的伪代码实现了SMO算法。算法和数据已经上传到了git

    伪代码

    target = desired output vector
    point = training point matrix
    procedure takeStep(i1,i2)
        if (i1 == i2) return 0
        alph1 = Lagrange multiplier for i1
        y1 = target[i1]
        E1 = SVM output on point[i1] – y1 (check in error cache)
        s = y1*y2
        Compute L, H via equations (13) and (14)
        if (L == H)
            return 0
        k11 = kernel(point[i1],point[i1])
        k12 = kernel(point[i1],point[i2])
        k22 = kernel(point[i2],point[i2])
        eta = k11+k22-2*k12
        if (eta > 0)
        {
            a2 = alph2 + y2*(E1-E2)/eta
            if (a2 < L) a2 = L
            else if (a2 > H) a2 = H
        }
        else
        {
            Lobj = objective function at a2=L
            Hobj = objective function at a2=H
            if (Lobj < Hobj-eps)
                a2 = L
            else if (Lobj > Hobj+eps)
                a2 = H
            else
                a2 = alph2
        }
        if (|a2-alph2| < eps*(a2+alph2+eps))
            return 0
        a1 = alph1+s*(alph2-a2)
        Update threshold to reflect change in Lagrange multipliers
        Update weight vector to reflect change in a1 & a2, if SVM is linear
        Update error cache using new Lagrange multipliers
        Store a1 in the alpha array
        Store a2 in the alpha array
        return 1
    endprocedure
    
    procedure examineExample(i2)
        y2 = target[i2]
        alph2 = Lagrange multiplier for i2
        E2 = SVM output on point[i2] – y2 (check in error cache)
        r2 = E2*y2
        if ((r2 < -tol && alph2 < C) || (r2 > tol && alph2 > 0))
        {
            if (number of non-zero & non-C alpha > 1)
            {
                i1 = result of second choice heuristic (section 2.2)
                if takeStep(i1,i2)
                    return 1
            }
            loop over all non-zero and non-C alpha, starting at a random point
            {
                i1 = identity of current alpha
                if takeStep(i1,i2)
                    return 1
            }
            loop over all possible i1, starting at a random point
            {
                i1 = loop variable
                if (takeStep(i1,i2)
                    return 1
            }
        }
        return 0
    endprocedure
    
    main routine:
        numChanged = 0
        examineAll = 1
        while (numChanged > 0 | examineAll)
        {
            numChanged = 0;
            if (examineAll)
                loop I over all training examples
                numChanged += examineExample(I)
            else
                loop I over examples where alpha is not 0 & not C
                numChanged += examineExample(I)
            if (examineAll == 1)
                examineAll = 0
            else if (numChanged == 0)
                examineAll = 1
        }
    

    python实现

    # @Author  : lightXu
    # @File    : smo_paper.py
    
    import numpy as np
    import matplotlib.pyplot as plt
    import random
    import copy
    
    
    class OptStruct:
        """
        数据结构,维护所有需要操作的值
        Parameters:
            dataMatIn - 数据矩阵
            classLabels - 数据标签
            C - 松弛变量
            toler - 容错率
        """
    
        def __init__(self, data_x, label, C, toler):
            self.X = data_x
            self.label = label
            self.C = C
            self.toler = toler
            self.row = data_x.shape[0]
            self.alpha = np.zeros(self.row)
            self.b = 0
            self.e_cache = np.zeros(self.row)
            # self.e_cache = label * (-1)
    
    
    def cal_Ek(ost, k):
        """
        计算误差
        Parameters:
            ost - 数据结构
            k - 标号为k的数据
        Returns:
            Ek - 标号为k的数据误差
        """
        fxk = np.dot((ost.alpha * ost.label).T, np.dot(ost.X, ost.X[k, :])) + ost.b
        Ek = fxk - ost.label[k]
        return round_float(Ek), round_float(fxk)
    
    
    def round_float(value):
        return round(value, 8)
    
    
    def load_data(file_name):
        data_x = []
        data_y = []
    
        with open(file_name, 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.strip().split('\t')
                xi = line[:-1]
                data_x.append(xi)
                data_y.append(line[-1])
    
        data_x = np.array(data_x, dtype=np.float)
        label = np.array(data_y, dtype=np.float)
    
        return data_x, label
    
    
    def select_j_random(i, m):
        j = i
        while j == i:
            j = int(random.uniform(0, m))
    
        return j
    
    
    def select_j(eligible_list, i, ost, Ei):
        eligible_list = list(eligible_list)
        if i in eligible_list:
            eligible_list.remove(i)
    
        E_list = [cal_Ek(ost, k)[0] for k in eligible_list]
        if Ei < 0:
            value = max(E_list)
        elif Ei > 0:
            value = min(E_list)
        else:
            value = max([abs(cal_Ek(ost, k)[0]) for k in eligible_list])
        max_k = eligible_list[E_list.index(value)]
    
        # E_list1 = [(Ei - cal_Ek(ost, k)[0]) for k in eligible_list]
        # value1 = max(E_list1)
        # max_k1 = eligible_list[E_list1.index(value1)]
        # if max_k != max_k1:
        #     print('!=', i)
    
        Ej, _ = cal_Ek(ost, max_k)
        return max_k, Ej
    
    
    def updateEk(ost, k):
        """
        计算Ek,并更新误差缓存
        Parameters:
            oS - 数据结构
            k - 标号为k的数据的索引值
        Returns:
        """
        Ek, _ = cal_Ek(ost, k)
        ost.e_cache[k] = Ek
    
    
    def clip_alpha(alpha, L, H):
        if alpha > H:
            alpha = H
        if alpha < L:
            alpha = L
    
        return alpha
    
    
    def cal_w(data_x, label, alpha):
        w = np.dot((alpha * label).T, data_x)
        return w
    
    
    def objective_func(ost, i1, i2, alpha1, alpha2, L, H):
        s = ost.label[i1] * ost.label[i2]
        k11 = np.dot(ost.X[i1, :], ost.X[i1, :])
        k12 = np.dot(ost.X[i1, :], ost.X[i2, :])
        k22 = np.dot(ost.X[i2, :], ost.X[i2, :])
        f1 = (ost.label[i1] * (cal_Ek(ost, i1) + ost.b) - alpha1 * k11 - s * alpha2 * k12)
        f2 = (ost.label[i2] * (cal_Ek(ost, i2) + ost.b) - s * alpha1 * k12 - alpha2 * k12)
    
        L1 = alpha1 + s * (alpha2 - L)
        H1 = alpha1 + s * (alpha2 - H)
    
        obj_L = L1 * f1 + L * f2 + 0.5 * L1 * L1 * k11 + 0.5 * L * L * k22 + s * L * L1 * k12
        obj_H = H1 * f1 + H * f2 + 0.5 * H1 * H1 * k11 + 0.5 * H * H * k22 + s * H * H1 * k12
    
        return obj_L, obj_H
    
    
    def take_step(ost, i1, i2, E2):
        if i1 == i2:
            return 0
        alpha1 = ost.alpha[i1].copy()
        y1 = ost.label[i1]
        alpha2 = ost.alpha[i2].copy()
        y2 = ost.label[i2]
    
        E1, _ = cal_Ek(ost, i1)
        s = y1 * y2
    
        if ost.label[i1] != ost.label[i2]:
            L = max(0, ost.alpha[i2] - ost.alpha[i1])
            H = min(ost.C, ost.C + ost.alpha[i2] - ost.alpha[i1])
        else:
            L = max(0, ost.alpha[i2] + ost.alpha[i1] - ost.C)
            H = min(ost.C, ost.alpha[i2] + ost.alpha[i1])
        if L == H:
            # print("L==H")
            return 0
    
        eta = (np.dot(ost.X[i1, :], ost.X[i1, :])
               + np.dot(ost.X[i2, :], ost.X[i2, :])
               - 2 * np.dot(ost.X[i1, :], ost.X[i2, :]))
        eta = round_float(eta)
        if eta > 0:
            a2 = alpha2 + y2 * (E1 - E2) / eta
            a2 = a2
            if a2 < L:
                a2 = L
            if a2 > H:
                a2 = H
    
        else:
            Lobj, _ = objective_func(ost, i1, i2, alpha1, L, L, H)
            _, Hobj = objective_func(ost, i1, i2, alpha1, H, L, H)
            if Lobj < Hobj - 0.0001:
                a2 = L
            elif Lobj > Hobj + 0.001:
                a2 = H
            else:
                a2 = alpha2
    
        if abs(a2 - alpha2) < 0.0001:
            return 0
    
        a1 = alpha1 + s * (alpha2 - a2)
    
        b1 = (ost.b - E1
              - ost.label[i1] * (ost.alpha[i1] - alpha1) * np.dot(ost.X[i1, :], ost.X[i1, :])
              - ost.label[i2] * (ost.alpha[i2] - alpha2) * np.dot(ost.X[i1, :], ost.X[i2, :]))
        b2 = (ost.b - E2
              - ost.label[i1] * (ost.alpha[i1] - alpha1) * np.dot(ost.X[i1, :], ost.X[i2, :])
              - ost.label[i2] * (ost.alpha[i2] - alpha2) * np.dot(ost.X[i2, :], ost.X[i2, :]))
    
        if 0 < ost.alpha[i1] < ost.C:
            ost.b = b1
        elif 0 < ost.alpha[i2] < ost.C:
            ost.b = b2
        else:
            ost.b = (b1 + b2) / 2.0
    
        updateEk(ost, i1)
        updateEk(ost, i2)
    
        ost.alpha[i1] = round_float(a1)
        ost.alpha[i2] = round_float(a2)
    
        return 1
    
    
    def violate_kkt(ost, alpha2, E2, fx2, y2):
        r2 = E2 * y2
        violate_cond1 = r2 < -ost.toler and alpha2 < ost.C
        violate_cond2 = r2 > ost.toler and alpha2 > 0
    
        violate12 = violate_cond1 or violate_cond2
    
        # 原始kkt
        # y2 * fx2 - 1 = y2*(fx2-y2) = y2*E2
        violate_cond3 = (not y2 * fx2 - 1 >= 0) and alpha2 == 0
        violate_cond4 = (not y2 * fx2 - 1 != 0) and 0 < alpha2 < ost.C
        violate_cond5 = (not y2 * fx2 - 1 <= 0) and alpha2 == ost.C
    
        # Notice that the KKT conditions are checked to be within ε of fulfillment.
        # 论文中引入了一个误差eps, 此时
        violate_cond3_ = (not y2 * fx2 - 1 >= -ost.toler) and alpha2 == 0
        violate_cond4_ = (not abs(y2 * fx2 - 1) <= ost.toler) and 0 < alpha2 < ost.C
        violate_cond5_ = (not y2 * fx2 - 1 <= ost.toler) and alpha2 == ost.C
    
        violate345 = violate_cond3_ or violate_cond4_ or violate_cond5_
    
        return violate345
    
    
    def examine_example(ost, i2):
        y2 = ost.label[i2]
        alpha2 = ost.alpha[i2]
        E2, fx2 = cal_Ek(ost, i2)
    
        # 是非违反kkt条件
        cond = violate_kkt(ost, alpha2, E2, fx2, y2)
        if cond:
            non_0_non_C_alpha_list = np.where((ost.alpha != 0) & (ost.alpha != ost.C))[0]
            if (len(non_0_non_C_alpha_list)) > 1:
                i1, _ = select_j(non_0_non_C_alpha_list, i2, ost, E2)
                if take_step(ost, i1, i2, E2):
                    return 1
    
            non_tmp = non_0_non_C_alpha_list.copy().tolist()
            while len(non_tmp) > 0:
                i1 = random.choice(non_tmp)
                if take_step(ost, i1, i2, E2):
                    return 1
                else:
                    non_tmp.remove(i1)
    
            tmp_list = list(range(0, ost.row))
            while len(tmp_list) > 0:
                i1 = random.choice(tmp_list)
                if take_step(ost, i1, i2, E2):
                    return 1
                else:
                    tmp_list.remove(i1)
    
        return 0
    
    
    def main(dataMatIn, classLabels, C, toler, maxIter):
        ost = OptStruct(dataMatIn, classLabels, C, toler)
        iter_num = 0
        num_changed = 0
        examine_all = 1
    
        while (iter_num < maxIter) and num_changed > 0 or examine_all:
            num_changed = 0
            if examine_all:
                for i in range(ost.row):
                    """
                    The outer loop first iterates over the entire training set, 
                    determining whether each example violates the KKT conditions (12). 
                    """
                    num_changed = num_changed + examine_example(ost, i)
                    print("全样本遍历:第%d次迭代 样本:%d, alpha优化次数:%d" % (iter_num, i, num_changed))
                iter_num += 1
    
            else:
                """
                After one pass through the entire training set, the outer loop iterates over all examples whose
                Lagrange multipliers are neither 0 nor C (the non-bound examples). Again, each example is
                checked against the KKT conditions and violating examples are eligible for optimization. 
                """
                non_bound_index = np.where((0 < ost.alpha) & (ost.alpha < C))[0]
                for i in non_bound_index:
                    num_changed = num_changed + examine_example(ost, i)
                    print("非边界:第%d次迭代 样本:%d, alpha优化次数:%d" % (iter_num, i, num_changed))
                iter_num += 1
            if examine_all:
                examine_all = 0
            elif num_changed == 0:
                examine_all = 1
    
        return ost.b, ost.alpha
    
    
    def show_classifier(data_x, label, w, b, alpha, seed):
        positive_index = np.where(label == 1)[0]
        negative_index = np.where(label == -1)[0]
        data_x_positive = data_x[positive_index]
        data_x_negative = data_x[negative_index]
    
        plt.scatter(data_x_positive[:, 0], data_x_positive[:, 1],
                    s=30, alpha=0.7, c='green')  # 正样本散点图
        plt.scatter(data_x_negative[:, 0], data_x_negative[:, 1],
                    s=30, alpha=0.7, c='pink')  # 负样本散点图
    
        x_max = np.max(data_x, axis=0)[0]
        x_min = np.min(data_x, axis=0)[0]
        a1, a2 = w
        b = float(b)
        y1, y2 = (-b - a1 * x_max) / a2, (-b - a1 * x_min) / a2
        plt.plot([x_max, x_min], [y1, y2])
    
        # 找出支持向量点
        for i, alp in enumerate(alpha):
            if abs(alp) > 0:
                print(i)
                x_max, x_min = data_x[i]
                plt.scatter([x_max], [x_min], s=150, c='none', alpha=0.7, linewidth=1.5, edgecolor='red')
    
        # plt.savefig("./fig/seed_{}.png".format(seed))
        plt.show()
        plt.close()
    
    
    def show_classifier1(data_x, label):
        positive_index = np.where(label == 1)[0]
        negative_index = np.where(label == -1)[0]
        data_x_positive = data_x[positive_index]
        data_x_negative = data_x[negative_index]
    
        plt.scatter(data_x_positive[:, 0], data_x_positive[:, 1],
                    s=30, alpha=0.7, c='green')  # 正样本散点图
        plt.scatter(data_x_negative[:, 0], data_x_negative[:, 1],
                    s=30, alpha=0.7, c='pink')  # 负样本散点图
    
        plt.show()
    
    
    if __name__ == '__main__':
        seed = 10
        random.seed(seed)
        dataMat, labelMat = load_data('testSet.txt')
        b, alphas = main(dataMat, labelMat, 0.6, 0.0001, 100)
        w = cal_w(dataMat, labelMat, alphas)
        print(w, b)
        show_classifier(dataMat, labelMat, w, b, alphas, seed)
    
    

    分类结果如下:


    smo.png

    补充

    第一个参数选择需要判断是否违反原始KKT条件, 这部分的原理可以参考博客。论文作者在这部分引入了eps加快训练, 推导过程大家直接看代码就行了。

    相关文章

      网友评论

        本文标题:SMO算法实现

        本文链接:https://www.haomeiwen.com/subject/zapghktx.html