需求:
基于病理数据,进行乳腺癌预测(复发4/正常2),使用Logistic算法构建模型。
数据来源:
http://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+%28Original%29
数据格式:
1000025,5,1,1,1,2,1,3,1,1,2
1002945,5,4,4,5,7,10,3,2,1,2
1015425,3,1,1,1,2,2,3,1,1,2
1016277,6,8,8,1,3,4,3,7,1,2
1017023,4,1,1,3,2,1,3,1,1,2
1017122,8,10,10,8,7,10,9,7,1,4
1018099,1,1,1,1,2,10,3,1,1,2
1018561,2,1,2,1,2,1,3,1,1,2
属性:编号,肿块厚度,细胞大小均匀性,细胞形状均匀性,边缘粘连,单上皮细胞大小,裸核,温和的染色质,正常核,有丝分裂,分类。
数据最后一列分类:(复发4/正常2)
1、头文件、除中文乱码、拦截异常
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import warnings
import sklearn
from sklearn.linear_model import LogisticRegressionCV,LinearRegression
from sklearn.linear_model.coordinate_descent import ConvergenceWarning
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
## 设置字符集,防止中文乱码
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False
## 拦截异常
warnings.filterwarnings(action = 'ignore', category=ConvergenceWarning)
2、数据读取并处理异常数据
path = "datas/breast-cancer-wisconsin.data"
names = ['id','Clump Thickness','Uniformity of Cell Size','Uniformity of Cell Shape',
'Marginal Adhesion','Single Epithelial Cell Size','Bare Nuclei',
'Bland Chromatin','Normal Nucleoli','Mitoses','Class']
df = pd.read_csv(path, header=None,names=names)
datas = df.replace('?', np.nan).dropna(how = 'any') # 只要有列为空,就进行删除操作
datas.head(5) ## 显示一下
注意:要了解一下数据的类型
datas.dtypes
id int64
Clump Thickness int64
Uniformity of Cell Size int64
Uniformity of Cell Shape int64
Marginal Adhesion int64
Single Epithelial Cell Size int64
Bare Nuclei object
Bland Chromatin int64
Normal Nucleoli int64
Mitoses int64
Class int64
dtype: object
datas['Bare Nuclei'].value_counts()
1 402
10 132
5 30
2 30
3 28
8 21
4 19
9 9
7 8
6 4
Name: Bare Nuclei, dtype: int64
3、数据提取以及数据分隔
## 提取
X = datas[names[1:10]]
Y = datas[names[10]]
## 分隔
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.1,random_state=0)
4、数据格式化(归一化)
ss = StandardScaler()
X_train = ss.fit_transform(X_train) ## 训练模型及归一化数据
1~4步是比较标准化的数据清洗工作,和线性回归时一样。
5、模型构建及训练(重点)
先来看看logistic回归LogisticRegressionCV中的一些参数:
-
multi_class: 分类方式参数;参数可选: ovr(默认)、multinomial;这两种方式在二元分类问题中,效果是一样的;在多元分类问题中,效果不一样。
ovr:
one-vs-rest, 对于多元分类的问题,先将其看做二元分类,分类完成后,再迭代对其中一类继续进行二元分类。
例子:共有(0,1,2)三种分类,model1(0,[1,2]); model2(1,[0,2]);model3(2,[0,1]);通过三个模型,构建三个二分类的,从而解决了三分类问题。
multinomial:
many-vs-many(MVM),对于多元分类问题,如果模型有T类,我们每次在所有的T类样本里面选择两类样本出来,
不妨记为T1类和T2类,把所有的输出为T1和T2的样本放在一起,把T1作为正例,T2作为负例,
进行二元逻辑回归,得到模型参数。我们一共需要T(T-1)/2次分类 -
fit_intercept: 是否有截据,如果没有则直线过原点。
-
Cs: 交叉验证 λ 。
np.logspace的用法:01 使用numpy处理数据、ndarry创建
Cs中的每个值描述正则化强度的倒数。如果Cs为int,则以1e-4和1e4之间的对数标度选择Cs值网格。与支持向量机一样,较小的值指定更强的正则化。 -
cv: 整数或交叉验证生成器,默认值:无。使用的默认交叉验证生成器是分层K-Folds。如果提供了整数,则它是使用的折叠数。有关
sklearn.model_selection
可能的交叉验证对象的列表,请参阅模块模块。
-
penalty: 解决逻辑回归中过拟合的参数。过拟合解决参数,L1或者L2。
L = loss(损失函数) + λ
07 回归算法 - 过拟合欠拟合案例
-
solver: 参数优化方式,用默认的即可。
如果想自选solver:
当penalty为L1的时候,参数只能是:liblinear(坐标轴下降法)
;
nlbfgs
和cg
都是关于目标函数的二阶泰勒展开。
当penalty为L2的时候,参数可以是:lbfgs(拟牛顿法)
、newton-cg(牛顿法变种)
,seg(minibatch)
。
维度<10000时,lbfgs
;
维度>10000时,cg
法比较好;
显卡计算的时候,lbfgs
和cg
都比seg
快。 -
tol: 当目标函数下降到该值是就停止,叫:容忍度,防止计算的过多
class_weight: 特征权重参数
lr = LogisticRegressionCV(multi_class='ovr',
fit_intercept=True, Cs=np.logspace(-2, 2, 20),
cv=2, penalty='l1', solver='liblinear', tol=0.01)
re=lr.fit(X_train, Y_train)
从另一篇参考文献角度理解这些参数:
https://www.cnblogs.com/pinard/p/6035872.html
TODO: Logistic回归是一种分类算法,不能应用于回归中(也即是说对于传入模型的y值来讲,不能是float类型,必须是int类型)
6、模型效果获取
r = re.score(X_train, Y_train)
print ("准确率:", r)
print ("稀疏化特征比率:%.2f%%" % (np.mean(lr.coef_.ravel() == 0) * 100))
print ("参数:",re.coef_)
print ("截距:",re.intercept_)
# 预测概率 p
print('预测概率 p= ',re.predict_proba(X_train))
y_hat = re.predict(X_train)
y_hat
准确率: 0.970684039088
稀疏化特征比率(θ为0的个数):0.00%
参数θ: [[ 1.01911603 0.56332225 0.24484925 0.57587411 0.24628713 1.2237774
0.6513926 0.43876176 0.25823551]]
截距: [-0.70401175]
[第一个分类概率,第二个分类的概率]:两者相加=1
预测概率 p=
[[ 0.30234188 0.69765812]
[ 0.98943725 0.01056275]
[ 0.98014558 0.01985442]
...,
[ 0.99251492 0.00748508]
[ 0.98014558 0.01985442]
[ 0.02355113 0.97644887]]
预测值:
array([4, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 4, 2, 4, 2, 2, 4, 2, 4, 4, 2, 4, 2,
4, 4, 2, 2, 4, 4, 4, 2, 2, 2, 4, 4, 2, 2, 4, 4, 2, 2, 4, 2, 2, 4, 2,
2, 2, 2, 2, 2, 2, 4, 2, 2, 4, 4, 2, 4, 2, 4, 2, 2, 4, 2, 2, 4, 2, 4,
2, 2, 2, 4, 2, 2, 2, 4, 4, 2, 4, 2, 4, 2, 2, 2, 2, 2, 4, 4, 2, 4, 4,
4, 4, 2, 4, 2, 2, 2, 2, 2, 2, 4, 4, 4, 2, 2, 2, 4, 2, 2, 4, 4, 2, 4,
2, 2, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 4, 2,
4, 2, 4, 4, 4, 2, 2, 2, 2, 4, 4, 2, 2, 4, 4, 2, 2, 4, 4, 2, 4, 2, 4,
4, 2, 2, 2, 4, 2, 4, 2, 4, 2, 2, 2, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2,
2, 2, 2, 4, 4, 2, 4, 2, 4, 2, 2, 4, 4, 4, 2, 2, 2, 2, 2, 2, 4, 4, 2,
2, 2, 4, 2, 2, 2, 4, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 4, 4,
2, 4, 2, 2, 2, 4, 2, 2, 2, 4, 4, 2, 4, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2,
4, 4, 4, 4, 2, 4, 2, 4, 2, 4, 4, 4, 2, 2, 4, 2, 2, 2, 2, 4, 4, 2, 2,
2, 4, 2, 2, 4, 2, 2, 2, 2, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4,
2, 4, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 2, 2, 4, 2, 4, 2, 4, 2, 2, 2, 2,
4, 2, 4, 2, 2, 2, 4, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 4, 2, 4, 2, 2,
2, 4, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 4, 2, 2, 4, 2, 2, 2, 2, 4, 4, 2,
2, 2, 2, 4, 2, 2, 4, 2, 2, 2, 2, 4, 4, 2, 4, 2, 4, 2, 2, 2, 4, 4, 4,
2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 2, 2, 2, 2, 2, 2, 2, 4, 4, 2, 2, 2, 2,
4, 4, 4, 2, 2, 2, 4, 2, 2, 2, 2, 2, 4, 2, 4, 4, 2, 2, 2, 2, 2, 2, 2,
4, 2, 2, 2, 4, 2, 2, 4, 4, 4, 2, 4, 4, 4, 2, 2, 2, 4, 2, 2, 2, 2, 4,
2, 4, 4, 4, 2, 2, 2, 4, 2, 4, 4, 4, 2, 2, 2, 4, 2, 4, 2, 2, 2, 2, 4,
4, 2, 2, 2, 4, 4, 2, 2, 4, 2, 4, 2, 4, 4, 2, 2, 2, 4, 2, 2, 2, 4, 2,
4, 2, 2, 4, 2, 2, 2, 2, 4, 4, 2, 2, 4, 4, 2, 2, 4, 4, 4, 2, 2, 4, 2,
2, 2, 2, 2, 4, 2, 4, 4, 2, 2, 2, 2, 4, 2, 2, 2, 2, 4, 2, 4, 2, 4, 2,
2, 4, 2, 2, 2, 2, 4, 2, 2, 2, 4, 2, 2, 4, 4, 4, 4, 2, 4, 4, 2, 4, 4,
2, 2, 2, 2, 2, 2, 4, 4, 2, 2, 2, 2, 4, 2, 4, 4, 2, 2, 2, 4, 4, 2, 2,
2, 2, 2, 2, 4, 2, 2, 4, 2, 2, 4, 2, 2, 2, 2, 4], dtype=int64)
lr.decision_function(X_train)
array([ 0.83617074, -4.53980252, -3.89927425, -4.62156243,
-3.55411336, -2.58111119, -3.27236187, 5.05462469,
-3.80389846, -3.89927425, -4.62156243, 6.2800504 ,
-3.40062758, 5.7815552 , -4.35579413, -4.88733072,
9.51137851, -2.71645503, 0.79481497, 5.9890357 ,
-2.07831512, 5.16346462, -4.01063324, 3.29522805,
0.33551 , -0.33814577, -0.14679295, 3.83862642,
5.37992594, 0.8417199 , -2.91121778, -2.81584199,
-4.15415869, 6.12709598, 5.19723322, -3.27236187,
-3.80866893, 5.09348596, 0.46342081, -4.01063324,
-4.88733072, 4.24143555, -2.6783394 , -3.89927425,
7.39351057, -4.46715312, -1.64616132, -3.03476789,
-0.40763651, -4.62156243, -4.52618663, -4.62156243,
2.86529812, -3.89927425, -4.62156243, 6.08154738,
8.57959703, -3.53654891, 1.95085165, -4.35579413,
4.9060374 , -3.99645883, -4.46715312, 5.1692686 ,
-4.53980252, -3.69253947, 5.84999432, -4.14905935,
2.89195289, -4.46715312, -2.91121778, -3.81751434,
0.70208682, -4.35579413, -3.99465005, -3.24621827,
1.79865624, 7.71842293, -4.62156243, 0.99844616,
-3.89927425, 2.53915811, -1.0160766 , -3.78791526,
-4.26041834, -4.99868971, -3.62144598, 5.85109957,
5.99063012, -4.77597174, 3.85872063, 4.09824676,
3.75324419, 6.89369995, -0.61170837, 4.49754974,
-3.55861874, -3.36099446, -2.79244823, -2.76345271,
-2.77872347, -3.99465005, 6.43867029, 5.93793372,
8.98988169, -4.88733072, -4.5694016 , -4.88733072,
2.58454758, -3.89927425, -3.89526309, 3.31242646,
0.3677743 , -1.98806356, 7.26002008, -1.31653505,
-4.62156243, 5.96846522, 10.31976733, -4.35579413,
-4.62156243, -3.63350596, -3.91976283, -3.63350596,
-4.88733072, -3.49128777, -4.72381091, -3.27236187,
-3.60226302, 3.24419901, -4.07239879, -2.91121778,
-4.45804262, -3.8816472 , -4.01063324, 6.007092 ,
2.76866809, -3.64948915, 6.30102268, -2.56892006,
4.22664964, 1.26890893, 7.88341743, -3.81751434,
-4.88733072, -4.46715312, -4.52618663, 6.57465539,
6.65264132, -4.03583907, -1.5605493 , 9.43787009,
5.85954585, -4.26041834, -3.78791526, 1.50517291,
4.07414888, -4.48405984, 4.0670997 , -3.27236187,
6.78184873, 6.542454 , -4.52618663, -4.26041834,
-4.35579413, 3.38252396, -2.86273925, 5.08254766,
-4.15938918, 3.79683776, -3.75150918, -4.88733072,
-3.96340711, -4.71425334, -4.10600903, 5.98888962,
-2.91121778, -2.86909099, 2.6451489 , -4.88733072,
-4.01063324, 7.24421905, -4.35579413, -3.45637026,
-4.35579413, -3.79301461, -2.55007369, 4.1969755 ,
0.29299258, -3.80389846, 8.48038944, -3.44275437,
7.33672926, -2.48782025, -2.38655388, 5.78121283,
7.24757065, 4.45044651, -4.62156243, -4.11443928,
-3.63350596, -4.54135843, -4.48405984, -3.99465005,
3.85089929, 3.36558941, -3.17698607, -3.80389846,
-3.80389846, 1.27716748, -3.53813016, -3.17698607,
-3.82720471, 5.81549838, -2.82249723, -4.16504255,
5.44785102, -4.73292142, -3.99465005, -4.99868971,
-3.44275437, -3.44275437, -4.16504255, -3.80389846,
6.98020737, -3.27391777, -2.51145018, -2.6783394 ,
4.84727663, 2.80989368, -4.14905935, 3.16972085,
-4.88733072, -4.88733072, -4.16504255, 1.8316676 ,
-4.16504255, -3.89927425, -3.91289014, 2.71442568,
3.16468554, -2.87707738, 6.86283152, -4.88733072,
-3.99465005, 3.48770771, -3.99465005, -4.52618663,
-3.19060196, -3.17698607, -3.03948349, -2.82914147,
-4.35579413, 1.45098614, 4.45044651, 2.41279783,
3.64625999, -4.04992268, 1.47376533, -2.58571345,
8.80922017, -3.15195618, 6.74233482, 9.37690191,
1.40550516, -4.73292142, -3.17698607, 2.6042202 ,
-2.14301074, -4.99868971, -3.89927425, -4.79705427,
5.39919603, 7.69864569, -3.03948349, -4.28386697,
-3.03476789, 5.90499227, -4.88733072, -1.41046592,
8.07193464, -2.98010631, -3.3506255 , -3.63350596,
-3.23507294, 5.0459199 , 0.55215957, -3.44275437,
-4.11820016, -0.82304126, -4.62156243, -3.63350596,
-4.37628271, -4.35579413, -4.52618663, -2.86909099,
-4.88733072, -3.17698607, 4.17147326, -4.88733072,
8.00554695, -4.26041834, -3.33649472, -4.16504255,
-4.35579413, -3.17698607, -3.89927425, 6.43124931,
9.08147525, 4.69502243, 3.97446649, -3.99465005,
-4.88733072, 6.44834235, -2.36943163, 3.70334491,
-3.19296927, 6.51568368, -2.23273733, -3.61483788,
-3.68040323, -3.87424436, 10.4874013 , -4.62156243,
5.18595895, -4.88733072, -2.37188689, -4.99868971,
3.31257898, -0.71319518, -2.39118318, -3.89551338,
6.60786816, -2.30631146, -3.98922184, -3.53813016,
-4.62156243, -4.16504255, -4.33712606, -3.53813016,
1.85276476, -3.57460359, 4.03576528, -3.26030189,
-4.45804262, -4.88733072, 3.19328816, -2.7689996 ,
-2.29063178, -3.80389846, -4.35579413, -4.73292142,
3.91922993, -3.17698607, -2.73695298, -4.63754562,
-4.35579413, 7.42016949, -4.26041834, -4.88733072,
1.90607251, -3.64948915, -4.88733072, -3.89927425,
-3.13014368, 3.60685555, 7.56753116, -2.73563799,
-3.53813016, -3.80389846, -4.72381091, 5.20226008,
-3.2958105 , -4.62156243, 4.47097843, -3.17698607,
-4.46715312, -4.35579413, -2.56127095, 4.84234883,
4.93279543, -0.8194471 , 5.52180933, -4.52618663,
2.9004162 , -2.91121778, -2.55007369, -3.89927425,
3.63773548, 3.9643458 , 4.44633198, -4.41992699,
-3.99063888, -3.89927425, -3.53813016, -4.16504255,
-3.27236187, -3.45637026, -2.81426073, 3.72761548,
6.14030394, -3.89927425, -3.63350596, -4.88733072,
-3.53813016, -3.49600337, -0.37201871, -2.74769797,
6.08696315, 3.92137146, -2.28908596, -3.81751434,
-4.35579413, -3.49755928, 9.97392126, 2.74225188,
7.03065921, -3.28455299, -0.13037112, -2.88473051,
7.49856938, -2.05744989, -4.26041834, -3.00865428,
-3.24941804, -4.35579413, 1.47376533, -3.04079076,
6.3425246 , 6.45873121, -4.73292142, -3.17698607,
-2.45950072, -1.75118755, -3.5283523 , -4.62156243,
-4.35579413, 5.59588512, -2.82945787, -4.01063324,
-3.76177167, 2.28791038, -3.03476789, -3.36099446,
3.31790563, 5.56482563, 2.4476091 , -4.52618663,
5.00746323, 3.14192615, 3.27003768, -4.62311834,
-3.34159406, -4.21357595, 0.77692448, -3.89927425,
-1.89543645, -3.99465005, -1.86846665, 6.25926658,
-4.62156243, 9.46951918, 5.39431796, 3.70589169,
-3.12638115, -3.89927425, -2.83929062, 4.9494201 ,
-3.36099446, 4.23635063, 2.30411278, 3.1834564 ,
-2.36439118, -3.03948349, -3.11335802, 5.20400687,
-3.82818076, 5.91994517, -3.12850755, -4.35735004,
-3.53813016, -3.72213855, 2.61434904, 5.28657624,
-4.35579413, -3.53813016, -4.16504255, 6.23334378,
9.79646173, -2.55007369, -4.52618663, 4.25274726,
-4.35579413, 0.19320599, -4.32455119, 4.20400586,
5.29501444, -1.84279266, -0.65757587, -3.27923455,
7.83620319, -2.96763492, -3.53813016, -4.88733072,
4.68273787, -4.62156243, 5.99846683, -3.79301461,
-2.7287653 , 4.69037719, -4.26041834, -4.47379736,
-3.27236187, -1.41590702, 0.91975041, 0.22713335,
-2.66904987, -4.62156243, 4.69318545, 6.29937646,
-4.62156243, -3.44275437, 3.25582969, 6.61857252,
5.83192693, -4.35579413, -4.16504255, 5.84294552,
-4.52618663, -3.27236187, -3.80389846, -4.46715312,
-3.02543995, 6.86464549, -4.16504255, 3.83053713,
4.00845483, -1.96966711, -2.91121778, -2.91121778,
-3.44275437, 5.38129835, -3.45637026, -4.02171717,
-2.82945787, -2.06979858, 4.21454584, -4.17250799,
2.99549431, -3.75705607, 3.51031116, -4.73292142,
-3.02257677, 1.13547285, -3.09522617, -3.17540482,
-2.82945787, -4.62156243, 5.5976563 , -4.52618663,
-2.75221393, -4.35579413, 0.66107561, -4.88733072,
-2.52371988, 1.51277975, 5.37280688, 5.15427952,
2.93808353, -2.6485613 , 2.12954345, 2.67440665,
-3.27236187, 7.24773299, 2.90342619, -3.27236187,
-3.01346626, -4.72381091, -4.00982184, -3.44275437,
-3.96340711, 2.30442705, 2.7999138 , -4.26041834,
-3.17698607, -3.89927425, -0.35231568, 3.84079402,
-3.23710776, 2.97721892, 5.78165353, -1.12567279,
-3.66168027, -4.88733072, 5.52748511, 1.61423847,
-2.75967165, -3.85297523, -4.46715312, -3.94742391,
-3.55174605, -3.53813016, 2.91792804, -4.62156243,
-3.37461035, 8.49172991, -2.42044256, -4.37177733,
9.89288169, -3.35567768, -2.97535063, -4.88733072,
-3.89927425, 3.72474854])
直接去看正负,离我的划分平面有多远。对应之前predict_proba计算出的预测分类值-array,正的划分到4类,负的划分到2类。请结合两组数据对照着看。
7、模型保存和加载
# # 5. 模型相关信息保存
# ## 引入包
from sklearn.externals import joblib
# ## 要求文件夹必须存在
joblib.dump(ss, "datas/logistic/ss.model") ## 将标准化模型保存
joblib.dump(lr, "datas/models/logistic/lr.model") ## 将模型保存
# # 模型加载
# ## 引入包
from sklearn.externals import joblib
oss = joblib.load("models/logistic/ss.model")
olr = joblib.load("models/logistic/lr.model")
8、数据预测
# 数据预测
## a. 预测数据格式化(归一化)
X_test = ss.transform(X_test) # 使用模型进行归一化操作
## b. 结果数据预测
Y_predict = re.predict(X_test)
9、图表展示
x_len = range(len(X_test))
plt.figure(figsize=(14,7), facecolor='w')
plt.ylim(0,6)
plt.plot(x_len, Y_test, 'ro',markersize = 8,
zorder=3, label=u'真实值')
plt.plot(x_len, Y_predict, 'go', markersize = 14, zorder=2,
label=u'预测值,准确率=%.3f' % re.score(X_test, Y_test))
plt.legend(loc = 'upper left')
plt.xlabel(u'数据编号', fontsize=18)
plt.ylabel(u'乳腺癌类型', fontsize=18)
plt.title(u'Logistic回归算法对数据进行分类', fontsize=20)
plt.show()
网友评论