美文网首页
如何借用Mapperduce框架进行xgboost分布式预测(j

如何借用Mapperduce框架进行xgboost分布式预测(j

作者: 一个菜鸟的自我修养 | 来源:发表于2019-08-05 18:35 被阅读0次

    背景:能分布式地预测数据(当然spark-scala框架本身就可以做到)。本文主要目的是通过一个项目,弄明白MR的执行原理

    实现步骤:

    使用MapReduce进行预测,目前实现3分钟内完成112w个样本,16维特征的数据预测,具体实现思路如下:

    1、mapreduce主入口类 main 函数中传入模型所在hdfs的路径及数据输入输出的hdfs路径

    2、Mapper类中重写Mapper里的setup()、map()、cleanup()三个方法。

    1)setup(Context context)方法获取context调用ml.dmlc.xgboost4j.java.XGBoost.loadModel将训练好的模型load完成

    2)map()里解析数据,封装Dmatrix并进行数据的预测,具体如下:

      - 首先将读取的一行行数据封装成 Dmatrix,达到阈值(暂时定为6000)时执行预测,并将预测值写入hdfs路径

      - cleanup里实现最后的清理收尾预测工作

    接下来分别讲解每个部分的实现。

    mapreduce主入口类 main 函数

    //設置conf

    Configuration conf = new Configuration();

    conf.set("hadoop.tmp.dir",args[0]);  //用于解决自动运行时目录权限问题,可以将此目录指定到一个有权限的目录 

    例如 /tmp suffle过程会有数据落到本地磁盘,这里的路径必须有权限

    很重要,因为默认的路径可能不具备访问的权限。

    conf.set("mapreduce.framework.name","yarn");

    设置运行的模式是yarn模式还是local模式

    conf.set("mapreduce.map.cpu.vcores","8");//指定这个mapreduce任务运行时cpu的个数

    根据数据量来设定合适的,因为集群上默认的map的cpu的核心数是1,未设置之前,任务一度出现map 0% reduce 0%保持不动。

    conf.set("mapreduce.map.memory.mb","8296");//一个 Map Task 可使用的内存上限(单位:MB),默认为 1024

    要想将模型的路径当成参数变量,传给每个map算子。则要考虑如何将模型的hdfs地址广播出去。

    conf.set("xgboost.model", modelPath);

    接下来就是设置Job提交的一些配置

            // 设置执行jar名

            Job = Job.getInstance(conf);

           job.setJarByClass(XGBoost.class);

          // 设置文件读取、输出的路径

            FileInputFormat.setInputPaths(job, new Path(inputFile));

            Path outputFile = new Path(ouptFile1);

            FileOutputFormat.setOutputPath(job, outputFile);

           // 设置mapper的类

            job.setMapperClass(XGBoostMapper.class);

           job.setMapOutputKeyClass(Text.class);//map的输出key值

            job.setMapOutputValueClass(IntWritable.class);//map的输出的value值

            //设置InputFormat类 设置 OutputFormat 类

            job.setInputFormatClass(TextInputFormat.class);

           job.setOutputFormatClass(TextOutputFormat.class);

           job.setNumReduceTasks(0);

           //因为对于本的的例子而言,没有reduce阶段,则将reduce的个数设置为0

            // 设置reduce输出的key value 类型

            job.setOutputKeyClass(Text.class);

            job.setOutputValueClass(IntWritable.class);

           // 提交job

            job.submit();

           // 等待执行完成

            boolean noErr = job.waitForCompletion(true);

           System.exit(noErr? 0 : 1);

          XGBoostMapper extends  Mapper<LongWritable, Text, Text, Text>每部分的实现

        public class XGBoostMapper extends Mapper<LongWritable, Text, Text, Text>{

        private  Booster;//模型变量

        private Text oKey = new Text();//write时的key值

        private Text oValue=new Text();//write时的value值

        private final static int ROUND_NUM = 6000;//每达到ROUND_NUM 开始进行预测

        private List<String> acct = new ArrayList<>(ROUND_NUM );//用来缓存待预测数据的pin值

        private List<String[]> preData = new ArrayList<>(ROUND_NUM );//用来缓存待预测数据的value值

    @Override protected void setup(Mapper.Context context) throws IOException {

         Configuration conf = context.getConfiguration();//目的是拿到模型的存储地址和待预测数据的地址

          String modelPath = conf.get("xgboost.model");//从conf中拿到模型的地址

          FileSystem fs = FileSystem.get(conf);

          FSDataInputStream open = fs.open(new Path(modelPath));

         //如果直接.loadModel(modelPath )则会将该地址解析为本地路径,就会报错,not file://。因为java版的xgboost没有读取hdfs的API,所以需要借助inputStream

          try {

                      booster = ml.dmlc.xgboost4j.java.XGBoost.loadModel(open);

                 } catch (XGBoostError xgBoostError) {

                        xgBoostError.printStackTrace();

                   }

              }

    @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {

    //读取每一行的数据,并将其转换成array

            String[] line = value.toString().split("\001");//输入数据的文件是以\001间隔的,注意检查输入数据的分隔符。即待预测数据的hive的数据存储时的间隔符

            try {

    predictUntilRoundNum(context, line);

    //当输入数据达到某个值的时候开始预测

    } catch (XGBoostError xgBoostError) {

    xgBoostError.printStackTrace();

    System.out.println("执行失败");

    }

        }

    private void predictUntilRoundNum(Context context, String[] line) throws IOException, InterruptedException, XGBoostError {

    acct.add(line[0]); //存放账号,hive表的第一列是用户的pin

    preData.add(Arrays.copyOfRange(line,3,line.length));//所选用的特征列从第3列开始

    //如果达到ROUND_NUM行数据

     if (acct.size() == ROUND_NUM){

    predictNow(context, acct,preData);

    //清空数据

                acct = new ArrayList<>();//清空存储账号的list

                preData = new ArrayList<>();//清空要组装成Dmatrix的list

            }

        }

    private void predictNow(Context context, List<String> acct, List<String[]> preData) throws InterruptedException, XGBoostError, IOException {

    if (!preData.isEmpty()) {

    DMatrix dMatrix = buildDMatrix(preData);

    //开始预测

     float[][] predict = booster.predict(dMatrix);

    System.out.println("rowPredixt=" + predict.length + " colPredict=" + predict[0].length);

    //write data to hdfs use context.write

    for (int i = 0; i < acct.size(); i++) {

    oKey.set(acct.get(i));//用户名作为输出key

    oValue.set(String.valueOf(predict[i][0]));

    context.write(oKey, oValue);

    }

        }

    }

    /**多维数组转成dmatrix数据

         *@param data:输入的数据,是一个二维数据

         *@return

    :返回的是一个Dmatrix的数据

         */

        private DMatrix buildDMatrix(List<String[]> data) throws XGBoostError {

    //        System.out.println("数组的row:"+data.size()+"数组的列:"+data.get(0).length);

            int num = 0;

    int col = data.get(0).length;

    int row = data.size();//行数,也就是Dmatrix的行数

    float[] resData = new float[row*col];

    for (String[] str: data){

    for (String aStr : str) {

    resData[num] = Float.valueOf(aStr);

    num++;

    }

            }

    return new DMatrix(resData, row, col);

    //其中Dmatrix的构造方法的第二列表示样本数,col表示feature的个数    }

    其中为什么设计了读到多少行后才开始进行Dmatrix的封装和预测。因为在封装Dmatrix和预测的过程中本身就很耗时。但是不全部进行预测的是因为当数据量太大的时候,内存可能不够。

    @Override

    protected void cleanup(Context context)throws IOException, InterruptedException {

    if (!acct.isEmpty()) {

    try {

    predictNow(context, acct, preData);

            }catch (XGBoostError xgBoostError) {

    xgBoostError.printStackTrace();

            }

    System.out.println("执行失败");

        }

    }

    该步是为了计算最后的数据。因为不可能保证数据的行数是ROUND_NUM的倍数,在map执行完后,list中还会有一些数据未被预测,所以需要在最后进行最后数据的预测工作。

    相关文章

      网友评论

          本文标题:如何借用Mapperduce框架进行xgboost分布式预测(j

          本文链接:https://www.haomeiwen.com/subject/tatzrctx.html