美文网首页
spark自带的ALS算法实现协同过滤

spark自带的ALS算法实现协同过滤

作者: 匪_3f3e | 来源:发表于2018-10-26 20:37 被阅读0次

    环境:spark1.6.0 scala2.11.4
    使用的数据集是tpch数据集

    第一步进行文件的读取,将读取到的dataframe注册成table;(如果存到了hive上可以直接使用hiveContext进行数据处理)
    第二步利用sqlContext查询出(用户,他们所购买的商品),因为没有评分信息,所以评分都默认10分;
    第三步拆分训练集,进行模型训练;
    第四步利用训练好的模型给测试集进行商品推荐
    MyColleborativeFilter.scala

    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.ml.recommendation.ALS
    import org.apache.spark.sql.{Dataset, SQLContext}
    
    object MyColleborativeFilter {
      case class Customer(id: Int, name: String, address: String, nation: String, phone: String, mktsegment:String,comment:String)
      case class Order(id:Int,customer:String,status:String,totalPrice:Double,date:String,priority:String,clerk:String,shipPriority:Double,comment:String)
      case class LineItem(orders:Int,part:Int)
    
    
      def main(args: Array[String]): Unit = {
    
        val path=args(0)
    
        //获取sparkSession
        val conf = new SparkConf()//.setAppName("MyRs").setMaster("local")
        //val sparkSession = new SparkSession(conf)
       // sparkSession.sparkContext.setLogLevel("WARN")
        //获取context
        val sparkContext =new  SparkContext(conf)
        val sqlContext = new SQLContext(sparkContext)
    
    
        import sqlContext.implicits._
    
        //读取三个表的数据
        val customerDf = sparkContext.textFile(path+"/customer/customer.tbl")
          .map(_.split("\\|"))
          .map(u => Customer(u(0).toInt, u(1), u(2), u(3), u(4), u(5), u(6))).toDF()
        customerDf.show()
        customerDf.registerTempTable("customer")
    
        val orderDf=sparkContext.textFile(path+"/orders/orders.tbl")
          .map(_.split("\\|"))
          .map(u=>Order(u(0).toInt, u(1), u(2), u(3).toDouble, u(4), u(5), u(6),u(7).toDouble,u(8)))
          .toDF()
        orderDf.registerTempTable("orders")
        orderDf.show()
    
        val itemlineDf=sparkContext.textFile(path+"/lineitem/lineitem.tbl")
            .map(_.split("\\|"))
            .map(u=>LineItem(u(0).toInt,u(1).toInt))
            .toDF()
        itemlineDf.registerTempTable("itemline")
        //利用sparksql查询数据
        val customerPartDf=sqlContext
          .sql("SELECT c.id customer,i.part part FROM customer c,orders o,itemline i " +
            "WHERE c.id=o.customer and o.id=i.orders")
        //增加评分,默认10
        val resultDf=customerPartDf.withColumn("rating",customerPartDf("customer")*0+10.0)
        resultDf.show()
    
        //生成测试,训练集
        val Array(traing,test) = resultDf.randomSplit(Array(0.8,0.2))
        //进行模型训练
        val als = new ALS()
          .setMaxIter(1)
          .setUserCol("customer")
          .setItemCol("part")
          .setRatingCol("rating")
          .setRegParam(0.01)//正则化参数
        val model = als.fit(traing)
        //model.setColdStartStrategy("drop")
        //model.write.overwrite().save("./ColleborativeFilterModle")
    
      //得出测试集的推荐结果
        val predictions = model.transform(test)
        predictions.show(false)
        //spark2.3.0之后可以用如下代码进行推荐
        //model.recommendForUserSubset(user1,10).show(false)
        //model.recommendForAllUsers(10)
    
        sparkContext.stop()
    
      }
    
    }
    
    

    建议把启动命令写成一个脚本,操作起来会更加的方便,master 可以根据自己的需要进行指定,测试可以用local模式,提交到集群上需要用standlone或yarn-client模式,/program/lxf/sql-1.0.jar 是你scala程序jar包的位置,再后面是我的程序需要传入的参数
    start-yarn.sh

    spark-submit  \
    --master yarn-client \       
    --class com.example.spark.MyColleborativeFilter  \
     /program/lxf/sql-1.0.jar  \       
     hdfs://10.77.20.23:8020/tpch
    

    相关文章

      网友评论

          本文标题:spark自带的ALS算法实现协同过滤

          本文链接:https://www.haomeiwen.com/subject/gyjktqtx.html