在预测问题中,我们会经常遇到两种常用术语:回归(Regression)和分类(classification),他们的区别是回归算法解决的是预测连续值,而分类问题则是预测的是离散值,因此回归模型的输出是无限的,而分类问题的输出是有限的.
在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析。这种函数是一个或多个称为回归系数的模型参数的线性组合。
本文首先介绍简单线性回归,下一节介绍多元线性回归。简单线性回归就是根据一个特征 X
来预测一个与其相关的变量 Y
。通常我们我们假设这两种变量之间是现行相关的,即可以通过一条直线把这些变量区分开.因此,简单线性回归就是我们试图寻找一个线性函数,该函数以特征 X
为输入,输出一个变量,且在训练集中,使其预测的值尽可能接近目标值.
在本文我们以一个学生学习时间(hours),来预测的他该门课程的分数(scores).我们假设这两个变量之间存在某种线性关系,如图所示
simple_linear_regression.png
我们的目标是找到最有的 和 ,使我们训练集中的数据,我们的预测值和真实值之间的误差最小.即
其中, 为我们预测的值, 为真是值.
1. 数据预处理
数据预处理,我们将按照第一天介绍的模型进行处理:
- 导入相关库
- 导入数据集
- 检查缺失值
- 划分数据集
- 特征标准化
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
dataset = pd.read_csv("studentscores.csv")
X = dataset.iloc[:, 0].values
Y = dataset.iloc[:, 1].values
from sklearn.cross_validation import train_test_split
X_train, Y_train, X_test, Y_test = train_test_split(X, Y, test_size=1/4, random_state=0)
2. 训练线性回归模型
sklearn 机器学习库为我们提供了许多的常用机器学习模型,线性回归模型 LinearRegression 存在于 sklearn.linear_model 文件中, 该文件为我们提供了许多的线性模型.
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(X_train, Y_train)
我们首先通过 LinearRegression() 初始化一个 regressor 实例来表示线性回归模型.然后通过给 .fit
传入我们的训练集的特征和标签来训练 regressor.
注意: 在 sklearn 中对训练数据的格式有一个规定,对于输入 X, 要求其格式是 N*M,其中 N 表示样本数, M 表示每个样本的特征数, 此示例中 M=1.对于标签 Y, 其格式是N*F, N表示样本数, F表示输出值的个数,此处F=1.
print(regressor.coef_, regressor.intercept_)
输出:
[[9.94167834]] [1.93220425]
regressor.coef_表示模型的权重, regressor.intercept_ 表示模型的偏执,分别表示上面模型公式的 和 .
3. 预测结果
当我们通过 .fit
函数训练好后模型,我们可以通过 .predict
函数来预测我们未知的数据.
Y_pred = regressor.predict(X_test)
我们通过手动预测来验证上面提到的 regressor.coef_, regressor.intercept_ 表示的意义:
temp = regressor.intercept_[0] + regressor.coef_[0] * X_test[0]
print(temp, Y_pred[0])
输出:
[16.84472176] [16.84472176]
4. 可视化
首先我们可视化训练集的结果
plt.scatter(X_train, Y_train, color="red")
plt.plot(X_train, regressor.predict(X_train), color='blue')
plt.show()
simple_linear_regression_fig1.png
对测试集进行可视化
plt.scatter(X_test, Y_test, color="red")
plt.plot(X_test, regressor.predict(X_test), color="blue")
plt.show()
simple_linear_regression_fig2.png
其中 .scatter
用于画散点图, .plot
用于画直线.
致谢
感谢大家的阅读和支持, 欢迎大家上星..该博客的原始Github项目地址点击这里
网友评论