在机器学习模型构建好后,常常需要保存模型,下面介绍几种常用的方法。
1. pickle
pickle模块利用二进制对Python对象进行了序列化或反序列化
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=200, n_features=10,
random_state=0, shuffle=False)
model = RandomForestRegressor(random_state=0)
model.fit(X, y)
print(model.predict([range(0,10)]))
import pickle
pkl_model = "model.pkl"
with open(pkl_model, 'wb') as pklM:
pickle.dump(model, pklM)
with open('model.pkl', 'rb') as pklM:
pkl_model = pickle.load(pklM)
pkl_model.predict([range(0,10)])
2. joblib
对于大数据储存,joblib进行了优化,比pickle 效率更高。
from sklearn.ensemble import RandomForestRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=200, n_features=10,
random_state=0, shuffle=False)
model = RandomForestRegressor(random_state=0)
model.fit(X, y)
print(model.predict([range(0,10)]))
import joblib
joblib.dump(model, 'model.pkl')
model = joblib.load('model.pkl')
model.predict([range(0,10)])
参考
Model persistence — scikit-learn 1.0.2 documentation
pickle — Python object serialization — Python 3.10.1 documentation
Persistence — joblib 1.2.0.dev0 documentation
网友评论