小矩阵相乘,《智能与并行程序设计》一课作业
Spark scala RDD
主要思路 键值:应该是目标矩阵的位置。
矩阵相乘,是先乘 后加
所以对于矩阵1 映射(目标矩阵坐标,列),值
对于矩阵2映射 (目标矩阵坐标,行),值
然后reduce,同样的相乘(因为目标矩阵坐标相同,矩阵1列=矩阵2行)
这时就需要再相加了
去掉键值中的第三个元素 行or列
然后reduce 相加
import org.apache.spark.SparkConf
import org.apache.spark.rdd._
import org.apache.spark.SparkContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.SparkSession
import scala.util.parsing.json.JSON
import org.apache.spark.sql.{ DataFrame, Dataset, SparkSession, Row }
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.functions._
import java.util.Calendar
import org.apache.spark.SparkContext
import scala.collection.mutable.ListBuffer
import org.apache.spark.sql.functions
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.mllib.recommendation.{ALS,Rating,MatrixFactorizationModel}
import org.apache.log4j.Logger
import org.apache.log4j.Level
import java.io.File
object MatrixM {
def main(args: Array[String]) {
// val conf = new SparkConf()
val conf = new SparkConf().setMaster("local").setAppName("MatrixM");//本地调试
val sc = new SparkContext(conf)
// val mt1 = sc.textFile(args(0));
// val mt2 = sc.textFile(args(1))
val mt1 = sc.textFile("./mt1.txt");
val mt2 = sc.textFile("./mt2.txt");
val index=3
val col=2
val mt1nums = mt1.map(_.split(" ").take(3).map(_.toInt))//1 1 0->1 1行, a肯定 1列 0值
val mt2nums = mt2.map(_.split(" ").take(3).map(_.toInt))
// val mt1nums = mt1values.map{case Array(user,movie,rating)=>Rating(user.toInt,movie.toInt,rating.toDouble)}
mt1nums.collect().foreach(println)
mt1nums.collect()(0).foreach(println)
var pairs=mt1nums.map(x=>((x(0),1,x(1)),x(2)))
for(i<-2 to col){
pairs=pairs.union(mt1nums.map(x=>((x(0),i,x(1)),x(2))))
}
for(i<-1 to index){
pairs=pairs.union(mt2nums.map(x=>((i,x(1),x(0)),x(2))))
}
pairs=pairs.reduceByKey((x,y)=>x*y)
pairs=pairs.sortByKey()
pairs.collect().toArray.foreach(println)
var newpairs=pairs.map(x=>((x._1._1,x._1._2),x._2))
newpairs=newpairs.sortByKey()
newpairs.collect().toArray.foreach(println)
newpairs=newpairs.reduceByKey((x,y)=>x+y)
newpairs=newpairs.sortByKey()
newpairs.collect().foreach(println)
newpairs.saveAsTextFile("output")
sc.stop()
}
}
网友评论