美文网首页
Spark中的一些数据结构

Spark中的一些数据结构

作者: 天之見證 | 来源:发表于2020-02-04 18:36 被阅读0次

1. ChunkedByteBuffer

ChunkedByteBuffer 是一个只读的bytebuffer, 其实是一个bytebuffer的数组

private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
  require(chunks != null, "chunks must not be null")
  require(chunks.forall(_.position() == 0), "chunks' positions must be 0")

    /**
   * Write this buffer to a channel.
   */
  def writeFully(channel: WritableByteChannel): Unit = {
    for (bytes <- getChunks()) {
      val originalLimit = bytes.limit()
      while (bytes.hasRemaining) {
        // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct
        // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread.
        // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may
        // cause significant native memory leak, if a large direct ByteBuffer is allocated and
        // cached, as it's never released until thread exits. Here we write the `bytes` with
        // fixed-size slices to limit the size of the cached direct ByteBuffer.
        // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details.
        val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
        bytes.limit(bytes.position() + ioSize)
        channel.write(bytes)
        bytes.limit(originalLimit)
      }
    }
  }
}

为什么在写入channel的时候需要指定大小来写入

因为当一个on-heapbytebuffer 将要写入channel的时候, JAVA的nio api会将它先copy到一个direct bytebuffer, 并且是一个线程一个(使用ThreadLocal来实现), 当源bytebuffer很大的时候, 很容易造成native内存泄漏

多线程copyon-heapbytebuffer 容易耗费太多内存, 具体参见

具体如果我们要释放direct bytebuffer占用的内存, 需使用内置的Cleaner 来处理

2. BitSet

内容存储

private val words = new Array[Long](bit2words(numBits))
private val numWords = words.length

从它的数据结构基本可以推断出要实现BitSet的API需要写的代码

/**
 * Set all the bits up to a given index
 */
def setUntil(bitIndex: Int): Unit = {
  val wordIndex = bitIndex >> 6 // divide by 64
  Arrays.fill(words, 0, wordIndex, -1) // 计算机中以补位存储,该方法将前面的bit都设置为1了
  if(wordIndex < words.length) {
    // Set the remaining bits (note that the mask could still be zero)
    val mask = ~(-1L << (bitIndex & 0x3f)) // 0x3f正好是6个bit
    words(wordIndex) |= mask
  }
}

nextSetBit 展现了一个良好的算法写法

def nextSetBit(fromIndex: Int): Int = {
  var wordIndex = fromIndex >> 6
  if (wordIndex >= numWords) {
    return -1
  }

  // Try to find the next set bit in the current word
  val subIndex = fromIndex & 0x3f
  var word = words(wordIndex) >> subIndex
  if (word != 0) {
    return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
  }

  // Find the next set bit in the rest of the words
  wordIndex += 1
  while (wordIndex < numWords) {
    word = words(wordIndex)
    if (word != 0) {
      return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
    }
    wordIndex += 1
  }

  -1
}

3. OpenHashSet

只支持插入和查询的HashSet

因为使用了Open addressing 来进行hash和data的存储

class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
    initialCapacity: Int,
    loadFactor: Double)
  extends Serializable {
  protected var _capacity = nextPowerOf2(initialCapacity)
  protected var _mask = _capacity - 1
  protected var _size = 0
  protected var _growThreshold = (loadFactor * _capacity).toInt

  // 存储index的数据
  protected var _bitset = new BitSet(_capacity)

  // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
  // specialization bug that would generate two arrays (one for Object and one for specialized T).
  // 存储具体的数据
  protected var _data: Array[T] = _
  _data = new Array[T](_capacity)
    
  def getPos(k: T): Int = {
    var pos = hashcode(hasher.hash(k)) & _mask
    var delta = 1
    while (true) {
      if (!_bitset.get(pos)) {
        return INVALID_POS
      } else if (k == _data(pos)) {
        return pos
      } else {
        // quadratic probing with values increase by 1, 2, 3, ...
        pos = (pos + delta) & _mask  // 搜寻数据
        delta += 1
      }
    }
    throw new RuntimeException("Should never reach here.")
  }
    
  def add(k: T) {
    addWithoutResize(k)
    rehashIfNeeded(k, grow, move)
  }
    
  def addWithoutResize(k: T): Int = {
    var pos = hashcode(hasher.hash(k)) & _mask
    var delta = 1
    while (true) {
      if (!_bitset.get(pos)) {
        // This is a new key.
        _data(pos) = k
        _bitset.set(pos)
        _size += 1
        return pos | NONEXISTENCE_MASK
      } else if (_data(pos) == k) {
        // Found an existing key.
        return pos
      } else {
        // quadratic probing with values increase by 1, 2, 3, ...
        pos = (pos + delta) & _mask
        delta += 1
      }
    }
    throw new RuntimeException("Should never reach here.")
  }
}

4. OpenHashMap

因为key使用OpenHashSet来存储, 所以对应的value数据就可以直接使用Array来存储

class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
    initialCapacity: Int)
  extends Iterable[(K, V)]
  with Serializable {

  def this() = this(64)
  // 存储key
  protected var _keySet = new OpenHashSet[K](initialCapacity)
  // 存储key对应的value
    private var _values: Array[V] = _
  _values = new Array[V](_keySet.capacity)
    
  def contains(k: K): Boolean = {
    if (k == null) {
      haveNullValue
    } else {
      _keySet.getPos(k) != OpenHashSet.INVALID_POS
    }
  }
    
  /** Get the value for a given key */
  def apply(k: K): V = {
    if (k == null) {
      nullValue
    } else {
      val pos = _keySet.getPos(k)
      if (pos < 0) {
        null.asInstanceOf[V]
      } else {
        _values(pos)
      }
    }
  }
    
  /** Set the value for a key */
  def update(k: K, v: V) {
    if (k == null) {
      haveNullValue = true
      nullValue = v
    } else {
      val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
      _values(pos) = v
      _keySet.rehashIfNeeded(k, grow, move)
      _oldValues = null
    }
  } 
}

5. AppendOnlyMap

内容存储, 可以看出使用的是一个数组来存储KV结构, 也确定了下面寻址的不同

// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
private var data = new Array[AnyRef](2 * capacity)

实现的API:

/** Get the value for a given key */
def apply(key: K): V = {
  assert(!destroyed, destructionMessage)
  val k = key.asInstanceOf[AnyRef]
  if (k.eq(null)) {
    return nullValue
  }
  var pos = rehash(k.hashCode) & mask
  var i = 1
  while (true) {
    val curKey = data(2 * pos)
    if (k.eq(curKey) || k.equals(curKey)) {
      return data(2 * pos + 1).asInstanceOf[V]
    } else if (curKey.eq(null)) {  // 没有找到元素
      return null.asInstanceOf[V]
    } else {  // open addressing 寻找元素
      val delta = i
      pos = (pos + delta) & mask
      i += 1
    }
  }
  null.asInstanceOf[V]
}

/** Set the value for a key */
// 可以看出大部分代码和apply的代码有重复
def update(key: K, value: V): Unit = {
  assert(!destroyed, destructionMessage)
  val k = key.asInstanceOf[AnyRef]
  if (k.eq(null)) {
    if (!haveNullValue) {
      incrementSize()
    }
    nullValue = value
    haveNullValue = true
    return
  }
  var pos = rehash(key.hashCode) & mask
  var i = 1
  while (true) {
    val curKey = data(2 * pos)
    if (curKey.eq(null)) {
      data(2 * pos) = k
      data(2 * pos + 1) = value.asInstanceOf[AnyRef]
      incrementSize()  // Since we added a new key, 其他情况不需要resize
      return
    } else if (k.eq(curKey) || k.equals(curKey)) {
      data(2 * pos + 1) = value.asInstanceOf[AnyRef]
      return
    } else {
      val delta = i
      pos = (pos + delta) & mask
      i += 1
    }
  }
}

返回一个排序的KV对, 当然这样的话这个map就没法使用了

def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
  destroyed = true
  // Pack KV pairs into the front of the underlying array
  var keyIndex, newIndex = 0
  while (keyIndex < capacity) {
    if (data(2 * keyIndex) != null) {
      data(2 * newIndex) = data(2 * keyIndex)
      data(2 * newIndex + 1) = data(2 * keyIndex + 1)
      newIndex += 1
    }
    keyIndex += 1
  }
  assert(curSize == newIndex + (if (haveNullValue) 1 else 0))

  new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)

  new Iterator[(K, V)] {
    var i = 0
    var nullValueReady = haveNullValue
    def hasNext: Boolean = (i < newIndex || nullValueReady)
    def next(): (K, V) = {
      if (nullValueReady) {
        nullValueReady = false
        (null.asInstanceOf[K], nullValue)
      } else {
        val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
        i += 1
        item
      }
    }
  }
}

6. Sorter

sorter 主要用于shufflewriter 里面

6.1 ShuffleExternalSorter

主要用于UnsafeShuffleWriter

ShuffleExternalSorter不同于ExternalSorter的地方在于, 前者不会合并spill的文件, 而这个合并的动作则是有 UnsafeShuffleWriter 来完成的

Sorter的实现又有2种

6.1.1 RadixSort

使用末尾比较法, 算法复杂度为 O(d(n+b)), 其中d 为数字中最大的位数

radixsort.png

6.1.2 TimSort

该算法源于一个假设, 数组中的数字其实都是部分排序好的

timsort.jpg

6.2 ExternalSorter

主要用于SortShuffleWriter

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
    // care whether the keys get sorted in each partition; that will be done on the reduce side
    // if the operation being run is sortByKey.
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
  }
  sorter.insertAll(records)

  // Don't bother including the time to open the merged output file in the shuffle write time,
  // because it just opens a single file, so is typically too fast to measure accurately
  // (see SPARK-3570).
  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
    }
  }
}
// ExternalSorter.scala
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
  // TODO: stop combining if we find that the reduction factor isn't high
  val shouldCombine = aggregator.isDefined
  if (shouldCombine) {
    // Combine values in-memory first using our AppendOnlyMap
    val mergeValue = aggregator.get.mergeValue
    val createCombiner = aggregator.get.createCombiner
    var kv: Product2[K, V] = null
    val update = (hadValue: Boolean, oldValue: C) => {
      if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    }
    while (records.hasNext) {
      addElementsRead()
      kv = records.next()
      map.changeValue((getPartition(kv._1), kv._1), update)
      maybeSpillCollection(usingMap = true)
    }
  } else {
    // Stick values into our buffer
    while (records.hasNext) {
      addElementsRead()
      val kv = records.next()
      buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
      maybeSpillCollection(usingMap = false)
    }
  }
}

// spill之后重新再申请一个内存对象
private def maybeSpillCollection(usingMap: Boolean): Unit = {
  var estimatedSize = 0L
  if (usingMap) {
    estimatedSize = map.estimateSize()
    if (maybeSpill(map, estimatedSize)) {
      map = new PartitionedAppendOnlyMap[K, C]
    }
  } else {
    estimatedSize = buffer.estimateSize()
    if (maybeSpill(buffer, estimatedSize)) {
      buffer = new PartitionedPairBuffer[K, C]
    }
  }

  if (estimatedSize > _peakMemoryUsedBytes) {
    _peakMemoryUsedBytes = estimatedSize
  }
}

/**
 * Write all the data added into this ExternalSorter into a file in the disk store. This is
 * called by the SortShuffleWriter.
 *
 * @param blockId block ID to write to. The index file will be blockId.name + ".index".
 * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
 */
def writePartitionedFile(
    blockId: BlockId,
    outputFile: File): Array[Long] = {

  // Track location of each range in the output file
  val lengths = new Array[Long](numPartitions)
  val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
    context.taskMetrics().shuffleWriteMetrics)

  if (spills.isEmpty) {
    // Case where we only have in-memory data
    val collection = if (aggregator.isDefined) map else buffer
    val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
    while (it.hasNext) {
      val partitionId = it.nextPartition()
      while (it.hasNext && it.nextPartition() == partitionId) {
        it.writeNext(writer)
      }
      val segment = writer.commitAndGet()
      lengths(partitionId) = segment.length
    }
  } else {
    // We must perform merge-sort; get an iterator by partition and write everything directly.
    for ((id, elements) <- this.partitionedIterator) {
      if (elements.hasNext) {
        for (elem <- elements) {
          writer.write(elem._1, elem._2)
        }
        val segment = writer.commitAndGet()
        lengths(id) = segment.length
      }
    }
  }

  writer.close()
  context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
  context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
  context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

  lengths
}

ref:

  1. https://brilliant.org/wiki/radix-sort/
  2. https://medium.com/@rylanbauermeister/understanding-timsort-191c758a42f3

相关文章

网友评论

      本文标题:Spark中的一些数据结构

      本文链接:https://www.haomeiwen.com/subject/mdhdxhtx.html