美文网首页
通过分析鸢尾花数据学习K-近邻算法

通过分析鸢尾花数据学习K-近邻算法

作者: 芦志强 | 来源:发表于2020-01-03 23:51 被阅读0次

    一、算法整体思路

    1. 按照比例切分测试集和训练集
    2. 选取特征值,对训练集建模
    3. 对于任一测试数据样本,通过计算该样本到每个样本见的距离,从而选出离该样本最近的K个样本
    4. 通过选出的K个样本投票,从而决定测试数据样本的分类
    5. 通过对比测试样本的预测值与真实值来评价预测效果,从而选出最优K值和最优样本特征

    二、通过python方式实现K-近邻算法

    话不多说直接上代码

    import os
    import csv
    import random
    import math
    import operator
    
    
    #数据文件所在路径
    base_dir = os.path.dirname(os.path.abspath(__file__))
    file_path = os.path.join(base_dir, 'iris.data')
    
    
    #加载数据并分隔测试集和训练集
    def load_data(split,train_set=[],test_set=[]):
        with open(file_path,'r') as file:
            lines = csv.reader(file)
            ds = list(lines)
            for x in range(len(ds)-1):
                for y in range(4):
                    ds[x][y] = float(ds[x][y])
                if random.random() < split:
                    train_set.append(ds[x])
                else:
                    test_set.append(ds[x])
    
    
    #计算距离
    def get_distance(instance1, instance2, length):
        distance = 0
        for x in range(length):
            distance += pow((instance1[x] - instance2[x]), 2)
        return math.sqrt(distance)
    
    
    #获取最近的k个样本
    def get_neighbors(trainingSet, testInstance, k):
        distances = []
        length = len(testInstance)-1
        # (1) 计算测试样本和每个训练样本的欧式距离
        for x in range(len(trainingSet)):
            dist = get_distance(testInstance, trainingSet[x], length)
            distances.append((trainingSet[x], dist))
        # (2) 对距离进行排序
        distances.sort(key=operator.itemgetter(1))
        neighbors = []
        # (3) 返回最近的K个邻居
        for x in range(k):
            neighbors.append(distances[x][0])
        return neighbors
    
    
    #通过最近的k个样本 决定预测样本属于哪个分类
    def get_response(neighbors):
        classVotes = {}
        # (1) 遍历K个最近的邻居中每个邻居
        for x in range(len(neighbors)):
            # 统计最近邻居中所有的类别标签数量
            response = neighbors[x][-1]
            if response in classVotes:
                classVotes[response] += 1
            else:
                classVotes[response] = 1
        sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)
        return sortedVotes[0][0]
    
    
    #计算准确度 (正确数量/样本总量)*100
    def getAccuracy(testSet, predictions):
        correct = 0
        # 遍历每个测试集的元素,计算预测值和真实值是否相等,计算准确度
        for x in range(len(testSet)):
            if testSet[x][-1] == predictions[x]:
                correct += 1
        return (correct/float(len(testSet))) * 100.0
    
    
    # 开始预测
    trainingSet=[]
    testSet=[]
    split = 0.67
    load_data(split, trainingSet, testSet)
    print('Train set: ' + repr(len(trainingSet)))
    print('Test set: ' + repr(len(testSet)))
    # generate predictions
    predictions=[]
    k = 3
        # (0) 遍历每个测试样本
    for x in range(len(testSet)):
            # (1) 对每个测试样本找到训练集中的最近的K邻居
        neighbors = get_neighbors(trainingSet, testSet[x], k)
            # (2) 统计K个邻居的类别
        result = get_response(neighbors)
            # (3) 记录结果
        predictions.append(result)
        print('> predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))
    accuracy = getAccuracy(testSet, predictions)
    print('Accuracy: ' + repr(accuracy) + '%')
    

    输出节选

    ...
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    > predicted='Iris-virginica', actual='Iris-virginica'
    Accuracy: 95.83333333333334%
    

    三、通过sklearn库实现k-近邻算法

    代码

    from sklearn.model_selection import cross_val_score
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn import metrics
    # 导入数据 这里直接导入了sklearn框架内的鸢尾花示例数据
    iris = load_iris()
    # 分隔训练集合测试集(花瓣+花萼) test_size 0.3 表示 测试集占比30% 训练集占比 70%
    trainX, testX, trainY, testY = train_test_split(
        iris.data, iris.target, test_size=0.3)
    
    
    # 开始训练
    # n_neighbors=5 表示要找5个最近样本进行投票
    model = KNeighborsClassifier(n_neighbors=5)
    model.fit(trainX, trainY)
    # 预测
    predict = model.predict(testX)
    # 打印预测结果及评分
    print('预测结果:{}'.format(predict))
    print('真实样本:{}'.format(testY))
    print('预测结果得分:{}'.format(metrics.accuracy_score(testY, predict)))
    
    
    # 为了防止因为仅一次训练集、测试集分割,导致训练评价不准确的情况,
    # 可以通过交叉验证的方式解决
    # 交叉验证的主要思想就是 将数据集分成N份 进行N次计算 每次计算都有其中的一份作为测试集
    # sklearn框架提供了交叉验证的方法
    scores = cross_val_score(
        model, iris.data, iris.target, cv=10, scoring='accuracy')  # cv=10 表示将数据集分成10份
    print('交叉验证得分:{}'.format(scores))
    

    输出

    预测结果:[0 0 1 1 2 2 1 1 0 0 0 2 2 2 2 1 0 2 1 0 2 1 0 1 0 2 1 2 0 0 0 0 1 1 1 2 0
     2 1 1 0 2 1 1 0]
    真实样本:[0 0 1 1 2 2 1 1 0 0 0 2 1 2 2 1 0 2 1 0 2 1 0 1 0 2 1 2 0 0 0 0 1 1 1 2 0
     2 1 1 0 2 1 1 0]
    预测结果得分:0.9777777777777777
    交叉验证得分:[1.         0.93333333 1.         1.         0.86666667 0.93333333
     0.93333333 1.         1.         1.        ]
    

    四、总结

    通过以上示例,基本了解了运用K-近邻算法解决分类问题的方法。通过运用sklearn库,节省了大量的代码编写,可以让精力更多的放在解决实际问题上。而且sklearn库还提供了完善的交叉验证方法,相对于直接运用Python编程实现,优势非常明显。

    相关文章

      网友评论

          本文标题:通过分析鸢尾花数据学习K-近邻算法

          本文链接:https://www.haomeiwen.com/subject/vwztactx.html