kaggle-识别手写数字

作者: 阿发贝塔伽马 | 来源:发表于2017-08-07 18:59 被阅读92次

    下载数据到本地,加载数据

    import numpy as np
    import csv
    import pandas as pd
    
    def load_data(csv):
        lines = csv.reader(open(csv))
        l = []
        for line in lines:
            l.append(line)
        return l
    
    l = load_data('train.csv')
    l = np.array(l[1:], dtype=float)
    train = l[1:,1:]
    label = l[1:,0]
    
    a = pd.DataFrame(train)
    # 二值化,不影响数字显示
    a[a > 1] = 1
    
    
    l = load_data('test.csv')
    test = np.array(l[1:], dtype=float)
    a = pd.DataFrame(test)
    # 二值化,不影响数字显示
    a[a > 1] = 1
    
    import seaborn as sns
    %matplotlib inline
    df = pd.DataFrame(np.hstack((train, label[:,None])),
                   columns = range(train.shape[1]) + ["class"])
    plt.figure(figsize=(8, 6))
    _ = sns.heatmap(df.corr(), annot=False)
    

    使用LogisticRegression分类

    from sklearn.linear_model import LogisticRegression
    from sklearn.preprocessing import StandardScaler
    from sklearn.metrics import accuracy_score
    
    X_train = train
    y_train = label
    sc = StandardScaler()
    sc.fit(X_train)
    X_train_std = sc.transform(X_train)
    
    lr = LogisticRegression(C=10000.0, random_state=0)
    lr.fit(X_train_std, y_train)
    

    看下训练集误差,误差大约6.7954%,这个还是蛮大的

    y_pred = lr.predict(X_train_std)
    print('Misclassified samples: %.8f' % ((y_train != y_pred).sum()/float(len(y_train))))
    
    OUT:Misclassified samples: 0.06795400
    

    对测试集预测

    X_test = test
    X_test_std = sc.transform(X_test)
    '''sc.scale_标准差, sc.mean_平均值, sc.var_方差'''
    y_pred = lr.predict(X_test_std)
    print y_pred
    
    OUT: [ 2.  0.  9. ...,  3.  9.  2.]
    

    提交kaggle,得分排名比较靠后


    画一个像素图片数字,第二个图片,上面预测是0

    from PIL import Image
    import numpy as np
    import matplotlib.pyplot as plt
    import csv
    
    test1 = test[1]
    test2 = []
    for el in test1:
        test2.append([0,0,el])
    
    img = np.array(test2)
    print img.shape
    img1 = img.reshape((28,28,3))
    plt.figure("dog")
    plt.imshow(img1)
    plt.axis('off')
    plt.show()
    

    相关文章

      网友评论

        本文标题:kaggle-识别手写数字

        本文链接:https://www.haomeiwen.com/subject/qqlnlxtx.html