美文网首页
01 手工就近原则实现一个简单的鸢尾花分类器

01 手工就近原则实现一个简单的鸢尾花分类器

作者: 夏威夷的芒果 | 来源:发表于2018-08-16 17:26 被阅读45次
    基础知识 问题描述
    任务描述 数据集 思路

    人工智能数据源下载地址,下载压缩包后解压即可.
    小脚本下载地址

    原理

    数据集分开,一部分用来训练集合,一部分作为测试集合,测试集合里面每一条用来与训练集合中的元素比对,就近标类。

    代码

    import pandas as pd
    import ai_utils
    from sklearn.model_selection import train_test_split
    from scipy.spatial.distance import euclidean
    import numpy as np
    
    #读取文件
    data_file = '/Users/miraco/PycharmProjects/ai/data_ai/Iris.csv'
    
    #种类
    species = ['Iris-setosa',
               'Iris-versicolor',
               'Iris-virginica'
               ]
    #特征
    feat_cols = ['SepalLengthCm','SepalWidthCm','PetalLengthCm','PetalWidthCm']
    
    def get_pred_label(test_sample_feat, train_data):
        #近朱者赤,,找最近距离的样本,取其标签作为预测样本的标签
        dis_list = []
        for idx, row in train_data.iterrows():
            #训练样本特征
            train_sample_feat = row[feat_cols].values
            #计算当前条目和样本集合之间的距离
            dis = euclidean(test_sample_feat, train_sample_feat)
            dis_list.append(dis)
    
        #最小距离对应的位置
        pos = np.argmin(dis_list)
        #离谁最近就算成谁
        pred_label = train_data.iloc[pos]['Species']
        return pred_label
    
    
    #读取数据
    
    iris_data = pd.read_csv(data_file, index_col = 'Id')
    
    #eda
    
    ai_utils.do_eda_plot_for_iris(iris_data)
    
    # 划分数据集
    #三分之一作为训练集
    train_data, test_data = train_test_split(iris_data, test_size= 1/3 , random_state= 10)
    
    # 预测对的个数
    acc_count = 0
    
    # 分类器
    
    for idx,row in test_data.iterrows():
        # 测试样本特征
        test_sample_feat = row[feat_cols].values
    
        # 预测值
        pred_label = get_pred_label(test_sample_feat, train_data)
    
        # 真实值
        true_label = row['Species']
        print(f'样本{idx}的真实标签是{true_label},预测标签是{pred_label}')
        if true_label == pred_label:
            acc_count += 1
    
    
    # 准确率
    accuracy  = acc_count / test_data.shape[0]
    print('预测准确率{:2f}%'.format(accuracy*100))
    
    

    运行结果

    样本88的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本112的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本11的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本92的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本50的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本61的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本73的真实标签是Iris-versicolor,预测标签是Iris-virginica
    样本68的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本40的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本56的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本67的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本143的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本54的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本2的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本20的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本113的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本86的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本39的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本22的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本36的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本103的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本133的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本127的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本25的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本62的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本3的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本96的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本91的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本77的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本118的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本59的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本98的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本130的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本115的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本147的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本48的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本125的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本121的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本119的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本142的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本27的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本44的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本60的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本42的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本57的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本33的真实标签是Iris-setosa,预测标签是Iris-setosa
    样本53的真实标签是Iris-versicolor,预测标签是Iris-versicolor
    样本71的真实标签是Iris-versicolor,预测标签是Iris-virginica
    样本122的真实标签是Iris-virginica,预测标签是Iris-virginica
    样本145的真实标签是Iris-virginica,预测标签是Iris-virginica
    预测准确率96.000000%
    
    运行的图

    复习需要注意的地方:

    • 知识点:


    • sklearntrain_test_split
    from sklearn.model_selection import train_test_split
    train_test_split(train_data,train_target,test_size=0.3, random_state=0)
    

    参数解释:
    train_data:被划分的样本特征集
    train_target:被划分的样本标签
    test_size:如果是浮点数,在0-1之间,表示样本占比;如果是整数,就是样本的数量
    random_state:是随机数的种子。随机数种子其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:

    1. 种子不同,产生不同的随机数;
    2. 种子相同,即使实例不同也产生相同的随机数。
    • 按行遍历
    for idx,row in test_data.iterrows():
        # 测试样本特征
        test_sample_feat = row[feat_cols].values
    
    • 欧式空间距离
    from scipy.spatial.distance import euclidean
    dis = euclidean(test_sample_feat, train_sample_feat) 
    

    练习:手工实现一个简单的水果识别器

    • 题目描述:创建一个水果识别器,根据水果的属性,判断该水果的种类。

    • 题目要求:

    • 根据“近朱者赤”的原则,手工实现一个简单的分类器

    • 选取1/5的数据作为测试集

    • 数据文件:

    • 数据源下载地址:https://video.mugglecode.com/fruit_data.csv

    • fruit_data.csv,包含了59个水果的的数据样本。

    • 共5列数据

    • fruit_name:水果类别

    • mass: 水果质量

    • width: 水果的宽度

    • height: 水果的高度

    • color_score: 水果的颜色数值,范围0-1。

    • 0.85 - 1.00:红色

    • 0.75 - 0.85: 橙色

    • 0.65 - 0.75: 黄色

    • 0.45 - 0.65: 绿色


      image

    参考答案

    import pandas as pd
    from sklearn.model_selection import train_test_split
    from scipy.spatial.distance import euclidean
    import numpy as np
    
    
    #特征文字
    
    feat_cols =['mass','width','height','color_score']
    
    #读取数据
    
    data = pd.read_csv('/Users/miraco/PycharmProjects/ai/data_ai/fruit_data.csv')
    
    #划分数据
    
    train_set, test_set = train_test_split(data, random_state = 10, test_size= 0.4)
    
    #计算结果
    
    acc_count = 0  # 预测对的个数
    
    for idx, row in test_set.iterrows():
        #提取每一行的各特征的值
        test_sample_feat = row[feat_cols].values  #多维的一定写value
    
        #预测值
    
        pos = np.argmin([euclidean(test_sample_feat,train_row[feat_cols].values) for idx2, train_row in train_set.iterrows()])
        pred_label = train_set.iloc[pos]['fruit_name']
    
        #实际值
        real_label = row['fruit_name']
    
        print(f'样本{idx}的真实标签是{real_label},预测标签是{pred_label}')
    
        if real_label == pred_label:
            acc_count += 1
    
    # 准确率
    accuracy  = acc_count / test_set.shape[0]
    print('预测准确率{:2f}%'.format(accuracy*100))
    

    运行结果

    /Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
      return f(*args, **kwds)
    样本31的真实标签是orange,预测标签是lemon
    样本3的真实标签是mandarin,预测标签是mandarin
    样本38的真实标签是orange,预测标签是orange
    样本27的真实标签是orange,预测标签是lemon
    样本21的真实标签是apple,预测标签是apple
    样本17的真实标签是apple,预测标签是apple
    样本46的真实标签是lemon,预测标签是lemon
    样本2的真实标签是apple,预测标签是apple
    样本23的真实标签是apple,预测标签是apple
    样本26的真实标签是orange,预测标签是orange
    样本35的真实标签是orange,预测标签是apple
    样本39的真实标签是orange,预测标签是orange
    样本20的真实标签是apple,预测标签是orange
    样本37的真实标签是orange,预测标签是orange
    样本7的真实标签是mandarin,预测标签是mandarin
    样本6的真实标签是mandarin,预测标签是mandarin
    样本45的真实标签是lemon,预测标签是orange
    样本56的真实标签是lemon,预测标签是lemon
    样本47的真实标签是lemon,预测标签是lemon
    样本10的真实标签是apple,预测标签是orange
    样本44的真实标签是lemon,预测标签是lemon
    样本54的真实标签是lemon,预测标签是lemon
    样本18的真实标签是apple,预测标签是apple
    样本4的真实标签是mandarin,预测标签是mandarin
    预测准确率75.000000%
    
    Process finished with exit code 0
    

    这个警告的原因是是各种库之间的版本不匹配,只需要把numpy的版本降到1.14.5就可以了。

    sudo pip uninstall numpy
    sudo pip install numpy==1.14.5
    

    我懒得理他,就这样吧。

    相关文章

      网友评论

          本文标题:01 手工就近原则实现一个简单的鸢尾花分类器

          本文链接:https://www.haomeiwen.com/subject/djsgbftx.html