美文网首页
SparkML 实现 LR 算法

SparkML 实现 LR 算法

作者: 乌鲁木齐001号程序员 | 来源:发表于2020-06-16 14:39 被阅读0次

    离散特征

    举例
    • 性别的男和女就是离散的特征;
    离散特征 | 处理
    • one-hot 编码,就是一维的编码,比如性别可以抽象成二维的向量,如果是男就是 (1, 0),女就是(0, 1);
    • 如果离散的特征分布的特别广泛, 比如有 10 种分类的方法,one-hot 编码的向量就是十维,其落在哪个维度上面,其对应的维度就是 1,其他的都是 0;

    连续特征

    举例
    • age 从 0 到 100 就是连续的特征;
    • price_per_man 也是连续的特征;
    • 连续的特征一般不会直接进模型;
    连续特征 | 标准化 | 处理
    • z-score 标准化(x-mean) / std
      • 计算特征值,比如 price_per_man 的平均数(mean)和标准差(std);
      • 这样,就可以使 price_per_man 压缩到 0~1 之间;
    • max-min 标准化 (x-min) / (max-min);
      • 这样也可以把 price_per_man 的值压缩在 0~1 之间;
    连续特征 | 离散化 | 处理
    • bucket 编码;
    • 比如 age,比如 1~10 岁的定义为孩子,10~30 定义为青年,30~50 定义为中年,50 以上定义为老年;虽然 age 是离散特征,可以把它当做离散特征落在不同的 bucket 中,然后在基于 bucket 做 one-hot 的编码;

    特征处理

    featurevalue.csv
    "用户id","年龄","性别","门店id","评分","人均价格","是否点击"
    "1","22","M","315","4","193","0"
    "1","16","F","431","3","193","1"
    "1","62","F","489","3","72","1"
    "1","12","M","398","0","216","1"
    "1","76","M","307","3","131","0"
    "1","54","M","490","1","205","0"
    "1","38","M","308","2","227","1"
    "1","56","M","400","3","82","1"
    "1","65","F","426","0","136","0"
    "2","48","F","328","3","64","1"
    
    feature.csv
    • 去掉 featurevalue.csv 中 userid,shopid 这些没有意义的字段,然后将其他内容做了映射;
    • age 分成前 4 列;
    • 性别分在 5, 6 列;
    • 评分使用 max-min 标准化分在第 7 列;
    • 人均价格使用 bucket 编码分布自 8 ~ 11 列;
    • 是否点击落在最后一列;
    "1","0","0","0","1","0","0.8","0","0","1","0","0"
    "1","0","0","0","0","1","0.6","0","0","1","0","1"
    "0","0","0","1","0","1","0.6","0","1","0","0","1"
    "1","0","0","0","1","0","0.0","0","0","0","1","1"
    "0","0","0","1","1","0","0.6","0","0","1","0","0"
    "0","0","1","0","1","0","0.2","0","0","0","1","0"
    "0","1","0","0","1","0","0.4","0","0","0","1","1"
    "0","0","1","0","1","0","0.6","0","1","0","0","1"
    "0","0","0","1","0","1","0.0","0","0","1","0","0"
    "0","0","1","0","0","1","0.6","0","1","0","0","1"
    

    LR 模型生成

    LR 模型生成 | 步骤
    • 用预处理过的特征值,训练生成模型;
    • 生成完了评估一下模型;
    LR 模型生成 | 代码
    package tech.lixinlei.dianping.recommand;
    
    import java.io.IOException;
    
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.ml.classification.LogisticRegression;
    import org.apache.spark.ml.classification.LogisticRegressionModel;
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
    import org.apache.spark.ml.linalg.VectorUDT;
    import org.apache.spark.ml.linalg.Vectors;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.Metadata;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class LRTrain {
    
        public static void main(String[] args) throws IOException {
    
            // 初始化spark运行环境
            SparkSession spark = SparkSession.builder().master("local").appName("DianpingApp").getOrCreate();
    
            // 加载特征及 label 训练文件
            JavaRDD<String> csvFile = spark.read().textFile("file:///home/lixinlei/project/gitee/dianping/src/main/resources/feature.csv").toJavaRDD();
    
            // 做转化
            JavaRDD<Row> rowJavaRDD = csvFile.map(new Function<String, Row>() {
                /**
                 *
                 * @param v1 feature.csv 中的一行数据;
                 * @return
                 * @throws Exception
                 */
                @Override
                public Row call(String v1) throws Exception {
                    v1 = v1.replace("\"", "");
                    String[] strArr = v1.split(",");
                    return RowFactory.create(new Double(strArr[11]),
                                             Vectors.dense(
                                                  Double.valueOf(strArr[0]),
                                                  Double.valueOf(strArr[1]),
                                                  Double.valueOf(strArr[2]),
                                                  Double.valueOf(strArr[3]),
                                                  Double.valueOf(strArr[4]),
                                                  Double.valueOf(strArr[5]),
                                                  Double.valueOf(strArr[6]),
                                                  Double.valueOf(strArr[7]),
                                                  Double.valueOf(strArr[8]),
                                                  Double.valueOf(strArr[9]),
                                                  Double.valueOf(10)));
                }
            });
    
            // 定义列
            StructType schema = new StructType(
                    new StructField[]{
                            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                            new StructField("features",new VectorUDT(),false, Metadata.empty())
                    }
            );
    
            // data 只有两列,第一列 label,第二列是个 11 维的向量;
            Dataset<Row> data = spark.createDataFrame(rowJavaRDD, schema);
    
            // 训练集和测试集
            Dataset<Row>[] dataArr = data.randomSplit(new double[]{0.8, 0.2});
            Dataset<Row> trainData = dataArr[0];
            Dataset<Row> testData = dataArr[1];
    
            // 模型训练 | 逻辑回归
            LogisticRegression lr = new LogisticRegression()
                    .setMaxIter(10) // 迭代次数
                    .setRegParam(0.3)
                    .setElasticNetParam(0.8)
                    .setFamily("multinomial");
            LogisticRegressionModel lrModel = lr.fit(trainData);
            lrModel.save("file:///home/lixinlei/project/gitee/dianping/src/main/resources/lrmode");
    
            // 测试评估
            Dataset<Row> predictions = lrModel.transform(testData);
            MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator();
            double accuracy = evaluator.setMetricName("accuracy").evaluate(predictions);
    
            System.out.println("auc = " + accuracy);
    
        }
    
    }
    

    相关文章

      网友评论

          本文标题:SparkML 实现 LR 算法

          本文链接:https://www.haomeiwen.com/subject/jmalxktx.html