鸢尾花分类是一个入门级的求解分类问题。这个数据始于1930年,是统计学的一个基础案例。数据不多,大概150条左右。
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import warnings
import sklearn
from sklearn.linear_model import LogisticRegressionCV
from sklearn.linear_model.coordinate_descent import ConvergenceWarning
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier#KNN
from sklearn.preprocessing import label_binarize
from sklearn import metrics
## 设置字符集,防止中文乱码
mpl.rcParams['font.sans-serif']=[u'simHei']
mpl.rcParams['axes.unicode_minus']=False
## 拦截异常
warnings.filterwarnings(action = 'ignore', category=ConvergenceWarning)
数据加载
path = "datas/iris.data"
names = ['sepal length', 'sepal width', 'petal length',
'petal width', 'cla']
df = pd.read_csv(path, header=None, names=names)
df['cla'].value_counts()
df.head()
哑编码处理
def parseRecord(record):
result=[]
r = zip(names,record)
for name,v in r:
if name == 'cla':
if v == 'Iris-setosa':
result.append(1)
elif v == 'Iris-versicolor':
result.append(2)
elif v == 'Iris-virginica':
result.append(3)
else:
result.append(np.nan)
else:
result.append(float(v))
return result
1、数据转换为数字以及分割
### 数据转换
datas = df.apply(lambda r: parseRecord(r), axis=1)
### 异常数据删除
datas = datas.dropna(how='any')
### 数据分割
X = datas[names[0:-1]]
Y = datas[names[-1]]
### 数据抽样(训练数据和测试数据分割)
X_train,X_test,Y_train,Y_test =
train_test_split(X, Y, test_size=0.4, random_state=0)
print ("原始数据条数:%d;训练数据条数:%d;特征个数:%d;测试样本条数:%d"
% (len(X), len(X_train), X_train.shape[1], X_test.shape[0]))
原始数据条数:150;训练数据条数:90;特征个数:4;测试样本条数:60
2. 数据标准化
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
3. 特征选择(这里不进行特征选择操作)
4. 降维处理(这里不做降维处理)
5. 模型构建
lr = LogisticRegressionCV(Cs=np.logspace(-4,1,50),
cv=3,fit_intercept=True, penalty='l2', solver='lbfgs',
tol=0.01, multi_class='multinomial')
#solver:‘newton-cg’,'lbfgs','liblinear','sag' default:liblinear
#'sag'=mini-batch
#'multi_clss':
lr.fit(X_train, Y_train)
LogisticRegressionCV(Cs=array([1.00000e-04, 1.26486e-04, 1.59986e-04, 2.02359e-04, 2.55955e-04,
3.23746e-04, 4.09492e-04, 5.17947e-04, 6.55129e-04, 8.28643e-04,
1.04811e-03, 1.32571e-03, 1.67683e-03, 2.12095e-03, 2.68270e-03,
3.39322e-03, 4.29193e-03, 5.42868e-03, 6.86649e-03, 8.68511e-03,
1...., 2.44205e+00, 3.08884e+00,
3.90694e+00, 4.94171e+00, 6.25055e+00, 7.90604e+00, 1.00000e+01]),
class_weight=None, cv=3, dual=False, fit_intercept=True,
intercept_scaling=1.0, max_iter=100, multi_class='multinomial',
n_jobs=1, penalty='l2', random_state=None, refit=True,
scoring=None, solver='lbfgs', tol=0.01, verbose=0)
6. 模型效果输出
## 将正确的数据转换为矩阵形式
y_test_hot = label_binarize(Y_test,classes=(1,2,3))
print(y_test_hot)
## 得到预测的损失值
lr_y_score = lr.decision_function(X_test)
## 计算roc的值
lr_fpr, lr_tpr, lr_threasholds = metrics.roc_curve(y_test_hot.ravel(),lr_y_score.ravel())
#threasholds阈值
## 计算auc的值
lr_auc = metrics.auc(lr_fpr, lr_tpr)
print ("Logistic算法R值:", lr.score(X_train, Y_train))
print ("Logistic算法AUC值:", lr_auc)
### 7. 模型预测
print(lr_y_score)
lr_y_predict = lr.predict(X_test)
print(lr.predict_proba(X_test))
[[0 0 1]
[0 1 0]
[1 0 0]
[0 0 1]
[1 0 0]
[0 0 1]
[1 0 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 0 1]
[0 1 0]
[0 1 0]
[0 1 0]
[0 1 0]
[1 0 0]
[0 1 0]
[0 1 0]
[1 0 0]
[1 0 0]
[0 0 1]
[0 1 0]
[1 0 0]
[1 0 0]
[0 0 1]
[1 0 0]
[1 0 0]
[0 1 0]
[0 1 0]
[1 0 0]
[0 0 1]
[0 1 0]
[1 0 0]
[0 0 1]
[0 0 1]
[0 1 0]
[1 0 0]
[0 1 0]
[0 1 0]
[0 1 0]
[0 0 1]
[1 0 0]
[0 0 1]
[1 0 0]
[1 0 0]
[0 1 0]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 1 0]
[0 0 1]
[0 1 0]
[0 1 0]
[0 0 1]
[0 0 1]
[0 0 1]
[0 0 1]
[0 1 0]
[0 0 1]]
Logistic算法R值: 0.9777777777777777
Logistic算法AUC值: 0.9269444444444445
[[ -6.27937676 1.05548892 5.22388784]
[ -2.5371109 4.02213826 -1.48502736]
[ 9.58561223 2.6059998 -12.19161203]
[ -8.18346495 2.6530125 5.53045244]
[ 8.06522513 3.206342 -11.27156713]
[ -7.22184333 0.48250241 6.73934092]
[ 8.3061655 2.99869891 -11.30486441]
[ -3.63926189 2.64337134 0.99589054]
[ -4.44558943 3.22131538 1.22427405]
[ -1.96604656 2.94686917 -0.98082261]
[ -4.96636124 2.69911013 2.26725111]
[ -2.62637732 2.30702815 0.31934917]
[ -2.60962466 2.90343809 -0.29381343]
[ -4.0043684 2.89099073 1.11337767]
[ -2.93851571 2.52286105 0.41565466]
[ 7.70822708 3.63929131 -11.34751839]
[ -2.8118009 2.36157169 0.4502292 ]
[ -1.8045935 2.73019664 -0.92560314]
[ 7.12845508 3.45997993 -10.58843501]
[ 8.46288289 2.68843117 -11.15131406]
[ -4.49448454 1.43394608 3.06053846]
[ -1.91130978 1.89262974 0.01868004]
[ 7.84043823 2.94114447 -10.7815827 ]
[ 7.78617371 3.37763355 -11.16380726]
[ -4.70281021 2.21479818 2.48801203]
[ 9.91131131 2.72673147 -12.63804279]
[ 7.85422387 2.42150625 -10.27573012]
[ -2.25629084 2.82839377 -0.57210293]
[ 0.25066183 3.21437436 -3.46503619]
[ 7.33003213 2.93259577 -10.2626279 ]
[ -5.14664337 1.82293543 3.32370794]
[ -1.59680356 1.72359445 -0.12679089]
[ 8.02219758 3.24450592 -11.2667035 ]
[ -4.13742303 1.84449723 2.29292581]
[ -7.2966421 1.69209766 5.60454444]
[ -0.98760882 2.21626947 -1.22866065]
[ 7.48053662 3.10154607 -10.5820827 ]
[ -4.46388278 2.35421825 2.10966453]
[ -1.97859927 2.21830435 -0.23970508]
[ -1.73705971 3.07991182 -1.3428521 ]
[ -6.94768906 1.91731888 5.03037018]
[ 8.26367368 3.26261293 -11.52628661]
[ -6.8945401 1.71742173 5.17711837]
[ 7.00114109 2.55201873 -9.55315982]
[ 8.2506487 3.15179846 -11.40244716]
[ -0.83342074 3.55890736 -2.72548661]
[ -5.57537515 2.03832445 3.53705071]
[ -7.44638593 0.89269467 6.55369126]
[ -5.386268 3.1574685 2.2287995 ]
[ -7.65174457 2.78959492 4.86214965]
[ -1.25902403 3.19064129 -1.93161727]
[ -9.77227645 2.28402866 7.48824779]
[ -2.76747696 1.91870518 0.84877178]
[ -1.23842631 3.26103558 -2.02260927]
[ -5.26636527 2.40614663 2.86021863]
[ -4.36397858 2.60175264 1.76222594]
[ -2.76830109 1.70565742 1.06264367]
[ -6.24433075 2.51899599 3.72533476]
[ -3.53935769 2.89090574 0.64845195]
[ -5.77127503 1.86074046 3.91053457]]
[[9.94308711e-06 1.52409814e-02 9.84749076e-01]
[1.40923429e-03 9.94555266e-01 4.03549934e-03]
[9.99070201e-01 9.29798504e-04 3.48231198e-10]
[1.04795442e-06 5.32800658e-02 9.46718886e-01]
[9.92300592e-01 7.69940419e-03 3.96991379e-09]
[8.62784771e-07 1.91362710e-03 9.98085510e-01]
[9.95069960e-01 4.93003683e-03 3.02615758e-09]
[1.56435834e-03 8.37238489e-01 1.61197152e-01]
[4.11954355e-04 8.80123366e-01 1.19464680e-01]
[7.15748439e-03 9.73671826e-01 1.91706895e-02]
[2.84121940e-04 6.06145229e-01 3.93570649e-01]
[6.29421519e-03 8.73961623e-01 1.19744162e-01]
[3.86037204e-03 9.57021938e-01 3.91176902e-02]
[8.65322130e-04 8.54661679e-01 1.44472999e-01]
[3.77297029e-03 8.88237645e-01 1.07989385e-01]
[9.83191769e-01 1.68082260e-02 5.20994486e-09]
[4.91131150e-03 8.66891307e-01 1.28197381e-01]
[1.03506360e-02 9.64720201e-01 2.49291631e-02]
[9.75119469e-01 2.48805118e-02 1.97110710e-08]
[9.96903714e-01 3.09628334e-03 3.02214789e-09]
[4.37277212e-04 1.64225862e-01 8.35336861e-01]
[1.89512121e-02 8.50485546e-01 1.30563242e-01]
[9.92603267e-01 7.39672472e-03 8.11586415e-09]
[9.87973457e-01 1.20265376e-02 5.81933161e-09]
[4.27698467e-04 4.31933453e-01 5.67638848e-01]
[9.99242387e-01 7.57613088e-04 1.60920285e-10]
[9.95647811e-01 4.35217610e-03 1.33157849e-08]
[5.95532105e-03 9.61956971e-01 3.20877079e-02]
[4.90338189e-02 9.49772774e-01 1.19340698e-03]
[9.87840788e-01 1.21591890e-02 2.26095427e-08]
[1.71351240e-04 1.82279096e-01 8.17549553e-01]
[3.02840778e-02 8.38001677e-01 1.31714245e-01]
[9.91654822e-01 8.34517412e-03 4.16195355e-09]
[9.82713315e-04 3.89351453e-01 6.09665834e-01]
[2.44617907e-06 1.95996503e-02 9.80397904e-01]
[3.78591986e-02 9.32391009e-01 2.97497922e-02]
[9.87617232e-01 1.23827534e-02 1.41283928e-08]
[6.13063609e-04 5.60491705e-01 4.38895231e-01]
[1.36665771e-02 9.08556294e-01 7.77771291e-02]
[7.93188460e-03 9.80303477e-01 1.17646383e-02]
[6.01310014e-06 4.25718447e-02 9.57422142e-01]
[9.93314195e-01 6.68580269e-03 2.52589966e-09]
[5.54497074e-06 3.04808279e-02 9.69513627e-01]
[9.88446166e-01 1.15537705e-02 6.39015637e-08]
[9.93933267e-01 6.06673050e-03 2.89818523e-09]
[1.21982255e-02 9.85962770e-01 1.83900493e-03]
[9.01386278e-05 1.82599114e-01 8.17310747e-01]
[8.28581172e-07 3.46698330e-03 9.96532188e-01]
[1.39586331e-04 7.16705121e-01 2.83155292e-01]
[3.26435596e-06 1.11792749e-01 8.88203987e-01]
[1.14799130e-02 9.82660931e-01 5.85915602e-03]
[3.17300208e-08 5.46332632e-03 9.94536642e-01]
[6.81960060e-03 7.39506484e-01 2.53673915e-01]
[1.09379036e-02 9.84069025e-01 4.99307120e-03]
[1.80743782e-04 3.88322849e-01 6.11496407e-01]
[6.58594413e-04 6.97905585e-01 3.01435820e-01]
[7.41788656e-03 6.50572468e-01 3.42009645e-01]
[3.60169716e-05 2.30341207e-01 7.69622776e-01]
[1.45514725e-03 9.02682174e-01 9.58626786e-02]
[5.52862776e-05 1.14066880e-01 8.85877833e-01]]
画图:预测结果画图
x_test_len = range(len(X_test))
plt.figure(figsize=(12, 9), facecolor='w')
plt.ylim(0.5,3.5)
plt.plot(x_test_len, Y_test, 'ro',markersize = 6,
zorder=3, label=u'真实值')
plt.plot(x_test_len, lr_y_predict, 'go', markersize = 10, zorder=2,
label=u'Logis算法预测值,$R^2$=%.3f' % lr.score(X_test, Y_test))
plt.legend(loc = 'lower right')
plt.xlabel(u'数据编号', fontsize=18)
plt.ylabel(u'种类', fontsize=18)
plt.title(u'鸢尾花数据分类', fontsize=20)
plt.show()
网友评论