美文网首页
Alink中机器学习模型的训练,保存及调用

Alink中机器学习模型的训练,保存及调用

作者: 老羊_肖恩 | 来源:发表于2022-02-15 11:29 被阅读0次

  机器学习训练算法比较复杂去数据集规模较大,通常需要在分布式环境中进行,但是使用训练出来的模型进行预测往往简单很多,一般可以单个或者多个节点对模型进行装载,从而进行多路预测。Alink提供了由参数或模型数据直接构建一个本地的java实例,我们称之为LocalPredictor,可以对单条数据进行预测。这样的话,预测任务不再必须由Flink完成,可以嵌入到提供RestAPI的预测服务系统,或者嵌入到用户的业务系统里。
  以酒店评论情感分析为例,数据集为:ChnSentiCorp_htl_all.csv,通过构建完整的中文情感分析pipeline,并将训练好的模型保存在指定位置:

        //数据文件地址
        String url = "D:\\Workspace\\data\\ChnSentiCorp_htl_small.csv";
        //数据schema
        String schemaStr = "label bigint, review string";
        //定义数据源
        BatchOperator data = new CsvSourceBatchOp()
                .setFilePath(url)
                .setSchemaStr(schemaStr)
                .setIgnoreFirstLine(true);
        //Shuffle数据集
        data = new ShuffleBatchOp().linkFrom(data);

        //按照7:3分割数据,生成训练集和测试集
        SplitBatchOp splitter = new SplitBatchOp().setFraction(0.7);
        BatchOperator trainData = splitter.linkFrom(data);
        BatchOperator testData = splitter.getSideOutput(0);

        //构建文本分类pipeline
        Pipeline pipeline = new Pipeline(
                //缺失值填充
                new Imputer()
                        .setSelectedCols("review")
                        .setOutputCols("featureText")
                        .setStrategy("value")
                        .setFillValue("null"),
                //分词
                new Segment()
                        .setSelectedCol("featureText"),
                //停用词过滤
                new StopWordsRemover()
                        .setSelectedCol("featureText"),
                //文本特征生成
                new DocCountVectorizer()
                        .setFeatureType("TF")
                        .setSelectedCol("featureText")
                        .setOutputCol("featureVector"),
                //逻辑回归二分类
                new LogisticRegression()
                        .setVectorCol("featureVector")
                        .setLabelCol("label")
                        .setPredictionCol("pred")
        );

        //模型训练
        PipelineModel model = pipeline.fit(trainData);

        //模型效果评估
        BatchOperator<?> predict = model.transform(testData);
        MultiClassMetrics metrics = new EvalMultiClassBatchOp()
                .setLabelCol("label")
                .setPredictionCol("pred")
                .linkFrom(predict)
                .collectMetrics();
        System.out.println("accuracy:" + metrics.getAccuracy("1"));
        System.out.println("recall:" + metrics.getRecall("1"));
        System.out.println("Macro Precision:" + metrics.getMacroPrecision());
        System.out.println("Micro Recall:" + metrics.getMicroRecall());
        System.out.println("Weighted Sensitivity:" + metrics.getWeightedSensitivity());

        //模型保存
        model.save("D:\\Workspace\\models\\SentimentHotel_model_0001", true);
        //这一行要有,不然不保存
        BatchOperator.execute();

  模型效果的评估结果如下图所示:

模型效果

  如果模型的效果能达到预期,那么将模型保存到指定的位置,方便后续的业务系统进行调用。这里我们可以发现,训练完成的模型保存到本地,生成了一个非常小的模型文件。后期业务系统可以直接使用这个模型对外提供模型预测服务。

模型保存

  业务系统可以使用LocalPredictor对指定位置的模型进行加载和调用,代码如下:

        //数据schema
        String SCHEMA_STR = "review string";

        //读取
        LocalPredictor localPredictor = new LocalPredictor("D:\\Workspace\\models\\SentimentHotel_model_0001", SCHEMA_STR);

        Row[] rows = new Row[]{
                Row.of("这个酒店给人留下了永远难忘的印象--垃圾!!" +
                        "奉劝各位千万不要再去那里了,保你后悔没及!总体印象是:无论酒店,商店," +
                        "交通还是旅游景点,除了收费已和国际接轨了以外," +
                        "其他的都是老、少、边区的水平!!!!!!!!悲哀呀"),
                Row.of("酒店的设施和环境都不错的,就是周围没有什么集市和超市," +
                        "在房间的阳台上就能看到一望无际的大海,真的心情非常的不错." +
                        "唯一的就是每天的早餐都是一样的东西.离机场和市区也不是太远."),
                Row.of("房间很宽敞,干净卫生,就是不知道为啥隔音很差,整体还行"),
                Row.of("价格还不错,出行很方便,临近地铁站,去机场也很方便," +
                        "楼上就四磁悬浮啦,但是房间是真的旧,地板缩水严重。")
        };
        for (Row row : rows) {
            Row predict = localPredictor.map(row);
            System.out.println(predict.getField(0) + "  prediction:" + predict.getField(3));
        }

  该中文情感分析的模型,调用结果如下:

模型调用结果

  以上就是使用Alink进行机器学习建模的全过程,建模的过程中,由于训练数据往往很庞大因此数据处理和模型训练的过程需要放在flink集群中去完成,最终生成满足业务需求的模型。由于生成的数据模型很小,且通常需要内嵌到业务系统中对外提供模型预测服务,因此可以将模型预测的功能于flink集群进行脱离,直接在业务系统中载入模型后对外提供预测服务。

参考:
https://www.yuque.com/pinshu/alink_guide/zo1y6q
https://www.yuque.com/pinshu/alink_guide/pz7rcl

相关文章

  • Alink中机器学习模型的训练,保存及调用

      机器学习训练算法比较复杂去数据集规模较大,通常需要在分布式环境中进行,但是使用训练出来的模型进行预测往往简单很...

  • Demo3 - 保存训练后模型

    训练好的模型,需要保存好,下次就直接拿来用,相当于是机器学习的成果。不用每次都去学习了, 直接保存成文件,然后下个...

  • pytorch学习(十七)—模型的保存与加载

    前言 在深度学习中,模型的保存和加载很重要,当我们辛辛苦苦训练好的一个网络模型,自然需要将训练好的模型保存为文件。...

  • 机器学习weka,java api调用随机森林及保存模型

    工作需要,了解了一下weka的java api,主要是随机森林这一块,刚开始学习,记录下。了解不多,直接上demo...

  • Pytorch Lightning系列 如何使用ModelChe

    在训练机器学习模型时,经常需要缓存模型。ModelCheckpoint是Pytorch Lightning中的一个...

  • iOS机器学习

    核心ML 将机器学习模型集成到您的应用程序中。 使用Core ML,您可以将训练有素的机器学习模型集成到应用程序中...

  • 机器学习中的评价指标

    机器学习中的评价指标 当一个机器学习模型建立好了之后,即模型训练已经完成,我们就可以利用这个模型进行分类识别。 正...

  • 深度学习_模型选择与拟合问题

    模型选择 首先我们需要考虑误差! 首先在机器学习模型中误差有如下两种: 训练误差 泛化误差 训练误差指模型在训练数...

  • Dropout

    1. Dropout简介 1.1 前言 在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很...

  • Tensorflow的模型保存和读取tf.train.Saver

    目标:训练网络后想保存训练好的模型,以及在程序中读取以保存的训练好的模型。 简介 首先,保存和恢复都需要实例化一个...

网友评论

      本文标题:Alink中机器学习模型的训练,保存及调用

      本文链接:https://www.haomeiwen.com/subject/lynklrtx.html