美文网首页
Custom Accumulator in Spark 2.1

Custom Accumulator in Spark 2.1

作者: zoyanhui | 来源:发表于2018-05-04 19:32 被阅读0次

Custom Accumulator in Spark 2.1

Accumulator can sum or count number in spark tasks over all nodes, and then return the final result. For example, LongAccumulator.
But when I want to accumulate value by custom way, it can be implemented by extends AccumulatorV2 in spark 2.1.

Below, a MapLongAccumulator is implemented to count the numbers of several values seperately.

import java.util
import java.util.Collections

import org.apache.spark.util.AccumulatorV2
import scala.collection.JavaConversions._


class MapLongAccumulator[T] extends AccumulatorV2[T, java.util.Map[T, java.lang.Long]] {
    private val _map: java.util.Map[T, java.lang.Long] = Collections.synchronizedMap(new util.HashMap[T, java.lang.Long]())

    override def isZero: Boolean = _map.isEmpty

    override def copyAndReset(): MapLongAccumulator[T] = new MapLongAccumulator

    override def copy(): MapLongAccumulator[T] = {
        val newAcc = new MapLongAccumulator[T]
        _map.synchronized {
            newAcc._map.putAll(_map)
        }
        newAcc
    }

    override def reset(): Unit = _map.clear()

    override def add(v: T): Unit = _map.synchronized {
        val old = _map.put(v, 1l)
        if (old != null) {
            _map.put(v, 1 + old)
        }
    }

    override def merge(other: AccumulatorV2[T, java.util.Map[T, java.lang.Long]]): Unit = other match {
        case o: MapLongAccumulator[T] => {
            for ((k,v) <- o.value) {
                val old = _map.put(k, v)
                if(old != null){
                    _map.put(k, old.longValue() + v)
                }
            }
        }
        case _ => throw new UnsupportedOperationException(
            s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
    }

    override def value: java.util.Map[T, java.lang.Long] = _map.synchronized {
        java.util.Collections.unmodifiableMap(new util.HashMap[T, java.lang.Long](_map))
    }

    def setValue(newValue: java.util.Map[T, java.lang.Long]): Unit = {
        _map.clear()
        _map.putAll(newValue)
    }
}

Use case:

val accumulator = new MapLongAccumulator[String]()
sc.register(accumulator, "CustomAccumulator")
someRdd.map(a => {
            ...
            val convertResult = convert(a)
            convertResultAccumulator.add(convertResult.toString)
            ...
        }).repartition(1).saveAsTextFile(outputPath)
System.out.println(accumulator.value)   // the several convertResults will be counted seperately, and then get the output value.

AccumulatorV2 is new accumulator class since spark 2.0.0.

public abstract class AccumulatorV2<IN,OUT>
extends Object
implements scala.Serializable

The base class for accumulators, that can accumulate inputs of type IN, and produce output of type OUT.
OUT should be a type that can be read atomically (e.g., Int, Long), or thread-safely (e.g., synchronized collections) because it will be read from other threads.

Methods of AccumulatorV2:
merge method doesn't need thread-safe, it is called only in one thread when task completion.(handleTaskCompletion in DAGScheduler)

org.apache.spark.scheduler.DAGScheduler.scala
/**
   * Responds to a task finishing. This is called inside the event loop so it assumes that it can
   * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
   */
  private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
    val task = event.task
    val taskId = event.taskInfo.id
    val stageId = task.stageId
    val taskType = Utils.getFormattedClassName(task)

    outputCommitCoordinator.taskCompleted(
      stageId,
      task.partitionId,
      event.taskInfo.attemptNumber, // this is a task attempt number
      event.reason)

    // Reconstruct task metrics. Note: this may be null if the task has failed.
    val taskMetrics: TaskMetrics =
      if (event.accumUpdates.nonEmpty) {
        try {
          TaskMetrics.fromAccumulators(event.accumUpdates)
        } catch {
          case NonFatal(e) =>
            logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
            null
        }
      } else {
        null
      }

    ......
    val stage = stageIdToStage(task.stageId)
    event.reason match {
      case Success =>
        stage.pendingPartitions -= task.partitionId
        task match {
          case rt: ResultTask[_, _] =>
            // Cast to ResultStage here because it's part of the ResultTask
            // TODO Refactor this out to a function that accepts a ResultStage
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.activeJob match {
              case Some(job) =>
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                  ......
           case smt: ShuffleMapTask =>
            val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
            updateAccumulators(event)
            val status = event.result.asInstanceOf[MapStatus]
            val execId = status.location.executorId
            logDebug("ShuffleMapTask finished on " + execId)
            if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
              logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
            } else {
              shuffleStage.addOutputLoc(smt.partitionId, status)
            }
            ......
    case exceptionFailure: ExceptionFailure =>
        // Tasks failed with exceptions might still have accumulator updates.
        updateAccumulators(event)
    ......
}


/**
   * Merge local values from a task into the corresponding accumulators previously registered
   * here on the driver.
   *
   * Although accumulators themselves are not thread-safe, this method is called only from one
   * thread, the one that runs the scheduling loop. This means we only handle one task
   * completion event at a time so we don't need to worry about locking the accumulators.
   * This still doesn't stop the caller from updating the accumulator outside the scheduler,
   * but that's not our problem since there's nothing we can do about that.
   */
  private def updateAccumulators(event: CompletionEvent): Unit = {
    val task = event.task
    val stage = stageIdToStage(task.stageId)
    try {
      event.accumUpdates.foreach { updates =>
        val id = updates.id
        // Find the corresponding accumulator on the driver and update it
        val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
          case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
          case None =>
            throw new SparkException(s"attempted to access non-existent accumulator $id")
        }
        acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])
        // To avoid UI cruft, ignore cases where value wasn't updated
        if (acc.name.isDefined && !updates.isZero) {
          stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
          event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
        }
      }
    } catch {
      case NonFatal(e) =>
        logError(s"Failed to update accumulators for task ${task.partitionId}", e)
    }
  }

org.apache.spark.executor.TaskMetrics.scala
  /**
   * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
   */
  def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
    val tm = new TaskMetrics
    val (internalAccums, externalAccums) =
      accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))

    internalAccums.foreach { acc =>
      val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
      tmAcc.metadata = acc.metadata
      tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
    }

    tm.externalAccums ++= externalAccums
    tm
  }

相关文章

网友评论

      本文标题:Custom Accumulator in Spark 2.1

      本文链接:https://www.haomeiwen.com/subject/unljrftx.html