美文网首页
spark 源码分析系列 - wordcount 源码分析

spark 源码分析系列 - wordcount 源码分析

作者: Rex_2013 | 来源:发表于2020-12-29 18:56 被阅读0次

    前言

    本文主要通过spark wordcount 案例的源码来分析spark中的运行过程。

    spark 编程模型

    在spark中,RDD被表示为对象,通过对象上的方法调用来对RDD进行转换。RDD经过一系列的transformations转换定义之后,就可以调用actions触发RDD的计算,action可以是向应用程序返回结果,或者是向存储系统保存数据。在spark中,只有遇到action,才会执行RDD的计算(即延迟计算)。

    sc.textFile("input").flatMap(_.split(" ")).map((_,1)).reduceByKey(_+_).collect

    算子:从认知心理学角度来讲,解决问题其实是将问题的初始状态,通过一系列的转换操作(operator),变成解决状态。

    spark wordcount 完整示例代码

    
    import org.apache.spark.rdd.RDD
    import org.apache.spark.{SparkConf, SparkContext}
    
    object WordCountScala {
    
    
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf()
        conf.setAppName("wordcount")
        conf.setMaster("local")  
    
        val sc = new SparkContext(conf)
        //单词统计
        //DATASET
    //    val fileRDD: RDD[String] = sc.textFile("bigdata-spark/data/testdata.txt",16)
        //hello world
    
    //    fileRDD.flatMap(  _.split(" ") ).map((_,1)).reduceByKey(  _+_   ).foreach(println)
    
        val fileRDD: RDD[String] = sc.textFile("bigdata-spark/data/testdata.txt")
        //hello world
        val words: RDD[String] = fileRDD.flatMap((x:String)=>{x.split(" ")})
        //hello
        //world
        val pairWord: RDD[(String, Int)] = words.map((x:String)=>{new Tuple2(x,1)})
        //(hello,1)
        //(hello,1)
        //(world,1)
        val res: RDD[(String, Int)] = pairWord.reduceByKey(  (x:Int,y:Int)=>{x+y}   )
        //X:oldValue  Y:value
        //(hello,2)  -> (2,1)
        //(world,1)   -> (1,1)
        //(msb,2)   -> (2,1)
    
        val reverseRDD: RDD[(Int, Int)] = res.map((x)=>{  (x._2,1)  })
        val resOver: RDD[(Int, Int)] = reverseRDD.reduceByKey(_+_)
    
        resOver.foreach(println)
        res.foreach(println)
    
        Thread.sleep(Long.MaxValue)
    
      }
    
    }
    
    

    spark wordcount 源码分析总流程图

    wordcount.png

    textFile() 源码

    在Spark中创建RDD的创建方式可以分为三种:从集合中创建RDD、从外部存储创建RDD、从其他RDD创建。

    本次示例中使用textFile从hadoop中读取数据

    sc.textfile()开始进行分析

     val fileRDD: RDD[String] = sc.textFile("bigdata-spark/data/testdata.txt")
    

    SparkContexttextfile()开始进行分析

      /**
       * Read a text file from HDFS, a local file system (available on all nodes), or any
       * Hadoop-supported file system URI, and return it as an RDD of Strings.
       * @param path path to the text file on a supported file system
       * @param minPartitions suggested minimum number of partitions for the resulting RDD
       * @return RDD of lines of the text file
       */
      def textFile(
          path: String,
          minPartitions: Int = defaultMinPartitions): RDD[String] = withScope {
        assertNotStopped()
         
          //输入文件的格式TextInputFormat,key的类型LongWritable ,value的类型Text
          //最小分区数defaultMinPartitions
        hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
          minPartitions).map(pair => pair._2.toString).setName(path)
      }
    

    textfile传入参数有两个:

    第一个参数传了path,就是文件路径

    第二个参数minpartitions,最小分区数,用户可以自己定义minpartitions,但是不一定会取到它,因为spark取并行度的原则是并行度最高优先,比如用户定义了一个minpartitions为10,而程序计算出来的minpartitions是15,那么取二者的最大值,也就是15,;如果用户定义的并行度为 20,程序算出来的并行度为15,那么就取20。

    hadoopFile() 源码

    textfile方法可以看到hadoopFile方法的调用

    /** Get an RDD for a Hadoop file with an arbitrary InputFormat
       *
       * @note Because Hadoop's RecordReader class re-uses the same Writable object for each
       * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
       * operation will create many references to the same object.
       * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
       * copy them using a `map` function.
       * @param path directory to the input data files, the path can be comma separated paths
       * as a list of inputs
       * @param inputFormatClass storage format of the data to be read
       * @param keyClass `Class` of the key associated with the `inputFormatClass` parameter
       * @param valueClass `Class` of the value associated with the `inputFormatClass` parameter
       * @param minPartitions suggested minimum number of partitions for the resulting RDD
       * @return RDD of tuples of key and corresponding value
       */
      def hadoopFile[K, V](
          path: String,
          inputFormatClass: Class[_ <: InputFormat[K, V]],
          keyClass: Class[K],
          valueClass: Class[V],
          minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope {
        assertNotStopped()
    
        // This is a hack to enforce loading hdfs-site.xml.
        // See SPARK-11227 for details.
        FileSystem.getLocal(hadoopConfiguration)
    
        // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
        val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration))
        val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
        new HadoopRDD(
          this,
          confBroadcast,
          Some(setInputPathsFunc),
          inputFormatClass,
          keyClass,
          valueClass,
          minPartitions).setName(path)
      }
    

    HadoopRDD 源码

    该方法主要是对不同的文件格式来返回一个 HadoopRDD

    class HadoopRDD[K, V](
        sc: SparkContext,
        broadcastedConf: Broadcast[SerializableConfiguration],
        initLocalJobConfFuncOpt: Option[JobConf => Unit],
        inputFormatClass: Class[_ <: InputFormat[K, V]],
        keyClass: Class[K],
        valueClass: Class[V],
        minPartitions: Int)
      extends RDD[(K, V)](sc, Nil) with Logging {
          def this ...
          def getJobConf ...
          def getInputFormat ...
          def getPartitions ...
          def compute ...
          ...
      }
    

    可以看到,这个HadoopRDD是RDD的一个子类,并将sparkcontext和一个Nil作为参数传到RDD父类

    abstract class RDD[T: ClassTag](
        @transient private var _sc: SparkContext,
        @transient private var deps: Seq[Dependency[_]]
      ) extends Serializable with Logging {
    

    这里的第二个参数deps表示的是依赖,因为这个spark LineAge 的源头 ,因此这个依赖指向的是Nill

    RDD的依赖就像是一个单链表,后一个RDD的依赖中包含了前一个RDD。

    此外,HadoopRDD中还包含了2个重要的方法:GetPartitions()compute()

    GetPartitions() 源码

    我们首先来看看GetPartitions()

     override def getPartitions: Array[Partition] = {
        val jobConf = getJobConf()
        // add the credentials here as this can be called before SparkContext initialized
        SparkHadoopUtil.get.addCredentials(jobConf)
        val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
        val inputSplits = if (ignoreEmptySplits) {
          allInputSplits.filter(_.getLength > 0)
        } else {
          allInputSplits
        }
        val array = new Array[Partition](inputSplits.size)
        for (i <- 0 until inputSplits.size) {
          array(i) = new HadoopPartition(id, i, inputSplits(i))
        }
        array
      }
    

    这里有一个FileInputFormat getsplits()方法,了解过MapReduce的对它一定不陌生,因为MapReduce也有getsplits()这个方法,进去看我们就会发现,这个getsplits()方法的实现逻辑和MapReduce几乎一模一样。

    getSplits() 源码
    /** Splits files returned by {@link #listStatus(JobConf)} when
       * they're too big.*/ 
      public InputSplit[] getSplits(JobConf job, int numSplits)
        throws IOException {
        Stopwatch sw = new Stopwatch().start();
        //获取所有 FileStatus
        FileStatus[] files = listStatus(job);
        
        // Save the number of input files for metrics/loadgen
        job.setLong(NUM_INPUT_FILES, files.length);
        long totalSize = 0;                           // compute total size
        for (FileStatus file: files) {                // check we have valid files
          if (file.isDirectory()) {
            throw new IOException("Not a file: "+ file.getPath());
          }
          totalSize += file.getLen();
        }
        //获取目标分片goalsize和最小minsize
        long goalSize = totalSize / (numSplits == 0 ? 1 : numSplits);
        long minSize = Math.max(job.getLong(org.apache.hadoop.mapreduce.lib.input.
          FileInputFormat.SPLIT_MINSIZE, 1), minSplitSize);
    
        // generate splits
        ArrayList<FileSplit> splits = new ArrayList<FileSplit>(numSplits);
        NetworkTopology clusterMap = new NetworkTopology();
        for (FileStatus file: files) {
          Path path = file.getPath();
          long length = file.getLen();
          if (length != 0) {
            FileSystem fs = path.getFileSystem(job);
            BlockLocation[] blkLocations;
            if (file instanceof LocatedFileStatus) {
              blkLocations = ((LocatedFileStatus) file).getBlockLocations();
            } else {
              blkLocations = fs.getFileBlockLocations(file, 0, length);
            }
              //判断文件是否支持切分
            if (isSplitable(fs, path)) {
              long blockSize = file.getBlockSize();
                
              //支持切分就进行切分分片,切分分片大小计算
              long splitSize = computeSplitSize(goalSize, minSize, blockSize);
    
              long bytesRemaining = length;
              while (((double) bytesRemaining)/splitSize > SPLIT_SLOP) {
                String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations,
                    length-bytesRemaining, splitSize, clusterMap);
                splits.add(makeSplit(path, length-bytesRemaining, splitSize,
                    splitHosts[0], splitHosts[1]));
                bytesRemaining -= splitSize;
              }
    
              if (bytesRemaining != 0) {
                String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations, length
                    - bytesRemaining, bytesRemaining, clusterMap);
                splits.add(makeSplit(path, length - bytesRemaining, bytesRemaining,
                    splitHosts[0], splitHosts[1]));
              }
            } else {
              String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations,0,length,clusterMap);
              splits.add(makeSplit(path, 0, length, splitHosts[0], splitHosts[1]));
            }
          } else { 
            //Create empty hosts array for zero length files
            splits.add(makeSplit(path, 0, length, new String[0]));
          }
        }
        sw.stop();
        if (LOG.isDebugEnabled()) {
          LOG.debug("Total # of splits generated by getSplits: " + splits.size()
              + ", TimeTaken: " + sw.elapsedMillis());
        }
        return splits.toArray(new FileSplit[splits.size()]);
      }
    

    getsplits实现分片几个步骤

    1. listStatus 获取所有 FileStatus
    2. 获取目标分片goalsize和最小minsize
    3. 判断文件是否支持切分
    4. 支持切分就进行切分分片,切分分片大小为Math.max(minSize, Math.min(goalSize, blockSize))
    5. 该方法最后返回InputSplit。后面getPartitions方法根据返回的InputSplit构建HadoopPartition
    compute() 源码

    接下来是看compute方法

    注意 compute方法是在执行了action操作才会触发

     override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
        //创建了一个迭代器
         val iter = new NextIterator[(K, V)] {
    
          private val split = theSplit.asInstanceOf[HadoopPartition]
          logInfo("Input split: " + split.inputSplit)
          private val jobConf = getJobConf()
    
          private val inputMetrics = context.taskMetrics().inputMetrics
          private val existingBytesRead = inputMetrics.bytesRead
    
          // Sets InputFileBlockHolder for the file block's information
          split.inputSplit.value match {
            case fs: FileSplit =>
              InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
            case _ =>
              InputFileBlockHolder.unset()
          }
    
          // Find a function that will return the FileSystem bytes read by this thread. Do this before
          // creating RecordReader, because RecordReader's constructor might read some bytes
          private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
            case _: FileSplit | _: CombineFileSplit =>
              Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback())
            case _ => None
          }
    
          // We get our input bytes from thread-local Hadoop FileSystem statistics.
          // If we do a coalesce, however, we are likely to compute multiple partitions in the same
          // task and in the same thread, in which case we need to avoid override values written by
          // previous partitions (SPARK-13071).
          private def updateBytesRead(): Unit = {
            getBytesReadCallback.foreach { getBytesRead =>
              inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
            }
          }
    
          private var reader: RecordReader[K, V] = null
          private val inputFormat = getInputFormat(jobConf)
          HadoopRDD.addLocalConfiguration(
            new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime),
            context.stageId, theSplit.index, context.attemptNumber, jobConf)
    
          reader =
            try {
              inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
            } catch {
              case e: IOException if ignoreCorruptFiles =>
                logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
                finished = true
                null
            }
          // Register an on-task-completion callback to close the input stream.
          context.addTaskCompletionListener { context =>
            // Update the bytes read before closing is to make sure lingering bytesRead statistics in
            // this thread get correctly added.
            updateBytesRead()
            closeIfNeeded()
          }
    
          private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
          private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
    
          override def getNext(): (K, V) = {
            try {
              finished = !reader.next(key, value)
            } catch {
              case e: IOException if ignoreCorruptFiles =>
                logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
                finished = true
            }
            if (!finished) {
              inputMetrics.incRecordsRead(1)
            }
            if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
              updateBytesRead()
            }
            (key, value)
          }
    
          override def close(): Unit = {
            if (reader != null) {
              InputFileBlockHolder.unset()
              try {
                reader.close()
              } catch {
                case e: Exception =>
                  if (!ShutdownHookManager.inShutdown()) {
                    logWarning("Exception in RecordReader.close()", e)
                  }
              } finally {
                reader = null
              }
              if (getBytesReadCallback.isDefined) {
                updateBytesRead()
              } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
                         split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
                // If we can't get the bytes read from the FS stats, fall back to the split size,
                // which may be inaccurate.
                try {
                  inputMetrics.incBytesRead(split.inputSplit.value.getLength)
                } catch {
                  case e: java.io.IOException =>
                    logWarning("Unable to get input size to set InputMetrics for task", e)
                }
              }
            }
          }
        }
        new InterruptibleIterator[(K, V)](context, iter)
      }
    

    这里我们可以看到,compute()方法里创建了一个迭代器NextIterator[(K, V)],我们进入这个迭代器。

    override def hasNext: Boolean = {
        if (!finished) {
          if (!gotNext) {
            nextValue = getNext()
            if (finished) {
              closeIfNeeded()
            }
            gotNext = true
          }
        }
        !finished
      }
    
      override def next(): U = {
        if (!hasNext) {
          throw new NoSuchElementException("End of stream")
        }
        gotNext = false
        nextValue
      }
    

    迭代器有hasNext()next()方法,这就是获取文件中<key,value>的方法,也就是说,compute()方法通过这个迭代器来迭代获取<k,v>数据。值得一提的是:getPartition()compute()方法并没有被执行

    通过textFile()方法 我们可以得到第一个RDD:fileRDD,下个操作是flatMap()方法

    flatMap()源码

      /**
       *  Return a new RDD by first applying a function to all elements of this
       *  RDD, and then flattening the results.
       */
      def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope {
        val cleanF = sc.clean(f)
        new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF))
      }
    

    clean(f)是将函数f序列化,并将其闭包进一个flatMap()中分发出去

    flatMap()方法里又定义了一个新的RDD:MapPartitionsRDD,而且在new MapPartitionsRDD的时候还将 f 这个匿名函数传进去了,并且最终会调用一个迭代器去执行flatMap(cleanF)方法,我们进入MapPartitionsRDD去看看。

    MapPartitionsRDD源码

    private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
        var prev: RDD[T],
        f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
        preservesPartitioning: Boolean = false,
        isOrderSensitive: Boolean = false)
      extends RDD[U](prev) {
    
      override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
    
      override def getPartitions: Array[Partition] = firstParent[T].partitions
    
      override def compute(split: Partition, context: TaskContext): Iterator[U] =
        f(context, split.index, firstParent[T].iterator(split, context))
    
      override def clearDependencies() {
        super.clearDependencies()
        prev = null
      }
    

    注意,这里的prev指的是前一个RDD,也就是fileRDD,MapPartitionsRDD在调用RDD的构造方法时将this也就是fileRDD传给了RDD,这样MapPartitionsRDD就持有了一个对fileRDD的引用。并且函数 f 也被传进来了,我们看一下这个RDD的构造方法。

      /** Construct an RDD with just a one-to-one dependency on one parent */
      def this(@transient oneParent: RDD[_]) =
        this(oneParent.context, List(new OneToOneDependency(oneParent)))
    

    MapPartiotion中也有一个compute()方法,我们看看这个compute()方法。

    compute()源码

    override def compute(split: Partition, context: TaskContext): Iterator[U] =
      f(context, split.index, firstParent[T].iterator(split, context))
    

    传进来的 f 函数在compute被调用是被调起了

    第一个参数sparkcontext是它自己的context

    第二个参数是切片的分区数

    第三个参数是前一个RDD的iterator()方法。但是HadoopRDD类中并没有iterator()方法,所以我们去它的父类RDD中找。

    RDD iterator()源码
      /**
       * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
       * This should ''not'' be called by users directly, but is available for implementors of custom
       * subclasses of RDD.
       */
      final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
        if (storageLevel != StorageLevel.NONE) {
          getOrCompute(split, context)
        } else {
          computeOrReadCheckpoint(split, context)
        }
      }
    

    看这段代码的逻辑,如果没有缓存和持久化,那么它就会调用自己的compute()方法,这样我们发现,MapPartitinsRDD调用了fileRDD的iterator方法,fileRDD的iterator方法又调用了自己的compute()方法,compute()方法会返回一个拉取数据的迭代器。

    flatMap 分析完毕,我们进入map()方法

    map()源码

      /**
       * Return a new RDD by applying a function to all elements of this RDD.
       */
      def map[U: ClassTag](f: T => U): RDD[U] = withScope {
        val cleanF = sc.clean(f)
        new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
      }
    

    跟flatMap()方法一样,这里会定义一个新的RDD:MapPartitionsRDD ,map()方法的执行逻辑和flatMap()方法是一样的。

    接下来是reduceBykey()方法,reduceBykey会触发shuffle,因为之前的RDD只关心自己处理的一条记录,不用管其他的记录,而reduceBykey要处理的是key相同的一组记录,所以需要把相同的key的记录发送到相同的reducer中,我们进入reduceBykey()

    reduceBykey()源码

      /**
       * Merge the values for each key using an associative and commutative reduce function. This will
       * also perform the merging locally on each mapper before sending results to a reducer, similarly
       * to a "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
       * parallelism level.
       */
      def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope {
        reduceByKey(defaultPartitioner(self), func)
      }
    

    这个方法又调用了一次它自己,并且调用时传了2个参数进去,第一个是defaultPartitioner(self),这个defaultPartitioner就是HashPartitioner(又和mapredcue一样)。我们进入这个reduceBykey(defaultPartitioner(self), func)方法。

     /**
       * Merge the values for each key using an associative and commutative reduce function. This will
       * also perform the merging locally on each mapper before sending results to a reducer, similarly
       * to a "combiner" in MapReduce.
       */
      def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope {
        combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)
      }
    

    这里它调用了combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)方法,并传入了4个参数

    第一个参数匿名函数

    第二个参数和第三个参数都是用户定义的function

    第四个参数是partition

    combineByKey() 源码

      def combineByKeyWithClassTag[C](
          createCombiner: V => C,
          mergeValue: (C, V) => C,
          mergeCombiners: (C, C) => C,
          partitioner: Partitioner,
          mapSideCombine: Boolean = true,
          serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope {
        require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
        if (keyClass.isArray) {
          if (mapSideCombine) {
            throw new SparkException("Cannot use map-side combining with array keys.")
          }
          if (partitioner.isInstanceOf[HashPartitioner]) {
            throw new SparkException("HashPartitioner cannot partition array keys.")
          }
        }
        val aggregator = new Aggregator[K, V, C](
          self.context.clean(createCombiner),
          self.context.clean(mergeValue),
          self.context.clean(mergeCombiners))
        if (self.partitioner == Some(partitioner)) {
          self.mapPartitions(iter => {
            val context = TaskContext.get()
            new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
          }, preservesPartitioning = true)
        } else {
          new ShuffledRDD[K, V, C](self, partitioner)
            .setSerializer(serializer)
            .setAggregator(aggregator)
            .setMapSideCombine(mapSideCombine)
        }
      }
    

    combinerbykey()要经历三个阶段

    第一个阶段是第一条记录的处理,第二个阶段是第二条及之后的记录的处理,第三个阶段是合并之前溢写出来多个小文件的处理。

    这也是combinerbykey()要传入三个函数的原因。

    这个函数将三个函数封装进一个aggregator里面,在函数的最后,会 new 一个ShuffledRDD,并调用setSerializer()setAggregator()setMapSideCombine()三个方法,第一个方法是序列化,第二个方法是聚合操作,即执行传入的三个方法,第三个是设置map端聚合。

    我们先来看一下ShuffledRDD。

    ShuffledRDD 源码

    class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
        @transient var prev: RDD[_ <: Product2[K, V]],
        part: Partitioner)
      extends RDD[(K, C)](prev.context, Nil) {
    

    注意,这里prev加了一个@transient注解,@transient的意思就是序列化时忽略这个变量。这意味着shuffledRDD序列化时,无法将其前面的RDD也序列化,后面的RDD也就无法获取shuffledRDD之前的RDD的引用了,所以shuffledRDD需要从前一个RDD的输出中拉取数据,而不是通过迭代器从源头开始计算。
    再看构造shuffledRDD的过程中传入的参数,deps参数是Nil,也就是说它的依赖是空,不依赖前一个RDD了

    shuffledRDD有一个getDependencies()方法来获取依赖。

    getDependencies() 源码
    override def getDependencies: Seq[Dependency[_]] = {
        val serializer = userSpecifiedSerializer.getOrElse {
          val serializerManager = SparkEnv.get.serializerManager
          if (mapSideCombine) {
            serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
          } else {
            serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
          }
        }
        List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
      }
    

    这个方法最终会返回一个List,List里面new 了一个ShuffleDependency,第一个参数是前一个RDD,第二个参数是分区器,第三个是序列化器,第四个是key是否排序,第五个是聚合器,聚合器存的就是combinerBykey()方法中的三个聚合函数,第六个参数是map端是否聚合。这些参数共同构成了一个ShuffleDependency

    shuffledRDD也有一个compute()方法

    compute()源码
      override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
        val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
        SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
          .read()
          .asInstanceOf[Iterator[(K, C)]]
      }
    

    这个compute()同样会被RDD的iterator()方法调用,那么它被调用时会发生哪些事呢?它会拿到上一个RDD的依赖,然后通过sparkEnv来获取ShuffleManager,最终返回一个Reader,这个reader调用read()方法返回一个迭代器Iterator。这里我们发现,它没有调用父类的迭代器,因为前面是一个独立的计算过程,它会将自己的结果输出到一个文件中,shuffledRDD只是从这个文件中拉取上一个计算过程中输出的结果,而不用去重新跑一遍。

    为了更直观的了解spark的运行过程,在reduceTask端再加一个map()操作,这个map方法和之前的一样,从shuffledRDD中获取数据。

    执行foreach(println)作为action算子,进入foreach()方法

      /**
       * Applies a function f to all elements of this RDD.
       */
      def foreach(f: T => Unit): Unit = withScope {
        val cleanF = sc.clean(f)
        sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
      }
    

    foreach()最终又调用了sparkcontext.runJob()

    相关文章

      网友评论

          本文标题:spark 源码分析系列 - wordcount 源码分析

          本文链接:https://www.haomeiwen.com/subject/zmdqoktx.html