前言
众所周知,电影可以按照题材分类,那么如何判定某部电影属于某个题材呢?同一题材的电影具有那些公共特征?这是需要在分类电影时考虑的问题。
那么动作片具有哪些公共特征?又与爱情片存在哪些明显差别呢?动作片也存在亲吻镜头,爱情片也存在打斗镜头,所以不能单纯从是否存在打斗或者亲吻镜头来判断。但爱情片亲吻镜头比动作片多,同样动作片打斗镜头更多。下面将基于电影中出现的亲吻、打斗次数,使用kNN(k-Nearest-Neighbor,k-近邻) 算法划分电影题材类型。
kNN容易理解掌握,本文首先使用电影分类讲解kNN算法的基本概念,了解其基本理论。在最后会利用实际例子讲解如何用kNN改进约会配对和手写数字识别。
kNN概述
- 优点:精度高、对异常值不敏感、无数据输入假定
- 缺点:计算复杂度高、空间复杂度高
- 适用数据范围:数值型和标称型
kNN算法通过测量不同特征值之间的距离来进行分类。
工作原理:存在一个已知类别的训练集(即每个数据都存在标签)。输入没有标签的新数据,将新数据的每个特征与训练集每个数据对应的特征进行比较,计算距离,选择前k个最相似的数据(k-近邻中k的出处),提取分类标签,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
用中国俗语解释就是“近朱者赤近墨者黑”。
现在回到电影分类的例子,有人曾统计过很多电影的打斗镜头和亲吻镜头,如下表显示了6部电影的打斗和亲吻镜头数,还有一部只知道镜头数的未知电影。
电影名称 | 打斗镜头 | 亲吻镜头 | 电影类型 |
---|---|---|---|
California Man | 3 | 104 | 爱情片 |
He's Not Really into Dudes | 2 | 100 | 爱情片 |
Beautiful Woman | 1 | 81 | 爱情片 |
Kevin Longblade | 101 | 10 | 动作片 |
Robo Slayer 30000 | 99 | 5 | 动作片 |
Amped II | 98 | 2 | 动作片 |
? | 18 | 90 | 未知 |
接下来,我们将使用kNN来对未知电影进行分类。首先计算未知电影与其他电影的距离,如下表。
电影名称 | 与未知电影的距离 |
---|---|
California Man | 20.5 |
He's Not Really into Dudes | 18.7 |
Beautiful Woman | 19.2 |
Kevin Longblade | 115.3 |
Robo Slayer 30000 | 117.4 |
Amped II | 118.9 |
得到样本集中所有电影与未知电影的距离后,按距离递增排序,找到k个距离最近的电影。假定k=3,则三个最接近的电影是California Man、He's Not Really into Dudes和Beautiful Woman,而这三部电影都是爱情片,因此我们判定未知电影类型是爱情片。
现在我们大概了解了kNN的工作原理。接下来就是代码实践了。
kNN算法代码实现
import numpy as np
import operator
def classify0(inX, dataSet, labels, k):
'''
inX:未知分类的输入向量
dataSet:训练样本集
labels:训练集的标签向量
k:最近邻居的数目
'''
dataSetSize = dataSet.shape[0]
# 距离计算
# tile:重复inX
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis = 1)
distances = sqDistances ** 0.5
# 排序,返回排序后index
sortedDistIndicies = distances.argsort()
classCount = {}
# 选择距离最小的k个点
for i in range(k):
voteIlabel = labels[sortedDistIndicies[I]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 根据出现次数排序
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
# 返回出现频率最高的标签
return sortedClassCount[0][0]
代码中的距离计算使用欧式距离公式,计算A和B两个点之间的距离:
接下来就是应用这个kNN分类器,看看它的分类效果如何。
示例:约会网站的配对
样本主要有三个特征:
- 每年获得的飞行常客里程数
- 玩视频游戏所耗时间百分比
- 每周消费的冰淇淋公升数
import numpy as np
# 加载数据
def loadDataSet(filename):
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines) # 数据行数
returnMat = np.zeros((numberOfLines, 3))
classLabelVector = []
index = 0
# 解析文件数据
for line in arrayOLines:
line = line.strip().split('\t')
returnMat[index, :] = line[0:3]
classLabelVector.append(int(line[-1]))
index += 1
return returnMat, classLabelVector
使用matplotlib创建散点图
# 数据展示
import matplotlib.pyplot as plt
# 解决中文图例乱码
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
datingDataMat, datingLabels = loadDataSet('datingTestSet2.txt')
x1, y1 = [],[]
x2, y2 = [],[]
x3, y3 = [],[]
for i in range(datingDataMat.shape[0]):
if datingLabels[i] == 1:
x1.append(datingDataMat[i,0])
y1.append(datingDataMat[i,1])
elif datingLabels[i] == 2:
x2.append(datingDataMat[i,0])
y2.append(datingDataMat[i,1])
else:
x3.append(datingDataMat[i,0])
y3.append(datingDataMat[i,1])
fig = plt.figure()
ax = fig.add_subplot(111)
type1 = ax.scatter(x1, y1, s=20, c='red')
type2 = ax.scatter(x2, y2, s=30, c='green')
type3 = ax.scatter(x3, y3, s=50, c='blue')
ax.legend([type1, type2, type3], ["不喜欢", "魅力一般", "极具魅力"], loc=2)
ax.axis([-5000,100000,-2,25])
plt.xlabel('每年获取的飞行常客里程数')
plt.ylabel('玩视频游戏所耗时间百分比')
plt.show()

归一化数据
下表给出样本集中的两组数据。
每年获得的飞行常客里程数 | 玩视频游戏所耗时间百分比 | 每周消费的冰淇淋公升数 | 样本分类 |
---|---|---|---|
5914 | 2.216246 | 0.587095 | 2 |
14851 | 14.305636 | 0.632317 | 3 |
两组数据距离的计算式为
我们很容易发现,上述式子中数字差值最大的属性对计算结果影响最大,但实际上这三种特征是同等重要的。在处理这种不同取值范围的特征值时,我们通常是将数值归一化,如将取值范围处理为0到1或者-1到1之间。
import numpy as np
# 将数值范围转化为0到1
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = np.zeros(dataSet.shape)
m = dataSet.shape[0]
normDataSet = dataSet - np.tile(minVals, (m,1))
normDataSet = normDataSet/np.tile(ranges, (m,1))
return normDataSet, ranges, minVals
测试算法
def datingClassTest():
# 用于划分测试集和训练集的比例
hoRatio = 0.1
datingDataMat, datingLabels = loadDataSet('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
# 测试集数目
numTestVecs = int(m * hoRatio)
errorCount = 0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
if classifierResult != datingLabels[I]:
errorCount += 1
print("分类器返回结果:", classifierResult, "实际结果:", datingLabels[I])
print("错误率为:", (errorCount/float(numTestVecs)))
datingClassTest()
首先使用loadDataSet
和autoNorm
从文件中读取数据并进行归一化。然后计算测试集的数量,将测试集和训练集输入到kNN分类器classify0
。最后计算错误率并输出结果。执行结果如下。
分类器返回结果: 3 实际结果: 2
分类器返回结果: 3 实际结果: 1
分类器返回结果: 3 实际结果: 1
分类器返回结果: 2 实际结果: 3
分类器返回结果: 3 实际结果: 1
错误率为: 0.05
示例:手写识别系统
为了简单起见,这里只能识别数字0到9。

将图像转换为测试向量
实际图像存储在两个目录内,目录trainingDigits包含大约2000个例子,每个例子如上图所示;目录testDigits包含900个测试数据。
为了方便计算,需要将32x32的二进制图像矩阵转换成1x1024的向量。
import numpy as np
# 将图像转换为向量
def img2vector(filename):
returnVect = np.zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
测试算法:使用kNN识别手写数字
import numpy as np
from os import listdir
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') # 获取目录内容
m = len(trainingFileList) # 文件数
trainingMat = np.zeros((m,1024))
for i in range(m):
# 从文件名中解析分类数字
fileNameStr = trainingFileList[I]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits')
errorCount = 0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[I]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
if (classifierResult != classNumStr):
errorCount += 1
print("分类器返回结果:", classifierResult, "实际结果:", classNumStr)
print('总错误数:', errorCount)
print('错误率:', errorCount/float(mTest))
所有文件按照规则命名,如文件7_34.txt的分类是7,是数字7的第34个实例,可以利用这个规则提取分类数字,然后存储在hwLabels
中。
手写识别这里不需要用到归一化,因为值只有0和1。
执行代码,可以得到如下结果:
分类器返回结果: 1 实际结果: 8
分类器返回结果: 3 实际结果: 8
分类器返回结果: 7 实际结果: 9
分类器返回结果: 9 实际结果: 3
分类器返回结果: 1 实际结果: 8
分类器返回结果: 1 实际结果: 9
分类器返回结果: 1 实际结果: 8
分类器返回结果: 7 实际结果: 1
分类器返回结果: 6 实际结果: 5
分类器返回结果: 3 实际结果: 5
分类器返回结果: 6 实际结果: 8
总错误数: 11
错误率: 0.011627906976744186
kNN算法识别手写数字的错误率为1.2%。通过改变k值、训练集和测试集的数目,都会对错误率产生影响。
小结
kNN是分类数据最简单有效的算法。
kNN必须保存全部数据集,如果训练集很大,必须使用大量的存储空间。
由于必须对数据集中的每个数据计算距离,实际使用会非常耗时。
都看到最后了,要不~点个赞?加波关注?
网友评论