美文网首页
优化算法应用(五)优化支持向量机(SVM)

优化算法应用(五)优化支持向量机(SVM)

作者: stronghorse | 来源:发表于2022-12-09 20:03 被阅读0次

    一. 目标描述

    支持向量机(Support Vector Machine,SVM)是一种有监督的机器学习算法。其原理简单,但是通用的公式推到及求解过程异常的复杂,这里将直接用优化算法来对支持向量机进行求解。
    (该篇没有对偶问题转化推导,没有核函数,如需要了解自行搜索。)

    二. 支持向量机简介

    支持向量机是二分类算法,它的主要目标是找到两个数据集的支持向量及分割平面,使分割平面间的距离最大化。


      如上图,红蓝点分别代表两个数据集,其中红色线和蓝色线表示由支持向量决定的两个平行的线(超平面),超平面之间的距离为2d。我们的目标就是找到最佳的斜率和支持向量,让d的值最大。

    1. 硬间隔

    硬间隔表示所有的数据点都需要在超平面的一侧。


      如上图,确定L的斜率后,蓝色数据集可以确定4个超平面L1,L2,L3,L4。由于硬间隔的要求,蓝色数据集只能在超平面的一侧,故只有L4能被选作超平面。
      选定超平面后就可以计算两个超平面之间的距离。
      设蓝色数据集为(x1,y1), (x2,y2), (x3,y3), (x4,y4),红色数据集(m1,n1), (m2,n2), (m3,n3), (m4,n4)。
      假设(x1,y1)和(m1,n1)为两个支持向量,则超平面由下式确定


      点(a,b)到直线Ax+By+C=0的距离计算公式如下:


      则可计算两个超平面间间隔2d如下


    2. 软间隔

    当数据集中有一定异常数据,数据无法被两个超平面分为两类时,将引入松弛参数,并重新计算数据集之间的最大距离。


      如图,若蓝色有两个分类错误的数据点时,分别计算其到对应超平面的距离,整个数据集间的距离如下:


      其中k为松弛参数,k为非负数,当k取值为0时,则表示可以允许所有数据点分类错误,当k取值为正无穷时,表示不允许任何数据点分类错误,此时,等同于硬间隔。

    三. 支持向量机适应度函数

    1. 适应度函数设计

    从支持向量机的定义知,这是一个最大化问题。其中的变量为超平面,而超平面由支持向量决定。所以适应度函数的输入应该是超平面的各个维度的斜率,输出则是两个数据集之间的距离。其流程如下:



      以上面的例子来说计算步骤如下
      设蓝色数据集为(x1,y1), (x2,y2), (x3,y3), (x4,y4),红色数据集(m1,n1), (m2,n2), (m3,n3), (m4,n4)。
      确定一组斜率(A,B),可以计算出8个超平面,分别由C1,C2,C3,C4,C5,C6,C7,C8。

    数据点 截距
    (x1,y1) C1
    (x2,y2) C2
    (x3,y3) C3
    (x4,y4) C4
    (m1,n1) C5
    (m2,n2) C6
    (m3,n3) C7
    (m4,n4) C8

    然后计算各个超平面组合之间的距离,即在下表中选出最大的D作为适应度值返回。

    C5 C6 C7 C8
    C1 D15 D16 D17 D18
    C2 D25 D26 D27 D28
    C3 D35 D36 D37 D38
    C4 D45 D46 D47 D48

    2. 适应度函数优化

    从适应度函数可以看出,如果数据集1有i个数据,数据集2有j个数据,那么为了选出最佳的支持向量,我们需要计算i*j次数据集之间的距离,当数据集中的数据量很大时,这将是会耗费非常多的时间。
      下面,将适应度函数优化一下,每次只计算单次数据集间距离。
      可以看到,原适应度函数中,若数据集中数据的维度为2时,输入的维度也为2,输入中的每一维都可以认为是该维度斜率(如A,B相除得到斜率),然后会根据斜率和数据点计算出该超平面的截距(如C1,C2),然后选出最佳的截距组合来确定超平面之间的距离。
      这里优化一下,将适应度函数的输入确定为两个超平面,如此每次只需计算这两个超平面之间的距离即可。
      具体如下:
      若数据集中数据的维度为2,那么适应度函数输入的维度为2+2(如(A,B,C1,C2)), 若数据集中数据的维度为N时,适应度函数输入的维度为N+2,其中最后两维代表这两个超平面的截距。
    这样一来,适应度函数输入的就是两个确定的超平面,而不是超平面的截距,但是此时,该超平面不一定会穿过数据集中的数据点,或者可以说该超平面不存在支持向量。我愿称之为广义支持向量机。


      如图,优化后的适应度函数计算出的超平面不一定经过数据点。

    四. 代码实现

    1 .原SVM模型

    文件路径:\optimization algorithm\application_svm\SVM_Model.m

    classdef SVM_Model < handle
        properties
            % 数据集m行*n列的矩阵,n = dim
            dataset1 = [];
            dataset2 = [];
            % 惩罚系数
            c = 5;
        end
        
        methods
            % 构造函数
            function self = SVM_Model(dataset1,dataset2)
                self.dataset1 = dataset1;
                self.dataset2 = dataset2;
            end
            
            % 输入为向量,输出为适应度值
            % 输入的x为角度
            function value = fit_function(self,x) 
                % 计算哪个数据集在上方
                is_up = self.get_is_up(x);
                
                % 计算结果及最佳支持向量截距
                [value,intercept1,intercept2] = self.get_value_and_intercept(x,is_up);
            end
            
            % 绘图,num为图片编号
            function draw(self,input,num)      
                bound_max = max(abs([self.dataset1(1,:),self.dataset1(2,:),self.dataset2(1,:),self.dataset2(2,:)]));
                set(gca,'XLim',[-bound_max bound_max]);
                set(gca,'YLim',[-bound_max bound_max]);
                % 绘制左下右上两点,保持图像区域不变
                scatter(-bound_max,-bound_max,1,'w');
                hold on;
                scatter(bound_max,bound_max,1,'w');
                hold on;
                axis square;
                
                for i = 1:length(self.dataset1)
                    scatter(self.dataset1(i,1),self.dataset1(i,2),10,'b','filled');
                    hold on;
                end
                for i = 1:length(self.dataset2)
                    scatter(self.dataset2(i,1),self.dataset2(i,2),10,'r','filled');
                    hold on;
                end
                
                % 有数据则绘制分割线
                if ~isempty(input)
                    is_up = self.get_is_up(input);
                    [value,intercept1,intercept2] = self.get_value_and_intercept(input,is_up);
                    x =-bound_max:0.1:bound_max;
                    k = -(input(1))/(input(2));
                    x1=x;
                    y1 = k*x1+intercept1/-(input(2));
                    % 剔除直线上超出边界的点
                    I = y1>bound_max;
                    y1(I) = [];
                    x1(I) = [];
                    J = y1<-bound_max;
                    y1(J) = [];
                    x1(J) = [];
                    plot(x1,y1,'b');
                    hold on;
                    
                    y2 = k*x+intercept2/-(input(2));
                    x2=x;
                    % 剔除直线上超出边界的点
                    I = y2>bound_max;
                    y2(I) = [];
                    x2(I) = [];
                    J = y2<-bound_max;
                    y2(J) = [];
                    x2(J) = [];
                    plot(x2,y2,'r');
                    hold on;
                end
                text(bound_max*0.9+-bound_max*0.1,bound_max*0.9-bound_max*0.1,num2str(num),'FontSize',20);
            end
            
        end
        
        % 受保护的方法,继承用
        methods (Access = protected)
            
            % 计算该斜率下哪个数据集在上哪个数据集在下,
            % 若数据集1 在上则返回true
            function value = get_is_up(self,x)
                % 计算两个数据集的平均位置
                mean_data1 = sum(self.dataset1)/length(self.dataset1);
                mean_data2 = sum(self.dataset2)/length(self.dataset2);
                
                % 根据数据集的平均位置来计算该斜率下的截距,用来判断哪个数据集在上,哪个数据集在下
                mean_intercept1 = self.get_intercept(mean_data1,x);
                mean_intercept2 = self.get_intercept(mean_data2,x);
                value = mean_intercept1 < mean_intercept2;
            end
            
            % 计算当前斜率下的最佳截距和结果
            % 输入的gradient为斜率
            function [value,intercept1,intercept2] = get_value_and_intercept(self,x,is_up)
                value = -realmax('double');
                % 遍历数据集,找出最佳的支持向量组合
                % 数据集之间距离越大越优
                for i = 1:length(self.dataset1)
                    for j = 1:length(self.dataset2)
                        tmp_intercept1 = self.get_intercept(self.dataset1(i,:),x);
                        tmp_intercept2 = self.get_intercept(self.dataset2(j,:),x);
                        svm_value = self.get_svm_value(x,tmp_intercept1,tmp_intercept2,is_up);
                        % 记录着两个支持向量对应的截距
                        if value < svm_value
                            intercept1 = tmp_intercept1;
                            intercept2 = tmp_intercept2;
                            value = svm_value;
                        end
                    end
                end
            end
            
            % 根据数据和超平面斜率,计算超平面截距
            % a*x+b*y+c = 0 --> c = -(a*x+b*y)
            function value = get_intercept(self,data,x)
                value = -sum(data.*x);
            end
            
            % 计算一个整个数据集在该超平面上的距离
            function value = get_svm_value(self,x,intercept1,intercept2,is_up)
                % 直线的分母
                deno = sqrt(sum(x.^2));
                if is_up
                    % 如果数据集1在数据集2的下方
                    % 则计算数据集1中在支持向量上方的数据的距离
                    value = -(intercept1-intercept2)/deno;
                    for i = 1:length(self.dataset1)
                        temp_intercept1 = self.get_intercept(self.dataset1(i,:),x);
                        if  temp_intercept1 > intercept1
                            value = value - self.c*abs(temp_intercept1-intercept1)/deno;
                        end
                    end
                    % 如果数据集2在数据集1的上方
                    % 则计算数据集2中在支持向量下方的数据的距离
                    for i = 1:length(self.dataset2)
                        temp_intercept2 = self.get_intercept(self.dataset2(i,:),x);
                        if temp_intercept2 < intercept2
                            value = value - self.c*abs(temp_intercept2-intercept2)/deno;
                        end
                    end
                else
                    % 如果数据集1在数据集2的上方
                    % 则计算数据集1中在支持向量下方的数据的距离
                    value = (intercept1-intercept2)/deno;
                    for i = 1:length(self.dataset1)
                        temp_intercept1 = self.get_intercept(self.dataset1(i,:),x);
                        if  temp_intercept1 < intercept1
                            value = value - self.c*abs(temp_intercept1-intercept1)/deno;
                        end
                    end
                    % 如果数据集2在数据集1的下方
                    % 则计算数据集2中在支持向量上方的数据的距离
                    for i = 1:length(self.dataset2)
                        temp_intercept2 = self.get_intercept(self.dataset2(i,:),x);
                        if temp_intercept2 > intercept2
                            value = value - self.c*abs(temp_intercept2-intercept2)/deno;
                        end
                    end
                end
    
            end
    
        end
        
    end
    

    测试代码
    文件路径:\optimization algorithm\application_svm\Test.m

    %% 清理之前的数据
    % 清除所有数据
    clear all;
    % 清除窗口输出
    clc;
    
    a = unifrnd(-pi,pi,20,1);
    ra = unifrnd(0,5,20,1);
    b = unifrnd(-pi,pi,20,1);
    rb = unifrnd(0,5,20,1);
    
    data1 = [sin(a(:)).*ra(:)+5,cos(a(:)).*ra(:)+5;];
    data2 = [sin(b(:)).*rb(:)-5,cos(b(:)).*rb(:)-5;];
    
    % 数据维度为2
    data_dim = 2;
    model = SVM_Model(data1,data2);
    
    range_max = ones(1,data_dim);
    range_min = ones(1,data_dim)*-1;
    
    %% 添加目录
    % 将上级目录中的frame文件夹加入路径
    addpath('../frame')
    % 引入差分进化算法
    addpath('../algorithm_differential_evolution')
    %% 算法实例
    dim = data_dim;
    % 种群数量
    size = 10;
    % 最大迭代次数
    iter_max = 50;
    % 取值范围上界
    range_max_list = range_max;
    % 取值范围下界
    range_min_list = range_min;
    % 实例化差分进化算法类
    base = DE_Impl(dim,size,iter_max,range_min_list,range_max_list);
    base.is_cal_max = true;
    % 确定适应度函数
    base.fitfunction = @model.fit_function;
    % 运行
    base.run();
    disp(['复杂度',num2str(base.cal_fit_num)]);
    
    disp(model.fit_function(base.position_best));
    
    %% 下面绘制动态图
    % 绘制每一代的路径
    for i = 1:length(base.position_best_history)
        model.draw(base.position_best_history(i,:),i);
        % 每0.01绘制一次
        pause = 0.01;
        %下面是保存为GIF的程序
        frame=getframe(gcf);
        % 返回单帧颜色图像
        imind=frame2im(frame);
        % 颜色转换
        [imind,cm] = rgb2ind(imind,256);
        filename = ['svm.gif'];
        if i==1
             imwrite(imind,cm,filename,'gif', 'Loopcount',inf,'DelayTime',1e-4);
        else
             imwrite(imind,cm,filename,'gif','WriteMode','append','DelayTime',pause);
        end
    
        if i <length(base.position_best_history)
            % 如果不是最后一张图就清除窗口
            clf;
        end
    end
    

    运行结果:

    从图中可以看出,每一代的超平面都会进过数据点。

    2. 优化后SVM模型

    文件路径:\optimization algorithm\application_svm\SVM_Model_Broad.m

    % 广义的svm,超平面不一定会经过数据点
    % 集成至svm
    classdef SVM_Model_Broad < SVM_Model
    
        properties
        end
        
        methods
            % 构造函数
            function self = SVM_Model_Broad(dataset1,dataset2)
                 % 调用父类构造函数
                self@SVM_Model(dataset1,dataset2);
            end
            
            % 输入为向量,输出为适应度值
            % 输入的x为角度,x维度 =数据维度+2,最后两维为截距
            function value = fit_function(self,x) 
                
                % 取出除后2维的其他维
                k = x(1:length(x)-2);
                
                 % 计算哪个数据集在上方
                is_up = self.get_is_up(k);
                
                % 数据集1的截距
                intercept1 = x(length(x)-1);
                % 数据集2的截距
                intercept2 = x(length(x));
                
                % 计算结果
                value = self.get_svm_value(k,intercept1,intercept2,is_up);
            end
            
            % 绘图,num为图片编号,输入的input为角度
            function draw(self,input,num)      
                bound_max = max(abs([self.dataset1(1,:),self.dataset1(2,:),self.dataset2(1,:),self.dataset2(2,:)]));
                set(gca,'XLim',[-bound_max bound_max]);
                set(gca,'YLim',[-bound_max bound_max]);
                % 绘制左下右上两点,保持图像区域不变
                scatter(-bound_max,-bound_max,1,'w');
                hold on;
                scatter(bound_max,bound_max,1,'w');
                hold on;
                axis square;
                
                for i = 1:length(self.dataset1)
                    scatter(self.dataset1(i,1),self.dataset1(i,2),10,'b','filled');
                    hold on;
                end
                for i = 1:length(self.dataset2)
                    scatter(self.dataset2(i,1),self.dataset2(i,2),10,'r','filled');
                    hold on;
                end
                
                % 有数据则绘制分割线
                if ~isempty(input)
                    % 获取两个数据集的截距
                    intercept1 = input(length(input)-1);
                    intercept2 = input(length(input));
                    
                    % 获取斜率
                    k = -(input(1))/(input(2));
                    x =-bound_max:0.1:bound_max;
                    x1=x;
                    y1 = k*x1+intercept1/-(input(2));
                    % 剔除直线上超出边界的点
                    I = y1>bound_max;
                    y1(I) = [];
                    x1(I) = [];
                    J = y1<-bound_max;
                    y1(J) = [];
                    x1(J) = [];
                    plot(x1,y1,'b');
                    hold on;
                    
                    x2=x;
                    y2 = k*x2+intercept2/-(input(2));
                    % 剔除直线上超出边界的点
                    I = y2>bound_max;
                    y2(I) = [];
                    x2(I) = [];
                    J = y2<-bound_max;
                    y2(J) = [];
                    x2(J) = [];
                    plot(x2,y2,'r');
                    hold on;
                end
                text(bound_max*0.9+-bound_max*0.1,bound_max*0.9-bound_max*0.1,num2str(num),'FontSize',20);
            end
            
        end
        
    end
    

    测试代码
    文件路径:\optimization algorithm\application_svm\Test_Broad.m

    %% 清理之前的数据
    % 清除所有数据
    clear all;
    % 清除窗口输出
    clc;
    
    a = unifrnd(-pi,pi,20,1);
    ra = unifrnd(0,5,20,1);
    b = unifrnd(-pi,pi,20,1);
    rb = unifrnd(0,5,20,1);
    
    data1 = [sin(a(:)).*ra(:)+5,cos(a(:)).*ra(:)+5;];
    data2 = [sin(b(:)).*rb(:)-5,cos(b(:)).*rb(:)-5;];
    
    data_dim = 2;
    model = SVM_Model_Broad(data1,data2);
    model.c = 10;
    range_max = ones(1,data_dim+2);
    range_min = -ones(1,data_dim+2);
    
    %% 添加目录
    % 将上级目录中的frame文件夹加入路径
    addpath('../frame')
    % 引入差分进化算法
    addpath('../algorithm_differential_evolution')
    %% 算法实例
    dim = data_dim+2;
    % 种群数量
    size = 40;
    % 最大迭代次数
    iter_max = 200;
    % 取值范围上界
    range_max_list = range_max;
    % 取值范围下界
    range_min_list = range_min;
    % 实例化差分进化算法类
    base = DE_Impl(dim,size,iter_max,range_min_list,range_max_list);
    base.is_cal_max = true;
    % 确定适应度函数
    base.fitfunction = @model.fit_function;
    % 运行
    base.run();
    disp(['复杂度',num2str(base.cal_fit_num)]);
    
    disp(model.fit_function(base.position_best));
    
    %% 下面绘制动态图
    % 绘制每一代的路径
    for i = 1:length(base.position_best_history)
        model.draw(base.position_best_history(i,:),i);
        % 每0.01绘制一次
        pause = 0.01;
        %下面是保存为GIF的程序
        frame=getframe(gcf);
        % 返回单帧颜色图像
        imind=frame2im(frame);
        % 颜色转换
        [imind,cm] = rgb2ind(imind,256);
        filename = ['svm_broad.gif'];
        if i==1
             imwrite(imind,cm,filename,'gif', 'Loopcount',inf,'DelayTime',1e-4);
        else
             imwrite(imind,cm,filename,'gif','WriteMode','append','DelayTime',pause);
        end
    
        if i <length(base.position_best_history)
            % 如果不是最后一张图就清除窗口
            clf;
        end
    end
    

      出图中可以看出,超平面不一定会经过数据点。

    五. 总结

    这次介绍了如何使用优化算法来优化支持向量机。直接使用了支持向量机的定义作为适应度函数模型,避免了大量的对偶问题转换。同时为了减少计算量,使用了广义的支持向量机,让超平面不必一定经过支持向量,当数据集中数据较多时效果会非常明显。
      文中使用的差分进化算法实现可以看优化算法matlab实现(七)差分进化算法matlab实现。如果想使用其他优化算法,则引入相关的优化算法路径后,实例化即可。

    文件目录如下:

    \optimization algorithm\application_svm\SVM_Model.m
    \optimization algorithm\application_svm\Test.m
    \optimization algorithm\application_svm\SVM_Model_Broad.m
    \optimization algorithm\application_svm\Test_Broad.m
    \optimization_algorithm\frame\Unit.py
    \optimization_algorithm\frame\Algorithm_Impl.py
    \optimization_algorithm\algorithm_differential_evolution\DE_Unit.py
    \optimization_algorithm\algorithm_differential_evolution\DE_Base.py
    \optimization_algorithm\algorithm_differential_evolution\DE_Impl.py

    相关文章

      网友评论

          本文标题:优化算法应用(五)优化支持向量机(SVM)

          本文链接:https://www.haomeiwen.com/subject/riuafdtx.html