美文网首页
k近邻算法及其实现

k近邻算法及其实现

作者: 青帝花神 | 来源:发表于2016-07-07 19:12 被阅读0次

    1. KNN (k-Nearest Neighbor)

    k近邻算法是一种基本分类与回归方法。k近邻法假设给定一个训练数据集,其中的实例类别一定。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方法进行预测。因此k近邻算法不具有显式的学习过程。k近邻实际上是利用训练数据集对特征向量空间进行划分,并作为其分类的模型。
    k近邻的三个基本要素是:k值的选择,距离的度量以及分类决策规则。

    1.1 距离的度量

    特征空间中两个实例点的距离是两个实例点相似程度的反映,常见的距离度量有:欧式距离,Lp距离等等(距离度量可以参考这篇博文: 从K近邻算法、距离度量谈到KD树、SIFT+BBF算法 - July_ - 博客园)。不同的距离度量得到的结果可能是不一样的。

    1.2 k值的选择

    如果选择较小的k,就相当于用较小的领域中的训练实例进行预测,只有与输入实例较近的训练实例才会对预测结果起作用,但是这样会导致预测结果对近邻点非常敏感。如果近邻的实例点恰巧是噪声,预测就会出错。也就是说,k值的减少就意味着整体模型变得复杂,容易过拟合。
    如果选择较大的k值,与输入实例较远的(不相似的)训练实例也会对预测起作用,使得预测发生错误。k值的增加意味着整体模型变得简单。

    1.3分类决策规则

    可以选择多数表决规则,甚至加上距离的远近(即把距离当做权重),决定输入实例是哪个类别。

    2.kd树

    实现k近邻算法是,主要考虑的问题是如何对训练数据进行快速k近邻搜索。这点在特征空间的维数大及训练数据容量大时尤其必要。为了提高k近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少距离计算次数。可以采用kd-tree。
    k近邻搜索算法思路如下:
    输入:已构造的kd树:目标点x;(辅助结构,数组)
    输出:x的k近邻
    公共操作P:在访问每个结点时,若数组容量不足k,则将该结点加入数组,若堆容量以达到k,则比较当前节点是否比数组尾元素与x的距离更近,若更近则以当前节点代替数组尾结点,并调整数组。
    (1)从根节点出发,递归地向下访问kd树,若目标x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,知道结点为叶节点为止。执行公共操作P。
    (2)递归的向上回退,在每个节点进行以下操作:
    (a)执行公共操作P。
    (b)检查该子结点的兄弟结点区域是否有比堆顶元素更近的点或堆容量未满。具体的,检查另一子结点对应的区域是否与以目标点为求心,以目标点与堆顶元素距离为半径的球体相交。
    如果相交或容量未满,以另一子结点为根节点执行(1)。
    (4)当回退到根节点时,搜索结束,堆中实例即为所求实例。

    注:前几天刚做完机器学习的大作业,实现了KNN算法,是针对iris数据集的。特此总结

    代码实现:

    代码不友好!!!!
    kd_tree.h
    #include<stdlib.h>
    #include<vector>
    #include<math.h>
    #include<algorithm>
    #include<iostream>
    using namespace std;

    #define  K   4    ////输入数据的维度
    
    class kd_tree_node{
    //成员对象
    public:
      vector<float> node_data;    //存储该节点样本数据
      string node_type;           //是叶节点还是树干(树枝)
      int numpoints;              //训练数据的个数,或者说这个二叉树有多少个节点
      int index;                  //节点数据在原数据中的索引位置
      int splitdim;               //该节点进行分裂是的,选择的分裂维度
      double splitval;            //该节点选择的分裂值
      kd_tree_node* left_node, *right_node,*parents;
    };
    vector<int> median_data(vector<vector<float>>data, vector<int> index, int splitdim_num);//排    序函数,返回排好序的索引序列
    
    / /递归实现创建kd_tree
    kd_tree_node* create_kd_tree(vector<vector<float>>data,int split_dim_num,vector<int>index,kd_tree_node *parent){
    
    //初始化,构造根节点。创建一个节点kd_tree_node;
    kd_tree_node * root = new kd_tree_node;
    root->numpoints = data.size();
    
    //判断结束条件
    if (index.size() == 1){
        //设置成员变量
        root->left_node = NULL;
        root->right_node = NULL;
        root->node_type = "leaf";
        root->splitdim = -1;
        root->splitval = 0;
        root->parents = parent;
        root->node_data = data[index[0]];
        root->index = index[0];
    }
    else{
        //排序,分裂
        index = median_data(data, index, split_dim_num);
        int length = index.size();
        vector<int>left, right;
        for (int i = 0; i < index.size(); i++){
            if (i < length  / 2)
                left.push_back(index[i]);
            else{
                if (i>length/2)
                    right.push_back(index[i]);
            }
        }
        //设置类成员变量
        if (left.size() >= 1){
            root->left_node = create_kd_tree(data, split_dim_num  % K + 1, left, root);
        }
        else
            root->left_node = NULL;
        if (right.size() >= 1){
            root->right_node = create_kd_tree(data, split_dim_num  % K + 1, right, root);
        }
        else
            root->right_node = NULL;
        root->node_type = "body";
        root->splitdim = split_dim_num;
        root->splitval = data[index[length/2]][split_dim_num - 1];  //(<)
        root->parents = parent;
        root->node_data = data[index[length/2]];
        root->index = index[length / 2];
    }
    return root;
    

    }

    //排序函数,返回排好序的索引序列
    vector<int> median_data(vector<vector<float>>data, vector<int> index, int splitdim_num){
    vector<float>temp;
    int length = index.size();
    for (int i = 0; i < length; i++){
        temp.push_back(data[index[i]][splitdim_num - 1]);
    }
    //升序排序,冒泡法
    int index_temp = 0;
    float a = 0;
    
    for (int i = 0; i < length - 1; i++){
        for (int j = 0; j < length -i- 1; j++){
            if (temp[j]>temp[j + 1]){
                a = temp[j + 1];
                temp[j + 1] = temp[j];
                temp[j] = a;
    
                index_temp = index[j + 1];
                index[j + 1] = index[j];
                index[j] = index_temp;
            }
        }
    }
    return index;
    

    }

    //k-近邻搜索算法
    /*****公共操作P:在访问每个结点时,若最大堆容量不足k,则将该结点加入最大堆,若堆容量以达到k,则    比较当前节点是否比堆顶元素与x的距离更近,若更近则以当前节点代替堆顶结点,并调整堆。
    (1)从根节点出发,递归地向下访问kd树,若目标x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,知道结点为叶节点为止。执行公共操作P。
    (2)递归的向上回退,在每个节点进行以下操作:
    (a)执行公共操作P。
    (b)检查该子结点的兄弟结点区域是否有比堆顶元素更近的点或堆容量未满。具体的,检查另一子结点对应的区域是否与以目标点为求心,以目标点与堆顶元素距离为半径的球体相交。
    如果相交或容量未满,以另一子结点为根节点执行(1)。
    (4)当回退到根节点时,搜索结束,堆中实例即为所求实例。
      ****/
      /*************
    function: knn_k_search()
    
    input:
      test_data:测试数据
      near_num:需要寻找几个近邻元素,near_num
      root:kd树的根节点
    
    output: 返回找到原数据中near_num个近邻点在原数据中的index(索引)数组。
    *************/
    vector<int> knn_k_search(vector<float>test_data, int near_num, kd_tree_node *root){
      vector<int> near_k_node_index(0);           //记录下k个近邻点的索引
      vector<double>near_k_nodedist(0);           //记录下k个紧邻点的距离
      vector<kd_tree_node*> near_k_nodepoint;     //记录下k个近邻点的kd_tree指针
    
    if (root->numpoints < near_num){
        cout << "do not have enough points" << endl;
        return near_k_node_index;
    }
    
    //首先找到叶节点,并记录下搜索的路径
    kd_tree_node * leaf_node = NULL;
    int split_dim = 1;
    leaf_node = root;
    vector<kd_tree_node*>path;
    path.push_back(leaf_node);
    while (leaf_node->node_type != "leaf"){
        split_dim = leaf_node->splitdim;
        if (test_data[split_dim - 1] <= leaf_node->splitval)//分裂
            leaf_node = leaf_node->left_node;
        else{
            if (leaf_node->right_node == NULL)
                leaf_node = leaf_node->left_node;//如果只有左子树,那么叶节点就选是左子树
            else
                leaf_node = leaf_node->right_node;
        }
        path.push_back(leaf_node);
    }
    path.pop_back();
    
    //copy一份路径
    vector<kd_tree_node*>path_copy = path;
    
    //k近邻搜索,回溯,,找到K个最接近给定测试数据的样本,统计出现频率
    //计算两点之间的距离,从叶子节点开始;  
    
    //test_data所在的叶节点指针一直存储在leaf_node中
    double dist1 = 0, max_dist = 0;
    
    //计算距离
    for (int i = 0; i < test_data.size(); i++){
        dist1 += (leaf_node->node_data[i] - test_data[i])*(leaf_node->node_data[i] - test_data[i]);
    }
    dist1 = sqrt(dist1);
    max_dist = dist1;
    
    //压入数据
    near_k_nodepoint.push_back(leaf_node);
    near_k_nodedist.push_back(max_dist);
    
    //定义一个指针,该值针,指向上一个分支。
    kd_tree_node * rl_node = leaf_node;  //也就是表示该分支已经被访问过了
    while (path.size() != 0){
    
        //回溯到父节点(不一定是父节点,是搜索队列中,栈顶元素)
        kd_tree_node *back_point = path[path.size() - 1];
        path.pop_back();
        int split_s = back_point->splitdim - 1; 
        double dist2 = 0;
        for (int i = 0; i < test_data.size(); i++){
            dist2 += (back_point->node_data[i] - test_data[i])*(back_point->node_data[i] - test_data[i]);
        }
        dist2 = sqrt(dist2);
    
        //判断是否加入队列,两个:队列是否已满?未满直接加入,更新最大距离,已满的话判断是否大于最大距离
        if (near_k_nodepoint.size() == near_num && dist2 < max_dist)//队列已满,且小于最大距离
        {
            near_k_nodepoint.pop_back();
            near_k_nodedist.pop_back();
            //此时队列是不满的
        }
    
        if (near_k_nodepoint.size() < near_num)//如果队列未满的话,压入队列
        {
            if (near_k_nodepoint.size() == 0){   // 当队列为空时
                near_k_nodepoint.push_back(back_point);
                near_k_nodedist.push_back(dist2);
                max_dist = dist2;
            }
            else{
                int i = 0;
                while (dist2>near_k_nodedist[i]){
                    i++;
                    if (i == near_k_nodepoint.size())
                        break;
                }
                //更新最大距离
                max_dist = near_k_nodedist[near_k_nodedist.size() - 1];
                if (i == near_k_nodepoint.size())
                    max_dist = dist2;
                //插入对i之前,对near_k_nodepoint和near_k_nodepoint;
                near_k_nodepoint.insert(near_k_nodepoint.begin() + i, back_point);
                near_k_nodedist.insert(near_k_nodedist.begin() + i, dist2);
            }
    
        }
    
        if (back_point->node_type == "leaf"){
            continue;//到达叶节点就继续下一轮
        }
        double dist3 = abs(test_data[split_s] - back_point->node_data[split_s]);
        //判断是否需要进入另一个分支
        if (near_k_nodepoint.size() < near_num || (dist3<max_dist)){
            
            //判断back_point 是否是test_data搜索路径中某个节点
            bool flag = false;
            for (int i = path_copy.size()-1; i >=0; i--){
                if (back_point == path_copy[i]){
                    flag = true;
                }
            }
            if (flag){
                double flag = test_data[split_s];
                double flag2 = back_point->node_data[split_s];
                if (flag <= flag2){
                    if (back_point->right_node != NULL)
                        back_point = back_point->right_node;//可能只有左子树,//如果只有左子树,那么叶节点就选是左子树
                    else
                        back_point = back_point->left_node;
                }
                else{
                    back_point = back_point->left_node;
                }
                path.push_back(back_point);
            }
            else{
                if (back_point->right_node != NULL)  //右节点压入栈中
                    path.push_back(back_point->right_node);
                if (back_point->left_node != NULL)   //左节点压入栈中
                    path.push_back(back_point->left_node);
            }
        }
    }
    //返回索引向量
    for (int i = 0; i < near_k_nodepoint.size(); i++){
        near_k_node_index.push_back(near_k_nodepoint[i]->index);
    }
    return near_k_node_index;
    

    }

    knn.cpp

    #include"kd_tree.h"
    #include<fstream>
    #include<string>
    
    #define  label_type  3   //有三种样本
    
    using namespace std;
    string iris_name[label_type] = {"Iris-setosa","Iris-versicolor","Iris-virginica"}; //三种iris花的名字
    
    void main(){
    
    //读取数据阶段
    /**数据分为train.txt和test.txt
        每个数据有五个分量,最后一个分量是样本所属的类型
        读得数据分别存储在data和label里,分为train_data,train_label.
        分隔符是空格符
    **/
    //训练数据
    string train_file = "train2.txt";
    ifstream ist(train_file.c_str());
    vector<vector<float>>train_data;
    vector<int>train_label;
    while (!ist.eof()){
        vector<float> single_data;
        for (int i = 0; i < K; i++){
            float temp = 0;
            ist >> temp;
            single_data.push_back(temp);
        }
        int label = 0;
        ist >> label;
        train_label.push_back(label);
        train_data.push_back(single_data);
        single_data.resize(0);
    }
    ist.close();
    
    //测试数据
    string test_file = "test2.txt";
    ifstream ist2(test_file.c_str());
    vector<vector<float>>test_data;
    vector<int>test_label;
    while (!ist2.eof()){
        vector<float> single_data;
        for (int i = 0; i < K; i++){
            float temp = 0;
            ist2 >> temp;
            single_data.push_back(temp);
        }
        int label = 0;
        ist2 >> label;
        test_label.push_back(label);
        test_data.push_back(single_data);
        single_data.resize(0);
    }
    ist2.close();
    
    int NUM = 0;                 //NUM是K近邻的所选取的近邻点的数目
    for (NUM = 1; NUM < 121; NUM++){
        //创建kd树
        kd_tree_node *iris_kd_tree = NULL;
        int numpoints = train_label.size();
        vector<int>index;
        for (int i = 0; i < numpoints; i++){
            index.push_back(i);
        }
        iris_kd_tree = create_kd_tree(train_data, 1, index, NULL);  //根据训练数据创建kd_tree
    
        //测试样本的准确率
        int sum_num[label_type];     //各类样本的总数
        int right_num[label_type];   //各类样本的正确判断数目
        int error_num[label_type];   //各类样本的错误识别率
    
        //初始化
        for (int i = 0; i < label_type; i++){
            sum_num[i] = 0;
            right_num[i] = 0;
            error_num[i] = 0;
        }
    
        //k近邻搜索,判断样本类型
        for (int i = 0; i < test_label.size(); i++){
            vector<int>k_index;
            vector<int>count_label;
            for (int j = 0; j < label_type; j++){
                count_label.push_back(0);
            }
            k_index = knn_k_search(test_data[i], NUM, iris_kd_tree);//k近邻搜索
    
            for (int j = 0; j < k_index.size(); j++){
                int flag = train_label[k_index[j]];
                count_label[flag]++;   //统计k近邻各类样本出现的次数
            }
            int max = count_label[0];
            int label_flag = 0;
            for (int j = 1; j < label_type; j++){
                if (max < count_label[j]){
                    max = count_label[j];
                    label_flag = j;
                }
            }
            
            if (label_flag == test_label[i]){
                right_num[test_label[i]]++;
            }
            else{
                error_num[label_flag]++;
            }
            sum_num[test_label[i]]++;
        }
    
        //统计结果,并打印出结果
        int sum = 0;
        int error = 0;
        for (int i = 0; i < label_type; i++){
            sum += sum_num[i];
            error += error_num[i];
        }
        cout << NUM << ":" << endl;
        for (int i = 0; i < label_type; i++){
            cout << iris_name[i] << "测试样本总数为:" << sum_num[i] << ",正确率为:" << right_num[i] / (sum_num[i] * 1.0) << ",错误识别为该样本的数目为:" << error_num[i] << endl;
        }
        cout << "总的正确率为:" << 1-error*1.0/sum<<endl;
        cout << endl;
    }
    //画出kd_树(选做)
    system("pause");
    

    }

    相关文章

      网友评论

          本文标题:k近邻算法及其实现

          本文链接:https://www.haomeiwen.com/subject/qxpljttx.html