美文网首页
KNN算法应用

KNN算法应用

作者: foochane | 来源:发表于2018-01-17 15:11 被阅读46次

    1. 利用Iris数据集来使用KNN算法

    1.1 Iris数据集介绍

    Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。

    该数据集包含了5个属性:

    • Sepal.Length(花萼长度),单位是cm;
    • Sepal.Width(花萼宽度),单位是cm;
    • Petal.Length(花瓣长度),单位是cm;
    • Petal.Width(花瓣宽度),单位是cm;
    • 种类:Iris Setosa(山鸢尾)、Iris Versicolour(杂色鸢尾),以及Iris Virginica(维吉尼亚鸢尾)。

    1.2 代码实现

    代码:

    from sklearn import neighbors
    from sklearn import datasets
    
    knn = neighbors.KNeighborsClassifier()
    
    #导入数据
    iris = datasets.load_iris()
    # save data
    # f = open("iris.data.csv", 'wb')
    # f.write(str(iris))
    # f.close()
    
    print("数据:")
    print(iris)
    
    #建立模型
    knn.fit(iris.data, iris.target)
    
    predictedLabel = knn.predict([[7.2,  3.6,  6.1,  2.5]])
    print("预测结果:")
    print(predictedLabel)
    

    运行结果:

    D:\dev\Anaconda3\python.exe D:/code/python/PyCharm/MachineLearning/KNN/iris.py
    数据:
    {'data': array([[ 5.1,  3.5,  1.4,  0.2],
           [ 4.9,  3. ,  1.4,  0.2],
           [ 4.7,  3.2,  1.3,  0.2],
           [ 4.6,  3.1,  1.5,  0.2],
           [ 5. ,  3.6,  1.4,  0.2],
           [ 5.4,  3.9,  1.7,  0.4],
           [ 4.6,  3.4,  1.4,  0.3],
           [ 5. ,  3.4,  1.5,  0.2],
           [ 4.4,  2.9,  1.4,  0.2],
           [ 4.9,  3.1,  1.5,  0.1],
           [ 5.4,  3.7,  1.5,  0.2],
           [ 4.8,  3.4,  1.6,  0.2],
           [ 4.8,  3. ,  1.4,  0.1],
           [ 4.3,  3. ,  1.1,  0.1],
           [ 5.8,  4. ,  1.2,  0.2],
           [ 5.7,  4.4,  1.5,  0.4],
           [ 5.4,  3.9,  1.3,  0.4],
           [ 5.1,  3.5,  1.4,  0.3],
           [ 5.7,  3.8,  1.7,  0.3],
           [ 5.1,  3.8,  1.5,  0.3],
           [ 5.4,  3.4,  1.7,  0.2],
           [ 5.1,  3.7,  1.5,  0.4],
           [ 4.6,  3.6,  1. ,  0.2],
           [ 5.1,  3.3,  1.7,  0.5],
           [ 4.8,  3.4,  1.9,  0.2],
           [ 5. ,  3. ,  1.6,  0.2],
           [ 5. ,  3.4,  1.6,  0.4],
           [ 5.2,  3.5,  1.5,  0.2],
           [ 5.2,  3.4,  1.4,  0.2],
           [ 4.7,  3.2,  1.6,  0.2],
           [ 4.8,  3.1,  1.6,  0.2],
           [ 5.4,  3.4,  1.5,  0.4],
           [ 5.2,  4.1,  1.5,  0.1],
           [ 5.5,  4.2,  1.4,  0.2],
           [ 4.9,  3.1,  1.5,  0.1],
           [ 5. ,  3.2,  1.2,  0.2],
           [ 5.5,  3.5,  1.3,  0.2],
           [ 4.9,  3.1,  1.5,  0.1],
           [ 4.4,  3. ,  1.3,  0.2],
           [ 5.1,  3.4,  1.5,  0.2],
           [ 5. ,  3.5,  1.3,  0.3],
           [ 4.5,  2.3,  1.3,  0.3],
           [ 4.4,  3.2,  1.3,  0.2],
           [ 5. ,  3.5,  1.6,  0.6],
           [ 5.1,  3.8,  1.9,  0.4],
           [ 4.8,  3. ,  1.4,  0.3],
           [ 5.1,  3.8,  1.6,  0.2],
           [ 4.6,  3.2,  1.4,  0.2],
           [ 5.3,  3.7,  1.5,  0.2],
           [ 5. ,  3.3,  1.4,  0.2],
           [ 7. ,  3.2,  4.7,  1.4],
           [ 6.4,  3.2,  4.5,  1.5],
           [ 6.9,  3.1,  4.9,  1.5],
           [ 5.5,  2.3,  4. ,  1.3],
           [ 6.5,  2.8,  4.6,  1.5],
           [ 5.7,  2.8,  4.5,  1.3],
           [ 6.3,  3.3,  4.7,  1.6],
           [ 4.9,  2.4,  3.3,  1. ],
           [ 6.6,  2.9,  4.6,  1.3],
           [ 5.2,  2.7,  3.9,  1.4],
           [ 5. ,  2. ,  3.5,  1. ],
           [ 5.9,  3. ,  4.2,  1.5],
           [ 6. ,  2.2,  4. ,  1. ],
           [ 6.1,  2.9,  4.7,  1.4],
           [ 5.6,  2.9,  3.6,  1.3],
           [ 6.7,  3.1,  4.4,  1.4],
           [ 5.6,  3. ,  4.5,  1.5],
           [ 5.8,  2.7,  4.1,  1. ],
           [ 6.2,  2.2,  4.5,  1.5],
           [ 5.6,  2.5,  3.9,  1.1],
           [ 5.9,  3.2,  4.8,  1.8],
           [ 6.1,  2.8,  4. ,  1.3],
           [ 6.3,  2.5,  4.9,  1.5],
           [ 6.1,  2.8,  4.7,  1.2],
           [ 6.4,  2.9,  4.3,  1.3],
           [ 6.6,  3. ,  4.4,  1.4],
           [ 6.8,  2.8,  4.8,  1.4],
           [ 6.7,  3. ,  5. ,  1.7],
           [ 6. ,  2.9,  4.5,  1.5],
           [ 5.7,  2.6,  3.5,  1. ],
           [ 5.5,  2.4,  3.8,  1.1],
           [ 5.5,  2.4,  3.7,  1. ],
           [ 5.8,  2.7,  3.9,  1.2],
           [ 6. ,  2.7,  5.1,  1.6],
           [ 5.4,  3. ,  4.5,  1.5],
           [ 6. ,  3.4,  4.5,  1.6],
           [ 6.7,  3.1,  4.7,  1.5],
           [ 6.3,  2.3,  4.4,  1.3],
           [ 5.6,  3. ,  4.1,  1.3],
           [ 5.5,  2.5,  4. ,  1.3],
           [ 5.5,  2.6,  4.4,  1.2],
           [ 6.1,  3. ,  4.6,  1.4],
           [ 5.8,  2.6,  4. ,  1.2],
           [ 5. ,  2.3,  3.3,  1. ],
           [ 5.6,  2.7,  4.2,  1.3],
           [ 5.7,  3. ,  4.2,  1.2],
           [ 5.7,  2.9,  4.2,  1.3],
           [ 6.2,  2.9,  4.3,  1.3],
           [ 5.1,  2.5,  3. ,  1.1],
           [ 5.7,  2.8,  4.1,  1.3],
           [ 6.3,  3.3,  6. ,  2.5],
           [ 5.8,  2.7,  5.1,  1.9],
           [ 7.1,  3. ,  5.9,  2.1],
           [ 6.3,  2.9,  5.6,  1.8],
           [ 6.5,  3. ,  5.8,  2.2],
           [ 7.6,  3. ,  6.6,  2.1],
           [ 4.9,  2.5,  4.5,  1.7],
           [ 7.3,  2.9,  6.3,  1.8],
           [ 6.7,  2.5,  5.8,  1.8],
           [ 7.2,  3.6,  6.1,  2.5],
           [ 6.5,  3.2,  5.1,  2. ],
           [ 6.4,  2.7,  5.3,  1.9],
           [ 6.8,  3. ,  5.5,  2.1],
           [ 5.7,  2.5,  5. ,  2. ],
           [ 5.8,  2.8,  5.1,  2.4],
           [ 6.4,  3.2,  5.3,  2.3],
           [ 6.5,  3. ,  5.5,  1.8],
           [ 7.7,  3.8,  6.7,  2.2],
           [ 7.7,  2.6,  6.9,  2.3],
           [ 6. ,  2.2,  5. ,  1.5],
           [ 6.9,  3.2,  5.7,  2.3],
           [ 5.6,  2.8,  4.9,  2. ],
           [ 7.7,  2.8,  6.7,  2. ],
           [ 6.3,  2.7,  4.9,  1.8],
           [ 6.7,  3.3,  5.7,  2.1],
           [ 7.2,  3.2,  6. ,  1.8],
           [ 6.2,  2.8,  4.8,  1.8],
           [ 6.1,  3. ,  4.9,  1.8],
           [ 6.4,  2.8,  5.6,  2.1],
           [ 7.2,  3. ,  5.8,  1.6],
           [ 7.4,  2.8,  6.1,  1.9],
           [ 7.9,  3.8,  6.4,  2. ],
           [ 6.4,  2.8,  5.6,  2.2],
           [ 6.3,  2.8,  5.1,  1.5],
           [ 6.1,  2.6,  5.6,  1.4],
           [ 7.7,  3. ,  6.1,  2.3],
           [ 6.3,  3.4,  5.6,  2.4],
           [ 6.4,  3.1,  5.5,  1.8],
           [ 6. ,  3. ,  4.8,  1.8],
           [ 6.9,  3.1,  5.4,  2.1],
           [ 6.7,  3.1,  5.6,  2.4],
           [ 6.9,  3.1,  5.1,  2.3],
           [ 5.8,  2.7,  5.1,  1.9],
           [ 6.8,  3.2,  5.9,  2.3],
           [ 6.7,  3.3,  5.7,  2.5],
           [ 6.7,  3. ,  5.2,  2.3],
           [ 6.3,  2.5,  5. ,  1.9],
           [ 6.5,  3. ,  5.2,  2. ],
           [ 6.2,  3.4,  5.4,  2.3],
           [ 5.9,  3. ,  5.1,  1.8]]), 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'target_names': array(['setosa', 'versicolor', 'virginica'],
          dtype='<U10'), 'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-----\nData Set Characteristics:\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...\n', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']}
    预测结果:
    [2]
    
    Process finished with exit code 0
    

    2. 自己实现KNN算法

    2.1 数据

    irisdata.txt,和前面的一样,只不过是txt格式的

    5.1,3.5,1.4,0.2,Iris-setosa
    4.9,3.0,1.4,0.2,Iris-setosa
    4.7,3.2,1.3,0.2,Iris-setosa
    4.6,3.1,1.5,0.2,Iris-setosa
    5.0,3.6,1.4,0.2,Iris-setosa
    5.4,3.9,1.7,0.4,Iris-setosa
    4.6,3.4,1.4,0.3,Iris-setosa
    5.0,3.4,1.5,0.2,Iris-setosa
    4.4,2.9,1.4,0.2,Iris-setosa
    4.9,3.1,1.5,0.1,Iris-setosa
    5.4,3.7,1.5,0.2,Iris-setosa
    4.8,3.4,1.6,0.2,Iris-setosa
    4.8,3.0,1.4,0.1,Iris-setosa
    4.3,3.0,1.1,0.1,Iris-setosa
    5.8,4.0,1.2,0.2,Iris-setosa
    5.7,4.4,1.5,0.4,Iris-setosa
    5.4,3.9,1.3,0.4,Iris-setosa
    5.1,3.5,1.4,0.3,Iris-setosa
    5.7,3.8,1.7,0.3,Iris-setosa
    5.1,3.8,1.5,0.3,Iris-setosa
    5.4,3.4,1.7,0.2,Iris-setosa
    5.1,3.7,1.5,0.4,Iris-setosa
    4.6,3.6,1.0,0.2,Iris-setosa
    5.1,3.3,1.7,0.5,Iris-setosa
    4.8,3.4,1.9,0.2,Iris-setosa
    5.0,3.0,1.6,0.2,Iris-setosa
    5.0,3.4,1.6,0.4,Iris-setosa
    5.2,3.5,1.5,0.2,Iris-setosa
    5.2,3.4,1.4,0.2,Iris-setosa
    4.7,3.2,1.6,0.2,Iris-setosa
    4.8,3.1,1.6,0.2,Iris-setosa
    5.4,3.4,1.5,0.4,Iris-setosa
    5.2,4.1,1.5,0.1,Iris-setosa
    5.5,4.2,1.4,0.2,Iris-setosa
    4.9,3.1,1.5,0.1,Iris-setosa
    5.0,3.2,1.2,0.2,Iris-setosa
    5.5,3.5,1.3,0.2,Iris-setosa
    4.9,3.1,1.5,0.1,Iris-setosa
    4.4,3.0,1.3,0.2,Iris-setosa
    5.1,3.4,1.5,0.2,Iris-setosa
    5.0,3.5,1.3,0.3,Iris-setosa
    4.5,2.3,1.3,0.3,Iris-setosa
    4.4,3.2,1.3,0.2,Iris-setosa
    5.0,3.5,1.6,0.6,Iris-setosa
    5.1,3.8,1.9,0.4,Iris-setosa
    4.8,3.0,1.4,0.3,Iris-setosa
    5.1,3.8,1.6,0.2,Iris-setosa
    4.6,3.2,1.4,0.2,Iris-setosa
    5.3,3.7,1.5,0.2,Iris-setosa
    5.0,3.3,1.4,0.2,Iris-setosa
    7.0,3.2,4.7,1.4,Iris-versicolor
    6.4,3.2,4.5,1.5,Iris-versicolor
    6.9,3.1,4.9,1.5,Iris-versicolor
    5.5,2.3,4.0,1.3,Iris-versicolor
    6.5,2.8,4.6,1.5,Iris-versicolor
    5.7,2.8,4.5,1.3,Iris-versicolor
    6.3,3.3,4.7,1.6,Iris-versicolor
    4.9,2.4,3.3,1.0,Iris-versicolor
    6.6,2.9,4.6,1.3,Iris-versicolor
    5.2,2.7,3.9,1.4,Iris-versicolor
    5.0,2.0,3.5,1.0,Iris-versicolor
    5.9,3.0,4.2,1.5,Iris-versicolor
    6.0,2.2,4.0,1.0,Iris-versicolor
    6.1,2.9,4.7,1.4,Iris-versicolor
    5.6,2.9,3.6,1.3,Iris-versicolor
    6.7,3.1,4.4,1.4,Iris-versicolor
    5.6,3.0,4.5,1.5,Iris-versicolor
    5.8,2.7,4.1,1.0,Iris-versicolor
    6.2,2.2,4.5,1.5,Iris-versicolor
    5.6,2.5,3.9,1.1,Iris-versicolor
    5.9,3.2,4.8,1.8,Iris-versicolor
    6.1,2.8,4.0,1.3,Iris-versicolor
    6.3,2.5,4.9,1.5,Iris-versicolor
    6.1,2.8,4.7,1.2,Iris-versicolor
    6.4,2.9,4.3,1.3,Iris-versicolor
    6.6,3.0,4.4,1.4,Iris-versicolor
    6.8,2.8,4.8,1.4,Iris-versicolor
    6.7,3.0,5.0,1.7,Iris-versicolor
    6.0,2.9,4.5,1.5,Iris-versicolor
    5.7,2.6,3.5,1.0,Iris-versicolor
    5.5,2.4,3.8,1.1,Iris-versicolor
    5.5,2.4,3.7,1.0,Iris-versicolor
    5.8,2.7,3.9,1.2,Iris-versicolor
    6.0,2.7,5.1,1.6,Iris-versicolor
    5.4,3.0,4.5,1.5,Iris-versicolor
    6.0,3.4,4.5,1.6,Iris-versicolor
    6.7,3.1,4.7,1.5,Iris-versicolor
    6.3,2.3,4.4,1.3,Iris-versicolor
    5.6,3.0,4.1,1.3,Iris-versicolor
    5.5,2.5,4.0,1.3,Iris-versicolor
    5.5,2.6,4.4,1.2,Iris-versicolor
    6.1,3.0,4.6,1.4,Iris-versicolor
    5.8,2.6,4.0,1.2,Iris-versicolor
    5.0,2.3,3.3,1.0,Iris-versicolor
    5.6,2.7,4.2,1.3,Iris-versicolor
    5.7,3.0,4.2,1.2,Iris-versicolor
    5.7,2.9,4.2,1.3,Iris-versicolor
    6.2,2.9,4.3,1.3,Iris-versicolor
    5.1,2.5,3.0,1.1,Iris-versicolor
    5.7,2.8,4.1,1.3,Iris-versicolor
    6.3,3.3,6.0,2.5,Iris-virginica
    5.8,2.7,5.1,1.9,Iris-virginica
    7.1,3.0,5.9,2.1,Iris-virginica
    6.3,2.9,5.6,1.8,Iris-virginica
    6.5,3.0,5.8,2.2,Iris-virginica
    7.6,3.0,6.6,2.1,Iris-virginica
    4.9,2.5,4.5,1.7,Iris-virginica
    7.3,2.9,6.3,1.8,Iris-virginica
    6.7,2.5,5.8,1.8,Iris-virginica
    7.2,3.6,6.1,2.5,Iris-virginica
    6.5,3.2,5.1,2.0,Iris-virginica
    6.4,2.7,5.3,1.9,Iris-virginica
    6.8,3.0,5.5,2.1,Iris-virginica
    5.7,2.5,5.0,2.0,Iris-virginica
    5.8,2.8,5.1,2.4,Iris-virginica
    6.4,3.2,5.3,2.3,Iris-virginica
    6.5,3.0,5.5,1.8,Iris-virginica
    7.7,3.8,6.7,2.2,Iris-virginica
    7.7,2.6,6.9,2.3,Iris-virginica
    6.0,2.2,5.0,1.5,Iris-virginica
    6.9,3.2,5.7,2.3,Iris-virginica
    5.6,2.8,4.9,2.0,Iris-virginica
    7.7,2.8,6.7,2.0,Iris-virginica
    6.3,2.7,4.9,1.8,Iris-virginica
    6.7,3.3,5.7,2.1,Iris-virginica
    7.2,3.2,6.0,1.8,Iris-virginica
    6.2,2.8,4.8,1.8,Iris-virginica
    6.1,3.0,4.9,1.8,Iris-virginica
    6.4,2.8,5.6,2.1,Iris-virginica
    7.2,3.0,5.8,1.6,Iris-virginica
    7.4,2.8,6.1,1.9,Iris-virginica
    7.9,3.8,6.4,2.0,Iris-virginica
    6.4,2.8,5.6,2.2,Iris-virginica
    6.3,2.8,5.1,1.5,Iris-virginica
    6.1,2.6,5.6,1.4,Iris-virginica
    7.7,3.0,6.1,2.3,Iris-virginica
    6.3,3.4,5.6,2.4,Iris-virginica
    6.4,3.1,5.5,1.8,Iris-virginica
    6.0,3.0,4.8,1.8,Iris-virginica
    6.9,3.1,5.4,2.1,Iris-virginica
    6.7,3.1,5.6,2.4,Iris-virginica
    6.9,3.1,5.1,2.3,Iris-virginica
    5.8,2.7,5.1,1.9,Iris-virginica
    6.8,3.2,5.9,2.3,Iris-virginica
    6.7,3.3,5.7,2.5,Iris-virginica
    6.7,3.0,5.2,2.3,Iris-virginica
    6.3,2.5,5.0,1.9,Iris-virginica
    6.5,3.0,5.2,2.0,Iris-virginica
    6.2,3.4,5.4,2.3,Iris-virginica
    5.9,3.0,5.1,1.8,Iris-virginica
    
    

    2.2 代码

    import csv
    import random
    import math
    import operator
    
    
    #导入数据集 并将数据分为测试集和训练集
    def loadDataset(filename, split, trainingSet = [], testSet = []):
        with open(filename, 'rt') as csvfile:
            lines = csv.reader(csvfile)
            dataset = list(lines)
            for x in range(len(dataset)-1):
                for y in range(4):
                    dataset[x][y] = float(dataset[x][y])
                if random.random() < split:
                    trainingSet.append(dataset[x])
                else:
                    testSet.append(dataset[x])
    
    
    # 计算欧式距离
    def euclideanDistance(instance1, instance2, length):
        distance = 0
        for x in range(length):
            distance += pow((instance1[x]-instance2[x]), 2)
        return math.sqrt(distance)
    
    
    # 从训练集(trainingSet)中选出距离测试实例(testInstance)最近的k个训练实例neighbors
    def getNeighbors(trainingSet, testInstance, k):
        distances = []
        length = len(testInstance)-1    # 实例数据的维度,这里为4
        for x in range(len(trainingSet)):
            # 计算trainingSet中的每个实例距离testInstance的距离dist
            dist = euclideanDistance(testInstance, trainingSet[x], length)
            distances.append((trainingSet[x], dist))
            #distances.append(dist)
    
        # 根据dist从小到大进行排序,取出前k个
        distances.sort(key=operator.itemgetter(1))
        neighbors = []
        for x in range(k):
            neighbors.append(distances[x][0])
        return neighbors
    
    
    # 对neighbors中的数据根据距离进行投票,并返回票数最多的一个
    def getResponse(neighbors):
        classVotes = {}
        for x in range(len(neighbors)):
            response = neighbors[x][-1]
            if response in classVotes:
                classVotes[response] += 1
            else:
                classVotes[response] = 1
        sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)
        return sortedVotes[0][0]
    
    
    # 计算算法的准确率
    def getAccuracy(testSet, predictions):
        correct = 0
        for x in range(len(testSet)):
            if testSet[x][-1] == predictions[x]:
                correct += 1
        return (correct/float(len(testSet)))*100.0
    
    
    def main():
        #prepare data
        trainingSet = []
        testSet = []
        split = 0.67   # 0.67约等于2/3, 2/3的数据为训练集,1/3的数据为训练集
        loadDataset(r'irisdata.txt', split, trainingSet, testSet)
        print('Train set: ' + repr(len(trainingSet)))
        print('Test set: ' + repr(len(testSet)))
    
        predictions = []  # 存储预测的结果
        k = 3
        for x in range(len(testSet)):
            neighbors = getNeighbors(trainingSet, testSet[x], k)
            result = getResponse(neighbors)
            predictions.append(result)
            print('>predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))
    
        print('--'*90)
    
        print('predictions: ' + repr(predictions))
        accuracy = getAccuracy(testSet, predictions)
        print('Accuracy: ' + repr(accuracy) + '%')
    
    if '__main__' == __name__:
        main()
    





                【注】:本文为麦子学院机器学习课程的学习笔记

    相关文章

      网友评论

          本文标题:KNN算法应用

          本文链接:https://www.haomeiwen.com/subject/wchuoxtx.html