美文网首页为防失联
scikit-learn_模型的保存与加载

scikit-learn_模型的保存与加载

作者: Ledestin | 来源:发表于2017-05-16 23:32 被阅读238次

    主要介绍scikit-learn中的模型的保存与加载
    我们训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步。这次主要介绍两种保存Model的模块pickle与joblib。


    Demo.py

    import pickle
    from sklearn.externals import joblib
    from sklearn.svm import SVC
    from sklearn import datasets
    
    # 定义分类器
    svm = SVC()
    
    # 加载iris数据集
    iris = datasets.load_iris()
    # 读取特征
    X = iris.data
    # 读取分类标签
    y = iris.target
    
    # 训练模型
    svm.fit(X, y)
    
    # 第一种:保存成python支持的文件格式pickle, 在当前目录下可以看到svm.pickle
    with open('svm.pickle', 'wb') as fw:
        pickle.dump(svm, fw)
    # 加载svm.pickle
    with open('svm.pickle', 'rb') as fr:
        new_svm = pickle.load(fr)
        print (new_svm.predict(X[0:1]))
        
        
    # 第二种:保存成sklearn自带的文件格式
    joblib.dump(svm, 'svm.pkl')
    # 加载svm.pkl
    new_svm = joblib.load('svm.pkl')
    print (new_svm.predict(X[0:1]))
    

    结果:

    [0] #第一种保存方式产生结果
    [0] #第二种保存方式产生结果
    

    最后可以知道joblib在使用上比较容易,读取速度也相对pickle快。

    相关文章

      网友评论

        本文标题:scikit-learn_模型的保存与加载

        本文链接:https://www.haomeiwen.com/subject/opsrxxtx.html