美文网首页
LightGBM Java实现在线预测

LightGBM Java实现在线预测

作者: 尼小摩 | 来源:发表于2020-11-14 16:37 被阅读0次

    LightGBM是三大知名GBDT的实现之一,支持二分类,多分类。与XGBoost相比,LGBM不需要通过所有样本计算信息增益,而且内置特征降维技术,支持高效率的并行训练,并且具有更快的训练速度、更低的内存消耗、更好的准确率、支持分布式可以快速处理海量数据等优点。 但在Java方面的支持不如XGBoost,没有封装好的Java在线预测包。

    至于XGB和LGB原理和优缺点自行百度,不在本文范围内。

    近期因为公司上线了很多XGBoost模型,在XGBoost训练消耗大量内存,为了节约资源选用LGBM替代XGBoost。在线预测服务就需要用Java封装训练好的LGBM模型,供线上实时预测使用。 在网上百度大多实现方式都是将模型封装为PMML格式,再在预测服务里预测结果。但是PMML版本模型单次预测需要100ms以上,显然不能满足性能需求。

    于是展开Google大法,发现微软开源的mmlspark库(https://github.com/Azure/mmlspark.git),其中有一个包可以将LightGBM部署在spark环境中分布式训练。使用swig封装LightGBM的接口,然后使用jni的方式在spark中调用。赶紧找到打包好的maven lib。

    <dependency>
          <groupId>com.microsoft.ml.lightgbm</groupId>
          <artifactId>lightgbmlib</artifactId>
          <version>2.3.180</version>
    </dependency>
    

    实现代码:

    package com.tuhu.algo.etl.features.model;
    
    import com.microsoft.ml.lightgbm.*;
    import org.apache.commons.lang3.StringUtils;
    
    import java.io.IOException;
    
    /**
     * <p></p>
     *
     * @Author: fc.w
     * @Date: 2020/11/14 16:29
     */
    public class LightGBMModelLoad {
    
        private SWIGTYPE_p_void boosterPtr;
    
        private String modelString;
    
        public LightGBMModelLoad(String modelString) {
            this.modelString = modelString;
            initModel();
        }
    
        public void initModel() {
            try {
                init(modelString);
            } catch (Exception e) {
                throw new RuntimeException("模型加载失败", e);
            }
        }
    
        public void init(String modelString) throws Exception {
            initEnv();
            if (StringUtils.isEmpty(modelString)) {
                throw new Exception("the inpute model string must not null");
            }
            this.boosterPtr = getBoosterPtrFromModelString(modelString);
        }
    
        private void initEnv() throws IOException {
            String osPrefix = NativeLoader.getOSPrefix();
            new NativeLoader("/com/microsoft/ml/lightgbm").loadLibraryByName(osPrefix + "_lightgbm");
            new NativeLoader("/com/microsoft/ml/lightgbm").loadLibraryByName(osPrefix + "_lightgbm_swig");
        }
    
        private void validate(int result) throws Exception {
            if (result == -1) {
                throw new Exception("Booster LoadFromString" + "call failed in LightGBM with error: " + lightgbmlib.LGBM_GetLastError());
            }
        }
    
        private SWIGTYPE_p_void getBoosterPtrFromModelString(String lgbModelString) throws Exception {
            SWIGTYPE_p_p_void boosterOutPtr = lightgbmlib.voidpp_handle();
            SWIGTYPE_p_int numItersOut = lightgbmlib.new_intp();
            validate(
                    lightgbmlib.LGBM_BoosterLoadModelFromString(lgbModelString, numItersOut, boosterOutPtr)
            );
            return lightgbmlib.voidpp_value(boosterOutPtr);
        }
    
        /**
         * 预测
         * @param data 批量向量
         * @param numRows 预测行数
         * @param numFeatures 向量大小
         * @return 批量预测结果
         */
        public double[] predictForMat(double[] data, int numRows, int numFeatures) {
            int data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64;
            int isRowMajor = 1;
            String datasetParams = "";
            SWIGTYPE_p_double scoredDataOutPtr = lightgbmlib.new_doubleArray(numRows * numFeatures);
    
            SWIGTYPE_p_long_long scoredDataLengthLongPtr = lightgbmlib.new_int64_tp();
            lightgbmlib.int64_tp_assign(scoredDataLengthLongPtr, numRows * numFeatures);
    
            SWIGTYPE_p_double doubleArray = lightgbmlib.new_doubleArray(data.length);
            for (int i = 0; i < data.length; i++) {
                lightgbmlib.doubleArray_setitem(doubleArray, i, data[i]);
            }
            SWIGTYPE_p_void pdata = lightgbmlib.double_to_voidp_ptr(doubleArray);
    
            try {
                lightgbmlib.LGBM_BoosterPredictForMat(
                        boosterPtr,
                        pdata,
                        data64bitType,
                        numRows,
                        numFeatures,
                        isRowMajor,
                        0,
                        -1,
                        datasetParams,
                        scoredDataLengthLongPtr,
                        scoredDataOutPtr);
                return predToArray(scoredDataOutPtr, numRows);
            } catch (Exception e) {
                e.printStackTrace();
                System.out.println(lightgbmlib.LastErrorMsg());
            } finally {
                lightgbmlib.delete_doublep(doubleArray);
                lightgbmlib.delete_doublep(scoredDataOutPtr);
                lightgbmlib.delete_int64_tp(scoredDataLengthLongPtr);
            }
            return new double[numRows];
        }
    
        private double[] predToArray(SWIGTYPE_p_double scoredDataOutPtr, int numRows) {
            double[] res = new double[numRows];
            for (int i = 0; i < numRows; i++) {
                res[i] = lightgbmlib.doubleArray_getitem(scoredDataOutPtr, i);
            }
            return res;
        }
    
    }
    

    资料
    LightGBM官网
    https://github.com/Azure/mmlspark
    无痛看懂LightGBM原文

    相关文章

      网友评论

          本文标题:LightGBM Java实现在线预测

          本文链接:https://www.haomeiwen.com/subject/dmdwbktx.html