美文网首页C语言
C++实现K近邻算法搜索树

C++实现K近邻算法搜索树

作者: 牛顿学计算机 | 来源:发表于2018-09-09 22:04 被阅读74次
  本文不详细叙述K近邻算法的原理,主要提供一种比较快的方法实现k近邻算法的数据结构。看过很多人写的关于这方面的文章,大对数人采取的措施是多次遍历搜索寻找与目标点欧式距离最近的k个点。本文采用kd树数据结构作为实现k近邻算法(参考:统计学习方法),kd树是二叉树,表示对k维空间的一个划分。

    构造kd树算法:

    1. 输入N*M数据集

    2. 从树根先序遍历kd树,假设被遍历结点的深度是j,选择x[l][j mod k + 1](1 <= l <= N, (j mod k) < M)作为本次的切分点。从N个数中选择中位数为x_m,且该结点为x_m对应的行向量。若x[l][j mod k + 1] < x_m则将x[l][1], x[l][2], ......,x[l][M - 2], x[l][M - 1作为子区域]划分到该中位数结点的左边,若x[l][j mod k + 1] >= x_m则将x[l][1], x[l][2], ......,x[l][M - 2], x[l][M - 1]作为子区域划分到该中为数结点的右边。

    3. 直到kd树的所有叶子结点左子区域和右子区域都没有结点,则构成了kd树。

    下面是该算法的代码和该算法运行的结果。
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>

using namespace std;

typedef int ElemType;
const int K = 1;  //k值
const int N = 2;  //维数

class tree {
public:
    tree();
    class tree *new_tree_node();
    void insert_node(vector<vector<ElemType>> all_data, size_t index);
    void create_kdtree(class tree *T);
    bool is_empty();
    vector<vector<ElemType>> get_value();
    class tree *get_left_node();
    class tree *get_right_node();
    class tree *get_parent_node();
    vector<ElemType> get_tree_data();
    void set_left_node(class tree *node);
    void set_right_node(class tree *node);
    void set_parent_node(class tree *node);
    void set_data(vector<vector<ElemType>> value);
    void add_tree_node(vector<ElemType> value);
    void add_lable(vector<ElemType> v, char lable);
    void set_lable(char lable);
    char get_lable();
    map<vector<ElemType>, char> get_data_and_lable();
    void print_kdtree(class tree *T);

private:
    class tree *left;
    class tree *right;
    class tree *parent;
    vector<vector<ElemType>> data;
    vector<ElemType> tree_data;
    map<vector<ElemType>, char> data_and_label;
    char lable;
};

tree::tree() {
    this->left = nullptr;
    this->right = nullptr;
    this->parent = nullptr;
}

class tree *tree::new_tree_node() {
    class tree *node = new class tree;
    node->left = nullptr;
    node->right = nullptr;
    node->parent = nullptr;

    return node;
}

vector<vector<ElemType>> tree::get_value() {
    return this->data;
}

class tree *tree::get_left_node() {
    return this->left;
}

class tree *tree::get_right_node() {
    return this->right;
}

class tree *tree::get_parent_node() {
    return this->parent;
}

void tree::set_left_node(class tree *node) {
    this->left = node;
}

void tree::set_right_node(class tree *node) {
    this->right = node;
}

void tree::set_parent_node(class tree *node) {
    this->parent = node;
}

void tree::set_data(vector<vector<ElemType>> value) {
    this->data = value;
}

void tree::add_tree_node(vector<ElemType> value) {
    this->data.push_back(value);
}

void tree::add_lable(vector<ElemType> v, char lable) {
    this->data_and_label.insert(make_pair(v, lable));
}

void tree::set_lable(char lable) {
    this->lable = lable;
}

char tree::get_lable() {
    return this->lable;
}

map<vector<ElemType>, char> tree::get_data_and_lable() {
    return this->data_and_label;
}

vector<ElemType> tree::get_tree_data() {
    return this->tree_data;
}

void tree::insert_node(vector<vector<ElemType>> all_data, size_t index) {
    if (all_data.size() != 0) {
        map<ElemType, vector<ElemType>> array_map;
        cout << "-----------------" << endl;
        for (int i = 0; i < all_data.size(); i++) {
            array_map.insert(make_pair(all_data[i][index], all_data[i]));
            cout << all_data[i][index] << endl;
        }
        cout << "-----------------" << endl;
        size_t mid = array_map.size() / 2;
        auto begin = array_map.begin();
        auto end = array_map.end();
        auto lable_begin = this->get_data_and_lable().begin();
        class tree *left = nullptr;
        class tree *right = nullptr;
        if (mid > 0) {
            if (this->get_left_node() == nullptr) {
                left = new_tree_node();
            }
            else {
                left = this->get_left_node();
                //left->add_lable(lable_begin->first, lable_begin->second);    //根据键值找到标签
            }
            if (left->data.size() == 0) {
                for (int i = 0; i < mid; i++) {             
                    //left->add_lable(lable_begin->first, lable_begin->second);
                    left->add_tree_node(begin->second);
                    begin++;
                    //lable_begin++;
                }
            }
            else {
                for (int i = 0; i < mid; i++) {
                    begin++;
                    //lable_begin++;
                }
            }
        }
        else {
            this->data.push_back(begin->second);
        }
        this->tree_data = begin->second;
        //this->set_lable(lable_begin->second);
        if (mid > 0) {
            this->set_left_node(left);
            left->set_parent_node(this);
        }
        if (mid < array_map.size()) {
            begin++;
            //lable_begin++;
            if (this->get_right_node() == nullptr) {
                right = new_tree_node();
            }
            else {
                right = this->get_right_node();
            }
            if (right->data.size() == 0) {
                while (begin != array_map.end()) {
                    //right->add_lable(lable_begin->first, lable_begin->second);
                    right->add_tree_node(begin->second);
                    begin++;
                    //lable_begin++;
                }
            }
            else {
                while (begin != array_map.end()) {
                    begin++;
                    //lable_begin++;
                }
            }
            this->set_right_node(right);
            right->set_parent_node(this);
        }
    }
}

size_t index = 0;
void tree::create_kdtree(class tree *T) {
    if (T->data.size() != 0) {
        T->insert_node(T->data, index);
        index++;
        if (index >= K) {
            index = 0;
        }
        if (T->get_left_node() != nullptr && T->get_left_node()->get_value().size() != 0) {
            create_kdtree(T->get_left_node());
        }
        if (T->get_right_node() != nullptr && T->get_right_node()->get_value().size() != 0) {
            create_kdtree(T->get_right_node());
        }
    }
}

void tree::print_kdtree(class tree *T) {
    if (T->tree_data.size() != 0) {
        for (int i = 0; i < T->get_tree_data().size(); i++) {
            vector<ElemType> tree_data = T->get_tree_data();
            cout << tree_data[i] << " ";
        }
        //cout << T->get_lable();
        cout << endl;
        if (T->get_left_node() != nullptr && T->get_left_node()->get_tree_data().size() > 0) {
            print_kdtree(T->get_left_node());
        }
        if (T->get_right_node() != nullptr && T->get_right_node()->get_tree_data().size() > 0) {
            print_kdtree(T->get_right_node());
        }
    }
}

bool tree::is_empty() {
    if (this->data.size() == 0) {
        return true;
    }
    else {
        return false;
    }
}

void read_data_from_file(class tree *T, string file_name) {
    ifstream in(file_name);
    vector<vector<ElemType>> train_data;

    if (!in.is_open()) {
        cout << "can not open " << file_name << endl;
        return;
    }
    while (!in.eof()) {
        vector<ElemType> temp_data;
        //char lable;
        int i = 0;
        while (i < N) {
            ElemType temp = 0;
            in >> temp;
            temp_data.push_back(temp);
            cout << temp << " ";
            i++;
        }
        //in >> lable;
        cout << endl;
        train_data.push_back(temp_data);
        //T->add_lable(temp_data, lable);
    }
    T->set_data(train_data);
}

class KNN {
public:
    void creat_knn();
    class tree *get_kdtree();
    void print_knn();

private:
    class tree KNN_Tree;
};

void KNN::creat_knn() {
    KNN_Tree.create_kdtree(&KNN_Tree);
}

class tree *KNN::get_kdtree() {
    return &(this->KNN_Tree);
}

void KNN::print_knn() {
    this->KNN_Tree.print_kdtree(&KNN_Tree);
}

int main(int argc, char *argv[]) {
    class KNN knn;

    read_data_from_file(knn.get_kdtree(), "data.txt");
    knn.creat_knn();
    knn.print_knn();
    index = 0;
    while (1);

    return 0;
}

输入数据:
2 3
5 4
9 6
4 7
8 1
7 2
3 6
1 3
0 9
6 6

kd树先序遍历输出数据:
5 4
2 3
1 3
0 9
4 7
3 6
8 1
7 2
6 6
9 6

相关文章

网友评论

    本文标题:C++实现K近邻算法搜索树

    本文链接:https://www.haomeiwen.com/subject/wladgftx.html