用KNN解决非线性回归问题

作者: 刘开心_8a6c | 来源:发表于2017-04-22 12:25 被阅读1956次

    一直以为KNN只是分类算法,只能在分类上用,昨天突然想起用KNN试试做回归,最近有一批数据,通过4个特征来预测1个值,原来用线性回归和神经网络尝试过,准确率只能到40%左右。用KNN结合网格搜索和交叉验证,正确率达到了79%,没错,KNN解决回归问题也很赞。

    什么是KNN

    KNN就是K近邻算法(k-NearestNeighbor),百度百科是这么写的:K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

    KNN怎么做回归

    要预测的点的值通过求与它距离最近的K个点的值的平均值得到,这里的“距离最近”可以是欧氏距离,也可以是其他距离,具体的效果依数据而定,思路一样。如下图,x轴是一个特征,y是该特征得到的值,红色点是已知点,要预测第一个点的位置,则计算离它最近的三个点(黄色线框里的三个红点)的平均值,得出第一个绿色点,依次类推,就得到了绿色的线,可以看出,这样预测的值明显比直线准。


    K=3的拟合.png

    上述例子是基于一个特征的,如果是一个特征向量怎么办?其实一样,距离的衡量通过求两个特征向量的欧氏距离或者皮尔逊系数或者余弦距离就行。

    parametric learner和non-parametric learner

    parametric learner就是像线性回归一样,给一个y=mx+b的函数,找合适的m和b参数。non-parametric learner则没有猜测的函数,KNN做回归就是一个non-parametric learner,最终它也没有得到一个方程,只是能很好地作出预测。parametric learner的优点在于不用存储原始数据,训练慢但是查询快,缺点是不能轻易更新模型;non-parametric learner的优点在于更改模型容易,训练快但是查询慢,缺点是需要存储所有点,消耗空间。

    KNN解决非线性回归问题

    问题解决流程按照上篇的机器学习项目流程与模型评估验证完成。

    数据准备

    数据如下,一个csv表格,黄色是4个特征值,绿色是1和待预测值。


    数据.png

    加载数据

    import numpy as np
    import pandas as pd
    from sklearn.model_selection import ShuffleSplit
    
    # %matplotlib inline  将图表输出内嵌到jupyter notebook中,如果不用jupyter可以忽略这句
    
    data = pd.read_csv('fdata.csv')
    data = data[data['Friction']>16]   # 这句和下句点作用是去除异常数据
    data = data[data['Friction']<30]  
    friction = data['Friction'] 
    features = data.drop('Friction', axis = 1) #特征向量为原数据集剔除待预测列
    print("数据共有{}条,每条含有{}个特征.".format(*features.shape))
    

    输出为数据共有1635条,每条含有27个特征.

    数据分割与重排

    这一步使用train_test_split将数据随机拆分为80%的训练集与20%的测试集。如果不设定random_state,划分结果不那么随机,指定了random_state后,划分结果是随机的(具体工作原理没有细查,有朋友知道的感谢指教)。

    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(features, friction, test_size=0.2, random_state=50)
    
    # Success
    print("训练集与测试集拆分成功,训练集有{}条,测试集有{}条。".format(X_train.shape[0], X_test.shape[0]))
    

    输出为训练集与测试集拆分成功,训练集有1304条,测试集有327条。

    定义衡量标准

    这一步给模型表现定义一个衡量标准,也就是最后通过什么指标来看模型训练的表现,如果在训练中用了交叉验证来找模型的最优参数,在交叉验证里就可以调用这个衡量标准做评分。上篇的流程图中写过,分类问题的衡量标准有accuracy、precision、recall、F_bate分数,回归问题的衡量标准有平均绝对误差,均方误差,R2分数和可释方差分数。这里用R2分数。

    from sklearn.metrics import r2_score
    def performance_metric(y_true, y_predict):
        """ Calculates and returns the performance score between 
            true and predicted values based on the metric chosen. """
        
        score = r2_score(y_true, y_predict)
       
        return score
    

    训练模型

    重头戏到了,这个部分训练模型,我用了网格搜索和交叉验证从{3,4,5,6,7,8,9,10}里寻找R2分数最高的K作为最优参数,然后用这个K进行预测。我用了shuffleSplit和K-fold两种交叉验证。

    • shuffleSplit
    from sklearn.metrics import make_scorer
    from sklearn.neighbors import KNeighborsRegressor
    from sklearn.model_selection import GridSearchCV
    
    def fit_model_shuffle(X, y):
        """ Performs grid search over the 'max_depth' parameter for a 
            decision tree regressor trained on the input data [X, y]. """
        
        # Create cross-validation sets from the training data
        cv_sets = ShuffleSplit(n_splits = 10, test_size = 0.20, random_state = 0)
    
        # Create a KNN regressor object
        regressor = KNeighborsRegressor()
        # Create a dictionary for the parameter 'n_neighbors' with a range from 3 to 10
        params = {'n_neighbors':range(3,10)}
    
        # Transform 'performance_metric' into a scoring function using 'make_scorer' 
        scoring_fnc = make_scorer(performance_metric)
    
        # Create the grid search object
        grid = GridSearchCV(regressor, param_grid=params,scoring=scoring_fnc,cv=cv_sets)
    
        # Fit the grid search object to the data to compute the optimal model
        grid = grid.fit(X, y)
    
        # Return the optimal model after fitting the data
        return grid.best_estimator_
    
    • k-fold
    from sklearn.model_selection import KFold
    def fit_model_k_fold(X, y):
        """ Performs grid search over the 'max_depth' parameter for a 
            decision tree regressor trained on the input data [X, y]. """
        
        # Create cross-validation sets from the training data
        # cv_sets = ShuffleSplit(n_splits = 10, test_size = 0.20, random_state = 0)
        k_fold = KFold(n_splits=10)
        
        # TODO: Create a decision tree regressor object
        regressor = KNeighborsRegressor()
    
        # TODO: Create a dictionary for the parameter 'max_depth' with a range from 1 to 10
        params = {'n_neighbors':range(3,10)}
    
        # TODO: Transform 'performance_metric' into a scoring function using 'make_scorer' 
        scoring_fnc = make_scorer(performance_metric)
    
        # TODO: Create the grid search object
        grid = GridSearchCV(regressor, param_grid=params,scoring=scoring_fnc,cv=k_fold)
    
        # Fit the grid search object to the data to compute the optimal model
        grid = grid.fit(X, y)
    
        # Return the optimal model after fitting the data
        return grid.best_estimator_
    

    网格搜索返回的是一个Gridsearch的object,想用它的哪个属性就用哪个属性,API都写的很清楚,我这里返回最好的一个estimator。
    用下面代码查看找到的最优K:

    # Fit the training data to the model using grid search
    reg = fit_model_k_fold(X_train, y_train)
    
    print "Parameter 'n_neighbors' is {} for the optimal model.".format(reg.get_params()['n_neighbors'])
    

    用shuffleSplit找到的最优k是8,用k-fold找到的最优k是9。

    预测

    # Show predictions
    for i, friction in enumerate(reg.predict(features)):
        print(friction)
    

    预测表现

    用上面定义的衡量标准来衡量预测表现

    print(performance_metric(y_test, reg.predict(X_test)))
    

    到这里,整个模型就完成了。

    相关文章

      网友评论

      • 13cf2958cecc:非常感谢你提供的解方案
        刘开心_8a6c:@zbdess 我那个html里写了怎么把非数值型变量转成数值型变量 有一个函数叫get_dummy 原理你可以百度一下one hot编码
        13cf2958cecc:@刘开心_8a6c 如果加入电影类型与导演的话,这两个特征如何做量化
        刘开心_8a6c:@zbdess 不客气 想预测准确 很多细节还需要你们仔细思考 每一步都有可优化的点
      • 13cf2958cecc:目前在做一个电影评分的预测,要求使用knn,但又没思路,如果把样例发给你是否可以帮我们解决,一定赞赏支持
      • 13cf2958cecc:不错,最近一个需求考虑用这种方法,能否提供测试数据学习
        刘开心_8a6c:我用的数据是我们项目的保密数据,没法提供给您,不好意思啊

      本文标题:用KNN解决非线性回归问题

      本文链接:https://www.haomeiwen.com/subject/aaxszttx.html