- 在使用sklearn的模型时,我们基本上不关心训练过程,只关心训练结果,然后直接使用模型去预测新数据集。比如,直接使用形如clf.predict(x) 或者clf.predict_proba(x),就能得出预测结果。或者使用joblib模块中的dump和load,将已经训练过的模型导入,直接使用。但是,在实际项目中,可能需要将模型嵌入到java工程中(因为现在的项目很大一部分是java开发的),这就需要考虑java工程和模型的由于不同语言之间存在的相通性而产生的一些兼容性问题或者效率问题,尤其在需要即时反馈的项目中是需要重点考虑的。在考虑模型嵌入问题时,可能有多种解决方法,比如用jpython调用python程序,通过系统的runtime调用python程序,将训练参数导出到对应语言的版本(使用python包m2cgen,porter等)或者用pmml转换成xml文件。通过以上几种方式实现的跨平台部署,可能存在以下问题,一个是在调用python程序时,语言之间的数据通信需要通过序列化和反序列化需要消耗一定时间,而且频繁调用程序对服务器的压力较大。如果是使用工具将参数导出,那么在将参数文件导入项目中时可能存在一些细节性问题,比如预测结果不准确,甚至无法编译运行,且可能参数极多,文件过大,编译时间很长。本文其实是采用参数导出的方式,实现跨平台部署,不过没有使用既有的导出工具,而是自己写了参数导出并应用在java项目中,实现与原版模型预测效果一致。
- 在应用模型时,只需五个参数。
- clf.estimators_.tree_.feature:存放每个节点所采用的特征
- clf.estimators_.tree_.threshold:存放每个节点特征的阈值
- clf.estimators_.tree_.children_left:存放经过特征和阈值分裂之后的左孩子
- clf.estimators_.tree_.children_right:存放经过特征和阈值分裂之后右孩子
- clf.estimators_.tree_.value:存放每个节点属0的个数和属于1的个数(两分类模型),归一化后就是该节点的各分类的概率值。
- 预测新值的一般流程:将新的x值放到训练好的这颗树模型中,根据每个节点的feature和threshold判断是分为children_left还是children_right,直到该x值被放入某个叶子节点中,根据叶子节点的value计算出概率。
- 关键点:实际上在使用模型时,最重要的是判断什么时候x已经到叶子节点了。其实只要feature=-2时,那么这之前x所在的节点就是叶子节点,只要取出该叶子节点的value,就能计算出概率。
- 还不明白的反复看关键点。也可以将五个值对应graphviz的树形图,对应起来观察理解。
网友评论