参考文章https://www.cnblogs.com/mantch/p/11164221.html
https://blog.csdn.net/v_JULY_v/article/details/81410574
一、原理
如果用一句话定义xgboost,很简单:Xgboost就是由很多CART树集成。但,什么是CART树?
数据挖掘或机器学习中使用的决策树有两种主要类型:
- 分类树分析是指预测结果是数据所属的类(比如某个电影去看还是不看)
- 回归树分析是指预测结果可以被认为是实数(例如房屋的价格,或患者在医院中的逗留时间)
-
而术语分类回归树(CART,Classification And Regression Tree)分析是用于指代上述两种树的总称,由Breiman等人首先提出。
————————————————
事实上,如果不考虑工程实现、解决问题上的一些差异,XGBoost与GBDT比较大的不同就是目标函数的定义。XGBoost的目标函数如下图所示:
image.png
image.png
二、代码实例
import numpy as np
from sklearn import tree
from sklearn.ensemble import GradientBoostingClassifier
from sklearn import datasets
from xgboost import XGBClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from xgboost import plot_tree,plot_importance
from graphviz import Source
X,y = datasets.load_iris(True)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = .2,random_state = 1024)
gbdt = GradientBoostingClassifier(n_estimators=3,learning_rate=0.1)
gbdt.fit(X_train,y_train)
gbdt.score(X_test,y_test)
1.0
gbdt.estimators_.shape
(3, 3)
传统的GBDT在每轮迭代时使用全部的数据
plt.figure(figsize=(12,12))
_ = tree.plot_tree(gbdt[1,1],filled=True)
output_5_0.png
Xgboost树可视化
# 3次for循环!
# 3分类问题,softmax
# 一次循环,构建了3棵树,所以,三次循环,9棵树
# 9棵树和GBDT一样
xgb = XGBClassifier(learning_rate = 0.1,n_estimators=3)
xgb.fit(X_train,y_train)
xgb.score(X_test,y_test)
1.0
plt.figure(figsize=(20,20))
ax = plt.subplot(1,1,1)
plot_tree(xgb,num_trees=1,ax = ax,)
plt.savefig('./xgboost.png')
# 模型训练,花了三天时间(神经网络)
# 模型保存
xgb.save_model('./xgboost.json')
# 模型加载
xgb2 = XGBClassifier()# 没有训练过
xgb2.load_model('./xgboost.json') # 加载模型,原模型,3个estimator
xgb2.set_params(n_estimators =10)#上面设置这个参数,但是不训练,不起作用
xgb2.score(X_test,y_test)
xgb2
plt.figure(figsize=(20,20))
ax = plt.subplot(1,1,1)
plot_tree(xgb2,num_trees=29,ax = ax,)
xgb3 = XGBClassifier(n_estimator3 = 10)
xgb3.fit(X_train,y_train)
xgb3.score(X_test,y_test)
ax = plt.subplot(1,1,1)
plot_tree(xgb3,num_trees=29,ax = ax,)
# 模型,也是智慧的结晶,模型保存下来
# joblib是sklearn下面包
from sklearn.externals import joblib# job工作 ,lib图书馆
joblib.dump(gbdt,'./gbdt')#保存
gbdt2 = joblib.load('./gbdt')
gbdt2.score(X_test,y_test)
网友评论