美文网首页机器学习机器学习
机器学习快人一步:LASSO回归

机器学习快人一步:LASSO回归

作者: 光行天下 | 来源:发表于2017-11-03 11:14 被阅读92次

回归是监督学习的一个重要问题,回归用于预测输入变量和输出变量之间的关系。

回归模型是表示输入变量到输出变量之间映射的函数。

回归问题的学习等价于函数拟合:使用一条函数曲线使其能够很好的拟合已知数据,以期能很好的预测未知数据。

本文对“广告投放与销量”之间的关系进行拟合。

LASSO模型拟合

训练数据集:一共200行,第1列是行号(可以忽略,不用),第2、3、4列分别是电视、广播和报纸投放的广告费,第5列是投放广告后的商品销量。
训练数据见文末。
测试数据集:从训练数据集中拆分。

Python(2.7)代码如下:

# !/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LassoCV

if __name__ == "__main__":
    np.set_printoptions(suppress=True)  # 关闭科学计数法,便于观察数据

    # pandas读入csv数据文件
    data = pd.read_csv('Advertising.csv')
    x = data[['TV', 'Radio', 'Newspaper']] # x,3个维度
    y = data['Sales'] # y,标签

    # 分隔训练集和测试集,random_state=1指定随机数种子,让训练集和测试集每次都一样,便于调参及模型之间评估,生产上不指定
    x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1, train_size=0.8, test_size=0.2)
    model = LassoCV(alphas=np.logspace(-3, 2, 10), fit_intercept=False) # 从0.001到100,取10个数,作为超参数的候选
    model.fit(x_train, y_train) # 训练模型
    print '超参数:', model.alpha_  # 查看(交叉验证选出的最优超参数)lambda λ

    # 下面3行为了画图好看,按照测试数据的y值做递增排序,并对应排序测试数据的x
    order = y_test.argsort(axis=0)
    y_test = y_test.values[order]
    x_test = x_test.values[order, :]

    y_hat = model.predict(x_test)  # 测试模型

    print 'R2:', model.score(x_test, y_test)  # R2(R平方),经验(但是没有标准):低于0.85说明模型不咋样
    mse = np.average((y_hat - np.array(y_test)) ** 2)  # Mean Squared Error 均方误差
    rmse = np.sqrt(mse)  # Root Mean Squared Error 均方误差的开方
    print 'MSE:', mse
    print 'RMSE:', rmse

    t = np.arange(len(x_test))
    mpl.rcParams['font.sans-serif'] = [u'simHei']
    mpl.rcParams['axes.unicode_minus'] = False
    plt.figure(facecolor='w')
    plt.plot(t, y_test, 'r-', linewidth=2, label=u'真实数据')
    plt.plot(t, y_hat, 'g-', linewidth=2, label=u'预测数据')
    plt.title(u'线性回归预测销量', fontsize=18)
    plt.legend(loc='upper left')
    plt.grid()
    plt.show()

以上代码运行过程中,控制台打印信息:

超参数: 0.599484250319
R2: 0.839581753334
MSE: 2.97963959461
RMSE: 1.72616325839

超参数是sklearn交叉验证在给定alphas的10个参数中自动选择的一个最优“扰动参数λ”作为损失函数中L1正则的λ。

扰动参数Python中记作alpha(就是以下损失函数中的λ)。

LASSO回归的损失函数:

L1正则-LASSO

Ridge回归的损失函数:

L2正则-Ridge

结合L1正则和L2正则的Elastic Net回归的损失函数:

Elastic Net

一般而言,在线性回归中R方低于0.85说明这个模型不咋样。就像参加考试,考到85分(满分100分)以下通常认为成绩不咋样。但是,模型能不能用,也要看具体场景。

RMSE1.7,相比真实销量10-20区间,还不错。

sklearn提供的线性回归模型有很多(如,LinearRegression、LassoCV、RidgeCV),使用上大同小异,可根据需要进行选择。

附:训练数据

,TV,Radio,Newspaper,Sales
1,230.1,37.8,69.2,22.1
2,44.5,39.3,45.1,10.4
3,17.2,45.9,69.3,9.3
4,151.5,41.3,58.5,18.5
5,180.8,10.8,58.4,12.9
6,8.7,48.9,75,7.2
7,57.5,32.8,23.5,11.8
8,120.2,19.6,11.6,13.2
9,8.6,2.1,1,4.8
10,199.8,2.6,21.2,10.6
11,66.1,5.8,24.2,8.6
12,214.7,24,4,17.4
13,23.8,35.1,65.9,9.2
14,97.5,7.6,7.2,9.7
15,204.1,32.9,46,19
16,195.4,47.7,52.9,22.4
17,67.8,36.6,114,12.5
18,281.4,39.6,55.8,24.4
19,69.2,20.5,18.3,11.3
20,147.3,23.9,19.1,14.6
21,218.4,27.7,53.4,18
22,237.4,5.1,23.5,12.5
23,13.2,15.9,49.6,5.6
24,228.3,16.9,26.2,15.5
25,62.3,12.6,18.3,9.7
26,262.9,3.5,19.5,12
27,142.9,29.3,12.6,15
28,240.1,16.7,22.9,15.9
29,248.8,27.1,22.9,18.9
30,70.6,16,40.8,10.5
31,292.9,28.3,43.2,21.4
32,112.9,17.4,38.6,11.9
33,97.2,1.5,30,9.6
34,265.6,20,0.3,17.4
35,95.7,1.4,7.4,9.5
36,290.7,4.1,8.5,12.8
37,266.9,43.8,5,25.4
38,74.7,49.4,45.7,14.7
39,43.1,26.7,35.1,10.1
40,228,37.7,32,21.5
41,202.5,22.3,31.6,16.6
42,177,33.4,38.7,17.1
43,293.6,27.7,1.8,20.7
44,206.9,8.4,26.4,12.9
45,25.1,25.7,43.3,8.5
46,175.1,22.5,31.5,14.9
47,89.7,9.9,35.7,10.6
48,239.9,41.5,18.5,23.2
49,227.2,15.8,49.9,14.8
50,66.9,11.7,36.8,9.7
51,199.8,3.1,34.6,11.4
52,100.4,9.6,3.6,10.7
53,216.4,41.7,39.6,22.6
54,182.6,46.2,58.7,21.2
55,262.7,28.8,15.9,20.2
56,198.9,49.4,60,23.7
57,7.3,28.1,41.4,5.5
58,136.2,19.2,16.6,13.2
59,210.8,49.6,37.7,23.8
60,210.7,29.5,9.3,18.4
61,53.5,2,21.4,8.1
62,261.3,42.7,54.7,24.2
63,239.3,15.5,27.3,15.7
64,102.7,29.6,8.4,14
65,131.1,42.8,28.9,18
66,69,9.3,0.9,9.3
67,31.5,24.6,2.2,9.5
68,139.3,14.5,10.2,13.4
69,237.4,27.5,11,18.9
70,216.8,43.9,27.2,22.3
71,199.1,30.6,38.7,18.3
72,109.8,14.3,31.7,12.4
73,26.8,33,19.3,8.8
74,129.4,5.7,31.3,11
75,213.4,24.6,13.1,17
76,16.9,43.7,89.4,8.7
77,27.5,1.6,20.7,6.9
78,120.5,28.5,14.2,14.2
79,5.4,29.9,9.4,5.3
80,116,7.7,23.1,11
81,76.4,26.7,22.3,11.8
82,239.8,4.1,36.9,12.3
83,75.3,20.3,32.5,11.3
84,68.4,44.5,35.6,13.6
85,213.5,43,33.8,21.7
86,193.2,18.4,65.7,15.2
87,76.3,27.5,16,12
88,110.7,40.6,63.2,16
89,88.3,25.5,73.4,12.9
90,109.8,47.8,51.4,16.7
91,134.3,4.9,9.3,11.2
92,28.6,1.5,33,7.3
93,217.7,33.5,59,19.4
94,250.9,36.5,72.3,22.2
95,107.4,14,10.9,11.5
96,163.3,31.6,52.9,16.9
97,197.6,3.5,5.9,11.7
98,184.9,21,22,15.5
99,289.7,42.3,51.2,25.4
100,135.2,41.7,45.9,17.2
101,222.4,4.3,49.8,11.7
102,296.4,36.3,100.9,23.8
103,280.2,10.1,21.4,14.8
104,187.9,17.2,17.9,14.7
105,238.2,34.3,5.3,20.7
106,137.9,46.4,59,19.2
107,25,11,29.7,7.2
108,90.4,0.3,23.2,8.7
109,13.1,0.4,25.6,5.3
110,255.4,26.9,5.5,19.8
111,225.8,8.2,56.5,13.4
112,241.7,38,23.2,21.8
113,175.7,15.4,2.4,14.1
114,209.6,20.6,10.7,15.9
115,78.2,46.8,34.5,14.6
116,75.1,35,52.7,12.6
117,139.2,14.3,25.6,12.2
118,76.4,0.8,14.8,9.4
119,125.7,36.9,79.2,15.9
120,19.4,16,22.3,6.6
121,141.3,26.8,46.2,15.5
122,18.8,21.7,50.4,7
123,224,2.4,15.6,11.6
124,123.1,34.6,12.4,15.2
125,229.5,32.3,74.2,19.7
126,87.2,11.8,25.9,10.6
127,7.8,38.9,50.6,6.6
128,80.2,0,9.2,8.8
129,220.3,49,3.2,24.7
130,59.6,12,43.1,9.7
131,0.7,39.6,8.7,1.6
132,265.2,2.9,43,12.7
133,8.4,27.2,2.1,5.7
134,219.8,33.5,45.1,19.6
135,36.9,38.6,65.6,10.8
136,48.3,47,8.5,11.6
137,25.6,39,9.3,9.5
138,273.7,28.9,59.7,20.8
139,43,25.9,20.5,9.6
140,184.9,43.9,1.7,20.7
141,73.4,17,12.9,10.9
142,193.7,35.4,75.6,19.2
143,220.5,33.2,37.9,20.1
144,104.6,5.7,34.4,10.4
145,96.2,14.8,38.9,11.4
146,140.3,1.9,9,10.3
147,240.1,7.3,8.7,13.2
148,243.2,49,44.3,25.4
149,38,40.3,11.9,10.9
150,44.7,25.8,20.6,10.1
151,280.7,13.9,37,16.1
152,121,8.4,48.7,11.6
153,197.6,23.3,14.2,16.6
154,171.3,39.7,37.7,19
155,187.8,21.1,9.5,15.6
156,4.1,11.6,5.7,3.2
157,93.9,43.5,50.5,15.3
158,149.8,1.3,24.3,10.1
159,11.7,36.9,45.2,7.3
160,131.7,18.4,34.6,12.9
161,172.5,18.1,30.7,14.4
162,85.7,35.8,49.3,13.3
163,188.4,18.1,25.6,14.9
164,163.5,36.8,7.4,18
165,117.2,14.7,5.4,11.9
166,234.5,3.4,84.8,11.9
167,17.9,37.6,21.6,8
168,206.8,5.2,19.4,12.2
169,215.4,23.6,57.6,17.1
170,284.3,10.6,6.4,15
171,50,11.6,18.4,8.4
172,164.5,20.9,47.4,14.5
173,19.6,20.1,17,7.6
174,168.4,7.1,12.8,11.7
175,222.4,3.4,13.1,11.5
176,276.9,48.9,41.8,27
177,248.4,30.2,20.3,20.2
178,170.2,7.8,35.2,11.7
179,276.7,2.3,23.7,11.8
180,165.6,10,17.6,12.6
181,156.6,2.6,8.3,10.5
182,218.5,5.4,27.4,12.2
183,56.2,5.7,29.7,8.7
184,287.6,43,71.8,26.2
185,253.8,21.3,30,17.6
186,205,45.1,19.6,22.6
187,139.5,2.1,26.6,10.3
188,191.1,28.7,18.2,17.3
189,286,13.9,3.7,15.9
190,18.7,12.1,23.4,6.7
191,39.5,41.1,5.8,10.8
192,75.5,10.8,6,9.9
193,17.2,4.1,31.6,5.9
194,166.8,42,3.6,19.6
195,149.7,35.6,6,17.3
196,38.2,3.7,13.8,7.6
197,94.2,4.9,8.1,9.7
198,177,9.3,6.4,12.8
199,283.6,42,66.2,25.5
200,232.1,8.6,8.7,13.4

相关文章

  • 机器学习快人一步:LASSO回归

    回归是监督学习的一个重要问题,回归用于预测输入变量和输出变量之间的关系。 回归模型是表示输入变量到输出变量之间映射...

  • 统计学 惩罚-LR RR和ENR

    岭回归和lasso回归 1.学习基础 偏差和方差(bias and variance) 我们在机器学习中理解bia...

  • 机器学习基础:用 Lasso 做特征选择

    大家入门机器学习第一个接触的模型应该是简单线性回归,但是在学Lasso时往往一带而过。其实 Lasso 回归也是机...

  • 机器学习之Lasso回归

    生存模型:• Cox单因素分析 (前面一篇文章讲生存分析的时候讲了)• Lasso回归 (本篇文章)• Cox多因...

  • 机器学习快人一步:逻辑回归

    机器学习,模式识别中很重要的一环,就是分类,因为计算机其实无法深层次地理解文字图片目标的意思,只能回答是或者不是。...

  • 广义线性模型(3)线性回归模型—Lasso回归、Ridge回归

    1 Ridge回归和Lasso回归概述 在机器学习中,如果特征很多,但是训练数据量不够大的情况下,学习器很容易把特...

  • ElasticNet回归的python实现及与岭回归、lasso

    ElasticNet回归与岭回归、Lasso回归ElasticNet回归也叫弹性网络回归,是岭回归和Lasso回归...

  • ML06-LASSO回归

    本主题主要说明LASSO回归,LASSO回归与Ridge回归一样,都是属于广义线性回归的一种。LASSO回归与Ri...

  • regression

    lm()即linear model线性模型函数,用来建立OLS回归模型 OLS线性回归 LASSO回归 LASSO...

  • 机器学习

    监督学习: 分类与回归 线性回归: 线性模型:最小二乘法,岭回归,lasso回归 解决线性问题...

网友评论

    本文标题:机器学习快人一步:LASSO回归

    本文链接:https://www.haomeiwen.com/subject/zbxtmxtx.html