1、k-近邻算法介绍
k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法。它的工作原理是:存在一个样本数据集合,也称作为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一个数据与所属分类的对应关系。输入没有标签的新数据后,将新的数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
1. K-近邻算法.png举个例子,使用k-近邻算法分类一个电影是爱情片还是动作片。
每部电影的打斗镜头数、接吻镜头数以及电影类型首先使用欧氏距离计算未知电影与样本集中其他电影的距离。
已知电影与未知电影的距离将距离列表排序,选出前k个最相似的样本。此处我们假设k=3,将上表中的距离进行排序后前3分别是:He’s Not Really into Dudes,Beautiful Woman,California Man。统计最相似样本的分类,此处很容易知道这3个样本均为爱情片。将分类最多的类别作为未知电影的分类。那么我们就得出结论,未知电影属于爱情片。
欧氏距离公式2、Python3.6 代码实现
算法一般流程(1)准备数据集
# -*- coding: UTF-8 -*-
from numpy import *
from numpy import array
# 操作符包
import operator
def creatDataSet():
"""
Function:
创建训练数据集
Parameters:
无
Returns:
group - 数据集
labels - 分类标签
Modify:
2018-07-23
"""
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
lables = ['A','A','B','B']
return group, lables
运行结果
(2)k-近邻算法
def classify0(inX,dataSet,labels,k):
"""
Function:
kNN算法
Parameters:
inX - 测试集
dataSet - 训练集
labes - 分类标签
k - 选择距离最小k个点
Returns:
sortedClassCount[0][0] - 分类结果
Modify:
2018-07-23
"""
dataSetSize = dataSet.shape[0]
# tile(A, reps)返回一个shape=reps的矩阵,矩阵的每个元素是A
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistance = sqDiffMat.sum(axis=1)
distance = sqrt(sqDistance)
# argsort()返回数组值从小到大的索引值
sortedDistIndicies = distance.argsort()
classCount = {}
for i in range(k):
voteIlable = labels[sortedDistIndicies[i]]
# dict.get(key,default=None),返回指定键的值,如果值不在字典中返回默认值。
classCount[voteIlable] = classCount.get(voteIlable, 0) + 1
# dict.items()以列表返回可遍历的(键, 值)元组数组
# 使用字典的第二个元素进行降序排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# 返回次数最多的类别
return sortedClassCount[0][0]
kNN算法运行结果
3、示例:k-近邻算法实战之约会网站配对效果
海伦收集约会数据存放在文本文件datingTestSet.txt中,有1000行的约会数据,样本主要包括以下3种特征:
- 每年获得的飞行常客里程数
- 玩视频游戏所耗时间百分比
- 每周消费的冰淇淋公升数
她约会过的对象可以进行如下分类:
- 不喜欢的人 didntLike
- 魅力一般的人 smallDoses
- 极具魅力的人 largeDoses
实现代码
def file2Matrix(filename):
"""
Function:
打开并解析文件
Parameters:
filename - 文件名
Returns:
returnMat - 特征矩阵
classLabelVector - 分类Label向量
Modify:
2018-07-23
"""
f = open(filename, 'r')
# 读取文件所有内容
arrayOlines = f.readlines()
numberOlines = len(arrayOlines)
returntMat = zeros((numberOlines, 3))
classLabelVector = []
index = 0
for line in arrayOlines:
# strip(rm)删除所有rm字符。当rm空时默认删除空白符(包括'\n', '\r', '\t', ' ')
# split(str)将字符串根据str分隔符进行切片,返回列表。
listFromLine = line.strip().split('\t')
returntMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returntMat, classLabelVector
def autoNorm(dataSet):
"""
Function:
归一化数据
Parameters:
dataSet - 特征矩阵
Returns:
normDataSet - 归一化后的特征矩阵
ranges - 数据范围
minVals - 数据集每列最小值
Modify:
2018-07-23
"""
# 获取数据集中每一列的最小数值
minVals = dataSet.min(0)
# 获取数据集中每一列的最大数值
maxVals = dataSet.max(0)
ranges = maxVals - minVals
m = shape(dataSet)[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1))
return normDataSet, ranges, minVals
def datingClassTest():
"""
Function:
分类器测试函数
Parameters:
无
Returns:
normDataSet - 归一化后的特征矩阵
ranges - 数据范围
minVals - 数据集每列最小值
Modify:
2018-07-23
"""
# 设定测试样本占比
hoRatio = 0.1
datingDataMat, datingLabels = file2Matrix( 'D:/PyCharm WorkPlace/machinelearninginaction/Ch02/datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
# 测试样本数量
numTestVecs = int(m * hoRatio)
errorCount = 0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print("the classifier came back with: %d, the real answer is: %d, result is :%s" % (
classifierResult, datingLabels[i], classifierResult == datingLabels[i]))
if (classifierResult != datingLabels[i]):
errorCount += 1.0
print("the total error rate is: %f" % (errorCount / float(numTestVecs)))
def classifyPerson():
"""
函数说明:
构建可手动输入系统,通过输入一个人的三维特征进行分类输出
Parameters:
无
Returns:
无
Modify:
2018-07-23
"""
# 定义预测结果
resultList = ['didntLike', 'smallDoses', 'largeDoses']
# input()接收任意任性输入,将所有输入默认为字符串处理,并返回字符串类型。
percentTats = float(input( "percentage of time spent playing video games?"))
ffMiles = float(input("frequent filer miles earned per year?"))
iceCream = float(input("liters of ice cream consumed per year?"))
# 将输入的数值放在数组中
inArr = array([ffMiles, percentTats, iceCream])
datingDataMat, datingLabels = file2Matrix('D:/PyCharm WorkPlace/machinelearninginaction/Ch02/datingTestSet2.txt')
normMat, ranges, minValues = autoNorm(datingDataMat)
classifierResult = classify0((inArr - minValues) / ranges, normMat, datingLabels, 3)
print("you will probably like this person:", resultList[classifierResult - 1])
数据可视化结果
分类器测试函数运行结果
手动输入三维特征运行结果
4、示例:手写识别数字
def img2Vector(filename):
"""
Function:
将32x32的二进制图像转换为1x1024向量。
Parameters:
filename - 文件名
Returns:
returnVect - 返回的二进制图像的1x1024向量
Modify:
2018-07-23
"""
returnVect = zeros((1, 1024))
f = open(filename, 'r')
for i in range(32):
lineStr = f.readline()
for j in range(32):
returnVect[0,32 * i + j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
"""
Function:
手写识别数字分类测试
Parameters:
无
Returns:
无
Modify:
2018-07-23
"""
hwLabels = []
trainingFileDir = 'D:/PyCharm WorkPlace/machinelearninginaction/Ch02/digits/trainingDigits'
testFileDir = 'D:/PyCharm WorkPlace/machinelearninginaction/Ch02/digits/testDigits'
trainingFileList = listdir(trainingFileDir)
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2Vector(trainingFileDir + '/' + fileNameStr)
testFileList = listdir(testFileDir)
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2Vector(testFileDir + '/' + fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifierResult came back with: %d,the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr ): errorCount += 1.0
print("the total number of errors is: %d" % errorCount)
print("the total error rate is: %f" % (errorCount / float(mTest)))
手写识别数字分类测试运行结果
5、应用scikit-learn库实现手写识别数字
代码与之前类似
def img2Vector(filename):
"""
Function:
将32x32的二进制图像转换为1x1024向量。
Parameters:
filename - 文件名
Returns:
returnVect - 返回的二进制图像的1x1024向量
Modify:
2018-07-23
"""
returnVect = zeros((1, 1024))
f = open(filename, 'r')
for i in range(32):
lineStr = f.readline()
for j in range(32):
returnVect[0,32 * i + j] = int(lineStr[j])
return returnVect
def SklearnHandwritingClassTestS():
"""
Function:
手写识别数字分类测试
Parameters:
无
Returns:
无
Modify:
2018-07-23
"""
hwLabels = []
trainingFileDir = 'D:/PyCharm WorkPlace/machinelearninginaction/Ch02/digits/trainingDigits'
testFileDir = 'D:/PyCharm WorkPlace/machinelearninginaction/Ch02/digits/testDigits'
trainingFileList = listdir(trainingFileDir)
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2Vector(trainingFileDir + '/' + fileNameStr)
neigh = kNN(n_neighbors=3, algorithm='auto')
neigh.fit(trainingMat, hwLabels)
testFileList = listdir(testFileDir)
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2Vector(testFileDir + '/' + fileNameStr)
# classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
classifierResult = neigh.predict(vectorUnderTest)
print("the classifierResult came back with: %d,the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr ): errorCount += 1.0
print("the total number of errors is: %d" % errorCount)
print("the total error rate is: %f" % (errorCount / float(mTest)))
sklearn运行结果
6、小结
k近邻算法是分类数据最简单最有效的算法。k紧邻算法必须保存全部数据集,如果训练数据集很大,必须使用大量的存储空间。此外,由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。k近邻算法的另外一个缺陷是它无法给出任何数据的基础结构信息,因此我们也无法知晓平均实例样本和典型实例样本具有什么特征。
网友评论