本文不详细叙述K近邻算法的原理,主要提供一种比较快的方法实现k近邻算法的数据结构。看过很多人写的关于这方面的文章,大对数人采取的措施是多次遍历搜索寻找与目标点欧式距离最近的k个点。本文采用kd树数据结构作为实现k近邻算法(参考:统计学习方法),kd树是二叉树,表示对k维空间的一个划分。
构造kd树算法:
1. 输入N*M数据集
2. 从树根先序遍历kd树,假设被遍历结点的深度是j,选择x[l][j mod k + 1](1 <= l <= N, (j mod k) < M)作为本次的切分点。从N个数中选择中位数为x_m,且该结点为x_m对应的行向量。若x[l][j mod k + 1] < x_m则将x[l][1], x[l][2], ......,x[l][M - 2], x[l][M - 1作为子区域]划分到该中位数结点的左边,若x[l][j mod k + 1] >= x_m则将x[l][1], x[l][2], ......,x[l][M - 2], x[l][M - 1]作为子区域划分到该中为数结点的右边。
3. 直到kd树的所有叶子结点左子区域和右子区域都没有结点,则构成了kd树。
下面是该算法的代码和该算法运行的结果。
#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>
using namespace std;
typedef int ElemType;
const int K = 1; //k值
const int N = 2; //维数
class tree {
public:
tree();
class tree *new_tree_node();
void insert_node(vector<vector<ElemType>> all_data, size_t index);
void create_kdtree(class tree *T);
bool is_empty();
vector<vector<ElemType>> get_value();
class tree *get_left_node();
class tree *get_right_node();
class tree *get_parent_node();
vector<ElemType> get_tree_data();
void set_left_node(class tree *node);
void set_right_node(class tree *node);
void set_parent_node(class tree *node);
void set_data(vector<vector<ElemType>> value);
void add_tree_node(vector<ElemType> value);
void add_lable(vector<ElemType> v, char lable);
void set_lable(char lable);
char get_lable();
map<vector<ElemType>, char> get_data_and_lable();
void print_kdtree(class tree *T);
private:
class tree *left;
class tree *right;
class tree *parent;
vector<vector<ElemType>> data;
vector<ElemType> tree_data;
map<vector<ElemType>, char> data_and_label;
char lable;
};
tree::tree() {
this->left = nullptr;
this->right = nullptr;
this->parent = nullptr;
}
class tree *tree::new_tree_node() {
class tree *node = new class tree;
node->left = nullptr;
node->right = nullptr;
node->parent = nullptr;
return node;
}
vector<vector<ElemType>> tree::get_value() {
return this->data;
}
class tree *tree::get_left_node() {
return this->left;
}
class tree *tree::get_right_node() {
return this->right;
}
class tree *tree::get_parent_node() {
return this->parent;
}
void tree::set_left_node(class tree *node) {
this->left = node;
}
void tree::set_right_node(class tree *node) {
this->right = node;
}
void tree::set_parent_node(class tree *node) {
this->parent = node;
}
void tree::set_data(vector<vector<ElemType>> value) {
this->data = value;
}
void tree::add_tree_node(vector<ElemType> value) {
this->data.push_back(value);
}
void tree::add_lable(vector<ElemType> v, char lable) {
this->data_and_label.insert(make_pair(v, lable));
}
void tree::set_lable(char lable) {
this->lable = lable;
}
char tree::get_lable() {
return this->lable;
}
map<vector<ElemType>, char> tree::get_data_and_lable() {
return this->data_and_label;
}
vector<ElemType> tree::get_tree_data() {
return this->tree_data;
}
void tree::insert_node(vector<vector<ElemType>> all_data, size_t index) {
if (all_data.size() != 0) {
map<ElemType, vector<ElemType>> array_map;
cout << "-----------------" << endl;
for (int i = 0; i < all_data.size(); i++) {
array_map.insert(make_pair(all_data[i][index], all_data[i]));
cout << all_data[i][index] << endl;
}
cout << "-----------------" << endl;
size_t mid = array_map.size() / 2;
auto begin = array_map.begin();
auto end = array_map.end();
auto lable_begin = this->get_data_and_lable().begin();
class tree *left = nullptr;
class tree *right = nullptr;
if (mid > 0) {
if (this->get_left_node() == nullptr) {
left = new_tree_node();
}
else {
left = this->get_left_node();
//left->add_lable(lable_begin->first, lable_begin->second); //根据键值找到标签
}
if (left->data.size() == 0) {
for (int i = 0; i < mid; i++) {
//left->add_lable(lable_begin->first, lable_begin->second);
left->add_tree_node(begin->second);
begin++;
//lable_begin++;
}
}
else {
for (int i = 0; i < mid; i++) {
begin++;
//lable_begin++;
}
}
}
else {
this->data.push_back(begin->second);
}
this->tree_data = begin->second;
//this->set_lable(lable_begin->second);
if (mid > 0) {
this->set_left_node(left);
left->set_parent_node(this);
}
if (mid < array_map.size()) {
begin++;
//lable_begin++;
if (this->get_right_node() == nullptr) {
right = new_tree_node();
}
else {
right = this->get_right_node();
}
if (right->data.size() == 0) {
while (begin != array_map.end()) {
//right->add_lable(lable_begin->first, lable_begin->second);
right->add_tree_node(begin->second);
begin++;
//lable_begin++;
}
}
else {
while (begin != array_map.end()) {
begin++;
//lable_begin++;
}
}
this->set_right_node(right);
right->set_parent_node(this);
}
}
}
size_t index = 0;
void tree::create_kdtree(class tree *T) {
if (T->data.size() != 0) {
T->insert_node(T->data, index);
index++;
if (index >= K) {
index = 0;
}
if (T->get_left_node() != nullptr && T->get_left_node()->get_value().size() != 0) {
create_kdtree(T->get_left_node());
}
if (T->get_right_node() != nullptr && T->get_right_node()->get_value().size() != 0) {
create_kdtree(T->get_right_node());
}
}
}
void tree::print_kdtree(class tree *T) {
if (T->tree_data.size() != 0) {
for (int i = 0; i < T->get_tree_data().size(); i++) {
vector<ElemType> tree_data = T->get_tree_data();
cout << tree_data[i] << " ";
}
//cout << T->get_lable();
cout << endl;
if (T->get_left_node() != nullptr && T->get_left_node()->get_tree_data().size() > 0) {
print_kdtree(T->get_left_node());
}
if (T->get_right_node() != nullptr && T->get_right_node()->get_tree_data().size() > 0) {
print_kdtree(T->get_right_node());
}
}
}
bool tree::is_empty() {
if (this->data.size() == 0) {
return true;
}
else {
return false;
}
}
void read_data_from_file(class tree *T, string file_name) {
ifstream in(file_name);
vector<vector<ElemType>> train_data;
if (!in.is_open()) {
cout << "can not open " << file_name << endl;
return;
}
while (!in.eof()) {
vector<ElemType> temp_data;
//char lable;
int i = 0;
while (i < N) {
ElemType temp = 0;
in >> temp;
temp_data.push_back(temp);
cout << temp << " ";
i++;
}
//in >> lable;
cout << endl;
train_data.push_back(temp_data);
//T->add_lable(temp_data, lable);
}
T->set_data(train_data);
}
class KNN {
public:
void creat_knn();
class tree *get_kdtree();
void print_knn();
private:
class tree KNN_Tree;
};
void KNN::creat_knn() {
KNN_Tree.create_kdtree(&KNN_Tree);
}
class tree *KNN::get_kdtree() {
return &(this->KNN_Tree);
}
void KNN::print_knn() {
this->KNN_Tree.print_kdtree(&KNN_Tree);
}
int main(int argc, char *argv[]) {
class KNN knn;
read_data_from_file(knn.get_kdtree(), "data.txt");
knn.creat_knn();
knn.print_knn();
index = 0;
while (1);
return 0;
}
输入数据:
2 3
5 4
9 6
4 7
8 1
7 2
3 6
1 3
0 9
6 6
kd树先序遍历输出数据:
5 4
2 3
1 3
0 9
4 7
3 6
8 1
7 2
6 6
9 6
网友评论