广播变量我们通过一个commonJoin和broadcastJoin的例子来讲述:
1.普通join
scala> val personInfo = sc.parallelize(Array(("G301","hello"),("G302","world"),("G303","welcome"))).map(x=>(x._1,x))
personInfo: org.apache.spark.rdd.RDD[(String, (String, String))] = MapPartitionsRDD[19] at map at <console>:24
scala> val personDetail = sc.parallelize(Array(("G301","spark","2018"))).map(x=>(x._1,x))
personDetail: org.apache.spark.rdd.RDD[(String, (String, String, String))] = MapPartitionsRDD[21] at map at <console>:24
scala> personInfo.join(personDetail).map(x=>(x._1,x._2._1._2,x._2._2._2,x._2._2._3)).collect().foreach(println)
(G301,hello,spark,2018)
commonjoin
2.利用广播变量,把小表广播出去,再做join
scala> val personInfo = sc.parallelize(Array(("G301","hello"),("G302","world"),("G303","welcome"))).collectAsMap()
personInfo: scala.collection.Map[String,String] = Map(G302 -> world, G301 -> hello, G303 -> welcome)
scala> val personDetail = sc.parallelize(Array(("G301","spark","2018"))).map(x=>(x._1,x))
personDetail: org.apache.spark.rdd.RDD[(String, (String, String, String))] = MapPartitionsRDD[28] at map at <console>:24
scala> val personBroadcast = sc.broadcast(personInfo)
personBroadcast: org.apache.spark.broadcast.Broadcast[scala.collection.Map[String,String]] = Broadcast(11)
scala> personDetail.mapPartitions(x=>{
| val map = personBroadcast.value
| for ((key,value)<-x if(map.contains(key)))
| yield (key,map.get(key).getOrElse(""),value._2,value._3)
| })foreach(println)
(G301,hello,spark,2018)
通过DAG可以看出全程都没有发生shuffle,过程就是取出personDetail中的每一条记录和广播变量personInfo中的对比,有匹配的就取出,没匹配的就跳过
使用广播变量必须是一个大表一个小表的情况,把小表广播出去
附带其中几个函数的源码:
1.collectAsMap
/**
* Return the key-value pairs in this RDD to the master as a Map.
*
* Warning: this doesn't return a multimap (so if you have multiple values to the same key, only
* one value per key is preserved in the map returned)
*
* @note this method should only be used if the resulting data is expected to be small, as
* all the data is loaded into the driver's memory.
*/
def collectAsMap(): Map[K, V] = self.withScope {
val data = self.collect()
val map = new mutable.HashMap[K, V]
map.sizeHint(data.length)
data.foreach { pair => map.put(pair._1, pair._2) }
map
}
2.contains
/** Tests whether this map contains a binding for a key.
*
* @param key the key
* @return `true` if there is a binding for `key` in this map, `false` otherwise.
*/
def contains(key: A): Boolean = get(key).isDefined
3.broadcast
/**
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* @param value value to broadcast to the Spark nodes
* @return `Broadcast` object, a read-only variable cached on each machine
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
assertNotStopped()
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
4.get
/** Optionally returns the value associated with a key.
*
* @param key the key value
* @return an option value containing the value associated with `key` in this map,
* or `None` if none exists.
*/
def get(key: A): Option[B]
5.getOrElse
/** Returns the option's value if the option is nonempty, otherwise
* return the result of evaluating `default`.
*
* @param default the default expression.
*/
@inline final def getOrElse[B >: A](default: => B): B =
if (isEmpty) default else this.get
IDEA版代码:
object BroadcastJoinApp {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("roadcastJoinApp").setMaster("local[2]")
val sc = new SparkContext(sparkConf)
// commonJoin(sc)
broadcastJoin(sc)
sc.stop()
}
def commonJoin (sc:SparkContext) = {
val personInfo = sc.parallelize(Array(("G301","hello"),("G302","world"),("G303","welcome"))).map(x=>(x._1,x))
val personDetail = sc.parallelize(Array(("G301","spark","2018"))).map(x=>(x._1,x))
personInfo.join(personDetail).collect().foreach(println)
personInfo.join(personDetail).map(x=>(x._1,x._2._1._2,x._2._2._2,x._2._2._3)).collect().foreach(println)
}
def broadcastJoin (sc:SparkContext) = {
val personInfo = sc.parallelize(Array(("G301","hello"),("G302","world"),("G303","welcome"))).collectAsMap()
val personDetail = sc.parallelize(Array(("G301","spark","2018"))).map(x=>(x._1,x))
val personBroadcast = sc.broadcast(personInfo)
personDetail.mapPartitions(x=>{
val map = personBroadcast.value
for ((key,value)<-x if(map.contains(key)))
yield (key,map.get(key).getOrElse(""),value._2,value._3)
})foreach(println)
}
}
网友评论