美文网首页Spark认知学习数据科学
Spark MLlib学习——分类和回归

Spark MLlib学习——分类和回归

作者: shohokuooo | 来源:发表于2017-05-01 21:36 被阅读593次

    本章节介绍了分类和回归的算法。它还包括讨论特定类别的算法部分,如:线性方法,树和集成。
    下面是整个API Doc中的内容框架,这里不会每个都详细介绍,主要会把用到的介绍出来,后续用到的再陆续添加。(下面的链接都是指向官网文档而不是本笔记中的对应内容所在位置,而且有些内容没有出现在本笔记中)

    Classification 分类

    逻辑回归

    逻辑回归是预测分类问题的流行算法。它是 广义线性模型的一个特例来预测结果的可能性。 在spark.ml逻辑回归中可以使用二项式Logistic回归来预测二分类问题,也可以通过使用多项Logistic回归来预测多分类问题。 使用family参数在这两种算法之间进行选择,或者不设置它,让Spark自己推断出正确的值。

    通过将family参数设置为“多项式”,也可以将多项Logistic回归用于二分类问题。它将产生两个系数的集合和两个intercept。

    当在没有intercept的常量非零列的数据集上对LogisticRegressionModel进行拟合时,Spark MLlib为常数非零列输出零系数。此行为与R glmnet相同,但与LIBSVM不同。

    二分类逻辑回归

    有关二项式逻辑回归实现的更多背景和更多细节,请参阅spark.mllib中逻辑回归的文档
    代码示例:
    以下示例显示了如何用elastic net regularization来训练的二项式和多项Logistic的回归模型用于二分类问题。 elasticNetParam对应于αregParam对应于λ(这两个参数的定义参见Linear methods)
    Java版代码

    public class JavaLogisticRegressionWithElasticNetExample {
      public static void main(String[] args) {
        SparkSession spark = SparkSession
          .builder()
          .appName("JavaLogisticRegressionWithElasticNetExample")
          .getOrCreate();
    
        // $example on$
        // Load training data
        Dataset<Row> training = spark.read().format("libsvm")
          .load("/home/paul/spark/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt");
    
        LogisticRegression lr = new LogisticRegression()
          .setMaxIter(10)
          .setRegParam(0.3)
          .setElasticNetParam(0.8);
    
        // Fit the model
        LogisticRegressionModel lrModel = lr.fit(training);
    
        // Print the coefficients and intercept for logistic regression
        System.out.println("\n---------- Binomial logistic regression's Coefficients: "
          + lrModel.coefficients() + "\nBinomial Intercept: " + lrModel.intercept());
    
        // We can also use the multinomial family for binary classification
        LogisticRegression mlr = new LogisticRegression()
                .setMaxIter(10)
                .setRegParam(0.3)
                .setElasticNetParam(0.8)
                .setFamily("multinomial");
    
        // Fit the model
        LogisticRegressionModel mlrModel = mlr.fit(training);
    
        // Print the coefficients and intercepts for logistic regression with multinomial family
        System.out.println("\n+++++++++ Multinomial coefficients: " + mlrModel.coefficientMatrix()
          + "\nMultinomial intercepts: " + mlrModel.interceptVector());
        // $example off$
    
        spark.stop();
      }
    }
    
    上面代码运行结果,二项式的系数是按照稀疏矩阵格式打印的,多项式的是按照矩阵的格式打印

    spark.ml实现的逻辑回归算法也支持提取出训练集上训练后模型的摘要(这有助于分析模型在训练集上的性能)。 需要注意的是预测结果和权值在BinaryLogisticRegressionSummary中被存储为DataFrame类型并且被标注为@transient,所以只能在driver上可用。
    LogisticRegressionTrainingSummary
    是提供给LogisticRegressionModel
    的摘要。目前只有二分类模型有这个功能,而且必须被显式的强转成类型BinaryLogisticRegressionTrainingSummary
    。对于多分类模型的摘要的支持将在后续版本中实现。
    Java版代码:

    public class JavaLogisticRegressionSummaryExample {
      public static void main(String[] args) {
        SparkSession spark = SparkSession
          .builder()
          .appName("JavaLogisticRegressionSummaryExample")
          .getOrCreate();
    
        // Load training data
        Dataset<Row> training = spark.read().format("libsvm")
          .load("/home/paul/spark/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt");
    
        LogisticRegression lr = new LogisticRegression()
          .setMaxIter(10)
          .setRegParam(0.3)
          .setElasticNetParam(0.8);
    
        // Fit the model
        LogisticRegressionModel lrModel = lr.fit(training);
    
        // $example on$
        // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier
        // example
        LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
    
        // Obtain the loss per iteration.
        double[] objectiveHistory = trainingSummary.objectiveHistory();
        for (double lossPerIteration : objectiveHistory) {
          System.out.println(lossPerIteration);
        }
    
        // Obtain the metrics useful to judge performance on test data.
        // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary
        // classification problem.
        BinaryLogisticRegressionSummary binarySummary =
          (BinaryLogisticRegressionSummary) trainingSummary;
    
        // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
        Dataset<Row> roc = binarySummary.roc();
        roc.show();
        roc.select("FPR").show();
        System.out.println(binarySummary.areaUnderROC());
    
        // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
        // this selected threshold.
        Dataset<Row> fMeasure = binarySummary.fMeasureByThreshold();
        double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0);
        double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure))
          .select("threshold").head().getDouble(0);
        lrModel.setThreshold(bestThreshold);
        // $example off$
    
        spark.stop();
      }
    }
    

    运行结果为:

    0.6833149135741672
    0.6662875751473734
    0.6217068546034618
    0.6127265245887887
    0.6060347986802873
    0.6031750687571562
    0.5969621534836274
    0.5940743031983118
    0.5906089243339022
    0.5894724576491042
    0.5882187775729587
    17/05/02 22:46:21 WARN Executor: 1 block locks were not released by TID = 25:
    [rdd_39_0]
    +---+--------------------+
    |FPR|                 TPR|
    +---+--------------------+
    |0.0|                 0.0|
    |0.0|0.017543859649122806|
    |0.0| 0.03508771929824561|
    |0.0| 0.05263157894736842|
    |0.0| 0.07017543859649122|
    |0.0| 0.08771929824561403|
    |0.0| 0.10526315789473684|
    |0.0| 0.12280701754385964|
    |0.0| 0.14035087719298245|
    |0.0| 0.15789473684210525|
    |0.0| 0.17543859649122806|
    |0.0| 0.19298245614035087|
    |0.0| 0.21052631578947367|
    |0.0| 0.22807017543859648|
    |0.0| 0.24561403508771928|
    |0.0|  0.2631578947368421|
    |0.0|  0.2807017543859649|
    |0.0|  0.2982456140350877|
    |0.0|  0.3157894736842105|
    |0.0|  0.3333333333333333|
    +---+--------------------+
    only showing top 20 rows
    
    17/05/02 22:46:22 WARN Executor: 1 block locks were not released by TID = 27:
    [rdd_39_0]
    +---+
    |FPR|
    +---+
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    |0.0|
    +---+
    only showing top 20 rows
    
    1.0
    
    

    相关文章

      网友评论

        本文标题:Spark MLlib学习——分类和回归

        本文链接:https://www.haomeiwen.com/subject/ussqtxtx.html