美文网首页一步一步学习Spark
Spark Streaming之MapWithStateDSte

Spark Streaming之MapWithStateDSte

作者: 分裂四人组 | 来源:发表于2017-10-17 10:28 被阅读35次

    MapWithStateDStream

    MapWithStateDStreammapWithState算子的结果;

    def stateSnapshots(): DStream[(KeyType, StateType)]
    
    • MapWithStateDStreamsealed abstract class类型,因此所有其实现均在其srouce文件中可见;
    • MapWithStateDStreamImplMapWithStateDStream的唯一实现;

    sealed关键字的作用:

    其修饰的trait,class只能在当前文件里面被继承
    用sealed修饰这样做的目的是告诉scala编译器在检查模式匹配的时候,让scala知道这些case的所有情况,scala就能够在编译的时候进行检查,看你写的代码是否有没有漏掉什么没case到,减少编程的错误。

    MapWithStateDStreamImpl

    • MapWithStateDStreamImpl为内部(私有)、其父依赖为key-value的DStream;
    • 其内部实现依赖`InternalMapWithStateDStream类;
    • slideDuration/dependencies值均取自internalStream变量;

    InternalMapWithStateDStream

    • InternalMapWithStateDStream用于实现MapWithStateDStreamImpl
    • 其集成DStream[MapWithStateRDDRecord[K, S, E]]类,并默认使用MEMORY_ONLY存储级别;
    • 其使用StateSpecHashPartitioner作为其分区;
    • 其强制执行checkpoint(override val mustCheckpoint = true),如果checkpointDuration为空,则设置为sliceDuration窗口大小;

    InternalMapWithStateDStream.compute()

      /** Method that generates an RDD for the given time */
      // 生成给定时间的RDD,其主要作用是将State操作->转换为MapWithRecordRDD
      override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
        // Get the previous state or create a new empty state RDD
        val prevStateRDD = getOrCompute(validTime - slideDuration) match {
          case Some(rdd) =>
            if (rdd.partitioner != Some(partitioner)) {
              // If the RDD is not partitioned the right way, let us repartition it using the
              // partition index as the key. This is to ensure that state RDD is always partitioned
              // before creating another state RDD using it
              // 如果之前的RDD的partition不一致,需要基于partition index作为key进行repartition,
              // 这是确保state RDD 在使用之前是paritition正确
              MapWithStateRDD.createFromRDD[K, V, S, E](
                rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
            } else {
              rdd
            }
          case None =>
            MapWithStateRDD.createFromPairRDD[K, V, S, E](
              spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
              partitioner,
              validTime
            )
        }
    
    
        // Compute the new state RDD with previous state RDD and partitioned data RDD
        // Even if there is no data RDD, use an empty one to create a new state RDD
        // 基于之前的state RDD,计算新的RDD
        // 如果没有data RDD,使用一个空的创建
        val dataRDD = parent.getOrCompute(validTime).getOrElse {
          context.sparkContext.emptyRDD[(K, V)]
        }
        val partitionedDataRDD = dataRDD.partitionBy(partitioner)
        val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
          (validTime - interval).milliseconds
        }
        Some(new MapWithStateRDD(
          prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
      }
    

    下面我们研究MapWithStateRDD.createFromPairRDD方法,

    def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
          pairRDD: RDD[(K, S)],
          partitioner: Partitioner,
          updateTime: Time): MapWithStateRDD[K, V, S, E] = {
        
        // 将pairRDD转换为 MapWithStateRDDRecord()
        val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
          val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
          iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
          Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
        }, preservesPartitioning = true)
    
        val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
    
        val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
    
        new MapWithStateRDD[K, V, S, E](
          stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
      }
    

    MapWithStateRDD

    • 继承RDD, 其Dependencies依赖prevStateRDD和partitionedDataRDD;
    RDD[MapWithStateRDDRecord[K, S, E]](
        partitionedDataRDD.sparkContext,
        List(
          new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
          new OneToOneDependency(partitionedDataRDD))
    

    其compute()逻辑:

     override def compute(
          partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
    
        val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
        val prevStateRDDIterator = prevStateRDD.iterator(
          stateRDDPartition.previousSessionRDDPartition, context)
        val dataIterator = partitionedDataRDD.iterator(
          stateRDDPartition.partitionedDataRDDPartition, context)
    
        val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
        val newRecord = MapWithStateRDDRecord.updateRecordWithData(
          prevRecord,
          dataIterator,
          mappingFunction,
          batchTime,
          timeoutThresholdTime,
          removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
        )
        Iterator(newRecord)
      }
    

    其主要依赖MapWithStateRDDRecord.updateRecordWithData的方法,生成一个Iterator迭代器,其中stateMap存储了key的状态,mappedData存储了mapping function函数的返回值

        // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
        // 如果之前的state map存在,则clone它;
        // 否则则创建一个空的;
        // Key -> State之间的mapping ,存储了key的状态
        val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
        
        // 调动mappingFunction()的返回结果集,mapping function函数的返回值
        val mappedData = new ArrayBuffer[E]
        
        // State的wrap实现
        val wrappedState = new StateImpl[S]()
    
        // Call the mapping function on each record in the data iterator, and accordingly
        // update the states touched, and collect the data returned by the mapping function
        // 此处调用mappingFunction方法,并更新其state存储状态
        dataIterator.foreach { case (key, value) =>
          wrappedState.wrap(newStateMap.get(key))
          val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
          if (wrappedState.isRemoved) {
            newStateMap.remove(key)
          } else if (wrappedState.isUpdated
              || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
            newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
          }
          mappedData ++= returned
        }
    
        // Get the timed out state records, call the mapping function on each and collect the
        // data returned
        // 用户可以设置超时时的处理机制,此处遍历所有超时key,并触发其超时逻辑
        if (removeTimedoutData && timeoutThresholdTime.isDefined) {
          newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
            wrappedState.wrapTimingOutState(state)
            val returned = mappingFunction(batchTime, key, None, wrappedState)
            mappedData ++= returned
            newStateMap.remove(key)
          }
        }
    
        MapWithStateRDDRecord(newStateMap, mappedData)
      }
    

    StateMap

    /** Internal interface for defining the map that keeps track of sessions. */
    private[streaming] abstract class StateMap[K, S] extends Serializable {
    
      /** Get the state for a key if it exists */
      def get(key: K): Option[S]
    
      /** Get all the keys and states whose updated time is older than the given threshold time */
      def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)]
    
      /** Get all the keys and states in this map. */
      def getAll(): Iterator[(K, S, Long)]
    
      /** Add or update state */
      def put(key: K, state: S, updatedTime: Long): Unit
    
      /** Remove a key */
      def remove(key: K): Unit
    
      /**
       * Shallow copy `this` map to create a new state map.
       * Updates to the new map should not mutate `this` map.
       */
      def copy(): StateMap[K, S]
    
      def toDebugString(): String = toString()
    }
    
    • 位置org.apache.spark.streaming.util.StateMap;
    • 存储Spark Streaming 状态信息类;
    • 默认提供EmptyStateMapOpenHashMapBasedStateMap两种实现;
    • OpenHashMap为支持nullabled的HashMap,其性能为jdk默认HashMap的5倍以上,但是当处理0.0/0/0L/non-exist值时,用户需要小心;

    Demo

    object SparkStatefulRunner {
      /**
        * Aggregates User Sessions using Stateful Streaming transformations.
        *
        * Usage: SparkStatefulRunner <hostname> <port>
        * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
        */
      def main(args: Array[String]): Unit = {
        if (args.length < 2) {
          System.err.println("Usage: SparkRunner <hostname> <port>")
          System.exit(1)
        }
    
        val sparkConfig = loadConfigOrThrow[SparkConfiguration]("spark")
    
        val sparkContext = new SparkContext(sparkConfig.sparkMasterUrl, "Spark Stateful Streaming")
        val ssc = new StreamingContext(sparkContext, Milliseconds(4000))
        ssc.checkpoint(sparkConfig.checkpointDirectory)
    
        val stateSpec =
          StateSpec
            .function(updateUserEvents _)
            .timeout(Minutes(sparkConfig.timeoutInMinutes))
    
        ssc
          .socketTextStream(args(0), args(1).toInt)
          .map(deserializeUserEvent)
          .filter(_ != UserEvent.empty)
          .mapWithState(stateSpec)
          .foreachRDD { rdd =>
            if (!rdd.isEmpty()) {
              rdd.foreach(maybeUserSession => maybeUserSession.foreach {
                userSession =>
                  // Store user session here
                  println(userSession)
              })
            }
          }
    
        ssc.start()
        ssc.awaitTermination()
      }
    
      def deserializeUserEvent(json: String): (Int, UserEvent) = {
        json.decodeEither[UserEvent] match {
          case \/-(userEvent) =>
            (userEvent.id, userEvent)
          case -\/(error) =>
            println(s"Failed to parse user event: $error")
            (UserEvent.empty.id, UserEvent.empty)
        }
      }
    
      def updateUserEvents(key: Int,
                           value: Option[UserEvent],
                           state: State[UserSession]): Option[UserSession] = {
        def updateUserSessions(newEvent: UserEvent): Option[UserSession] = {
          val existingEvents: Seq[UserEvent] =
            state
              .getOption()
              .map(_.userEvents)
              .getOrElse(Seq[UserEvent]())
    
          val updatedUserSessions = UserSession(newEvent +: existingEvents)
    
          updatedUserSessions.userEvents.find(_.isLast) match {
            case Some(_) =>
              state.remove()
              Some(updatedUserSessions)
            case None =>
              state.update(updatedUserSessions)
              None
          }
        }
    
        value match {
          case Some(newEvent) => updateUserSessions(newEvent)
          case _ if state.isTimingOut() => state.getOption()
        }
      }
    }
    

    参考:

    相关文章

      网友评论

        本文标题:Spark Streaming之MapWithStateDSte

        本文链接:https://www.haomeiwen.com/subject/nmgruxtx.html