美文网首页
Spark中的一些数据结构

Spark中的一些数据结构

作者: 天之見證 | 来源:发表于2020-02-04 18:36 被阅读0次

    1. ChunkedByteBuffer

    ChunkedByteBuffer 是一个只读的bytebuffer, 其实是一个bytebuffer的数组

    private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
      require(chunks != null, "chunks must not be null")
      require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
    
        /**
       * Write this buffer to a channel.
       */
      def writeFully(channel: WritableByteChannel): Unit = {
        for (bytes <- getChunks()) {
          val originalLimit = bytes.limit()
          while (bytes.hasRemaining) {
            // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct
            // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread.
            // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may
            // cause significant native memory leak, if a large direct ByteBuffer is allocated and
            // cached, as it's never released until thread exits. Here we write the `bytes` with
            // fixed-size slices to limit the size of the cached direct ByteBuffer.
            // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details.
            val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
            bytes.limit(bytes.position() + ioSize)
            channel.write(bytes)
            bytes.limit(originalLimit)
          }
        }
      }
    }
    

    为什么在写入channel的时候需要指定大小来写入

    因为当一个on-heapbytebuffer 将要写入channel的时候, JAVA的nio api会将它先copy到一个direct bytebuffer, 并且是一个线程一个(使用ThreadLocal来实现), 当源bytebuffer很大的时候, 很容易造成native内存泄漏

    多线程copyon-heapbytebuffer 容易耗费太多内存, 具体参见

    具体如果我们要释放direct bytebuffer占用的内存, 需使用内置的Cleaner 来处理

    2. BitSet

    内容存储

    private val words = new Array[Long](bit2words(numBits))
    private val numWords = words.length
    

    从它的数据结构基本可以推断出要实现BitSet的API需要写的代码

    /**
     * Set all the bits up to a given index
     */
    def setUntil(bitIndex: Int): Unit = {
      val wordIndex = bitIndex >> 6 // divide by 64
      Arrays.fill(words, 0, wordIndex, -1) // 计算机中以补位存储,该方法将前面的bit都设置为1了
      if(wordIndex < words.length) {
        // Set the remaining bits (note that the mask could still be zero)
        val mask = ~(-1L << (bitIndex & 0x3f)) // 0x3f正好是6个bit
        words(wordIndex) |= mask
      }
    }
    

    nextSetBit 展现了一个良好的算法写法

    def nextSetBit(fromIndex: Int): Int = {
      var wordIndex = fromIndex >> 6
      if (wordIndex >= numWords) {
        return -1
      }
    
      // Try to find the next set bit in the current word
      val subIndex = fromIndex & 0x3f
      var word = words(wordIndex) >> subIndex
      if (word != 0) {
        return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
      }
    
      // Find the next set bit in the rest of the words
      wordIndex += 1
      while (wordIndex < numWords) {
        word = words(wordIndex)
        if (word != 0) {
          return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
        }
        wordIndex += 1
      }
    
      -1
    }
    

    3. OpenHashSet

    只支持插入和查询的HashSet

    因为使用了Open addressing 来进行hash和data的存储

    class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
        initialCapacity: Int,
        loadFactor: Double)
      extends Serializable {
      protected var _capacity = nextPowerOf2(initialCapacity)
      protected var _mask = _capacity - 1
      protected var _size = 0
      protected var _growThreshold = (loadFactor * _capacity).toInt
    
      // 存储index的数据
      protected var _bitset = new BitSet(_capacity)
    
      // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
      // specialization bug that would generate two arrays (one for Object and one for specialized T).
      // 存储具体的数据
      protected var _data: Array[T] = _
      _data = new Array[T](_capacity)
        
      def getPos(k: T): Int = {
        var pos = hashcode(hasher.hash(k)) & _mask
        var delta = 1
        while (true) {
          if (!_bitset.get(pos)) {
            return INVALID_POS
          } else if (k == _data(pos)) {
            return pos
          } else {
            // quadratic probing with values increase by 1, 2, 3, ...
            pos = (pos + delta) & _mask  // 搜寻数据
            delta += 1
          }
        }
        throw new RuntimeException("Should never reach here.")
      }
        
      def add(k: T) {
        addWithoutResize(k)
        rehashIfNeeded(k, grow, move)
      }
        
      def addWithoutResize(k: T): Int = {
        var pos = hashcode(hasher.hash(k)) & _mask
        var delta = 1
        while (true) {
          if (!_bitset.get(pos)) {
            // This is a new key.
            _data(pos) = k
            _bitset.set(pos)
            _size += 1
            return pos | NONEXISTENCE_MASK
          } else if (_data(pos) == k) {
            // Found an existing key.
            return pos
          } else {
            // quadratic probing with values increase by 1, 2, 3, ...
            pos = (pos + delta) & _mask
            delta += 1
          }
        }
        throw new RuntimeException("Should never reach here.")
      }
    }
    

    4. OpenHashMap

    因为key使用OpenHashSet来存储, 所以对应的value数据就可以直接使用Array来存储

    class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
        initialCapacity: Int)
      extends Iterable[(K, V)]
      with Serializable {
    
      def this() = this(64)
      // 存储key
      protected var _keySet = new OpenHashSet[K](initialCapacity)
      // 存储key对应的value
        private var _values: Array[V] = _
      _values = new Array[V](_keySet.capacity)
        
      def contains(k: K): Boolean = {
        if (k == null) {
          haveNullValue
        } else {
          _keySet.getPos(k) != OpenHashSet.INVALID_POS
        }
      }
        
      /** Get the value for a given key */
      def apply(k: K): V = {
        if (k == null) {
          nullValue
        } else {
          val pos = _keySet.getPos(k)
          if (pos < 0) {
            null.asInstanceOf[V]
          } else {
            _values(pos)
          }
        }
      }
        
      /** Set the value for a key */
      def update(k: K, v: V) {
        if (k == null) {
          haveNullValue = true
          nullValue = v
        } else {
          val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
          _values(pos) = v
          _keySet.rehashIfNeeded(k, grow, move)
          _oldValues = null
        }
      } 
    }
    

    5. AppendOnlyMap

    内容存储, 可以看出使用的是一个数组来存储KV结构, 也确定了下面寻址的不同

    // Holds keys and values in the same array for memory locality; specifically, the order of
    // elements is key0, value0, key1, value1, key2, value2, etc.
    private var data = new Array[AnyRef](2 * capacity)
    

    实现的API:

    /** Get the value for a given key */
    def apply(key: K): V = {
      assert(!destroyed, destructionMessage)
      val k = key.asInstanceOf[AnyRef]
      if (k.eq(null)) {
        return nullValue
      }
      var pos = rehash(k.hashCode) & mask
      var i = 1
      while (true) {
        val curKey = data(2 * pos)
        if (k.eq(curKey) || k.equals(curKey)) {
          return data(2 * pos + 1).asInstanceOf[V]
        } else if (curKey.eq(null)) {  // 没有找到元素
          return null.asInstanceOf[V]
        } else {  // open addressing 寻找元素
          val delta = i
          pos = (pos + delta) & mask
          i += 1
        }
      }
      null.asInstanceOf[V]
    }
    
    /** Set the value for a key */
    // 可以看出大部分代码和apply的代码有重复
    def update(key: K, value: V): Unit = {
      assert(!destroyed, destructionMessage)
      val k = key.asInstanceOf[AnyRef]
      if (k.eq(null)) {
        if (!haveNullValue) {
          incrementSize()
        }
        nullValue = value
        haveNullValue = true
        return
      }
      var pos = rehash(key.hashCode) & mask
      var i = 1
      while (true) {
        val curKey = data(2 * pos)
        if (curKey.eq(null)) {
          data(2 * pos) = k
          data(2 * pos + 1) = value.asInstanceOf[AnyRef]
          incrementSize()  // Since we added a new key, 其他情况不需要resize
          return
        } else if (k.eq(curKey) || k.equals(curKey)) {
          data(2 * pos + 1) = value.asInstanceOf[AnyRef]
          return
        } else {
          val delta = i
          pos = (pos + delta) & mask
          i += 1
        }
      }
    }
    

    返回一个排序的KV对, 当然这样的话这个map就没法使用了

    def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
      destroyed = true
      // Pack KV pairs into the front of the underlying array
      var keyIndex, newIndex = 0
      while (keyIndex < capacity) {
        if (data(2 * keyIndex) != null) {
          data(2 * newIndex) = data(2 * keyIndex)
          data(2 * newIndex + 1) = data(2 * keyIndex + 1)
          newIndex += 1
        }
        keyIndex += 1
      }
      assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
    
      new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)
    
      new Iterator[(K, V)] {
        var i = 0
        var nullValueReady = haveNullValue
        def hasNext: Boolean = (i < newIndex || nullValueReady)
        def next(): (K, V) = {
          if (nullValueReady) {
            nullValueReady = false
            (null.asInstanceOf[K], nullValue)
          } else {
            val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
            i += 1
            item
          }
        }
      }
    }
    

    6. Sorter

    sorter 主要用于shufflewriter 里面

    6.1 ShuffleExternalSorter

    主要用于UnsafeShuffleWriter

    ShuffleExternalSorter不同于ExternalSorter的地方在于, 前者不会合并spill的文件, 而这个合并的动作则是有 UnsafeShuffleWriter 来完成的

    Sorter的实现又有2种

    6.1.1 RadixSort

    使用末尾比较法, 算法复杂度为 O(d(n+b)), 其中d 为数字中最大的位数

    radixsort.png

    6.1.2 TimSort

    该算法源于一个假设, 数组中的数字其实都是部分排序好的

    timsort.jpg

    6.2 ExternalSorter

    主要用于SortShuffleWriter

    /** Write a bunch of records to this task's output */
    override def write(records: Iterator[Product2[K, V]]): Unit = {
      sorter = if (dep.mapSideCombine) {
        new ExternalSorter[K, V, C](
          context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
      } else {
        // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
        // care whether the keys get sorted in each partition; that will be done on the reduce side
        // if the operation being run is sortByKey.
        new ExternalSorter[K, V, V](
          context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
      }
      sorter.insertAll(records)
    
      // Don't bother including the time to open the merged output file in the shuffle write time,
      // because it just opens a single file, so is typically too fast to measure accurately
      // (see SPARK-3570).
      val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
      val tmp = Utils.tempFileWith(output)
      try {
        val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
        val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
        shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
        mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
      } finally {
        if (tmp.exists() && !tmp.delete()) {
          logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
        }
      }
    }
    
    // ExternalSorter.scala
    def insertAll(records: Iterator[Product2[K, V]]): Unit = {
      // TODO: stop combining if we find that the reduction factor isn't high
      val shouldCombine = aggregator.isDefined
      if (shouldCombine) {
        // Combine values in-memory first using our AppendOnlyMap
        val mergeValue = aggregator.get.mergeValue
        val createCombiner = aggregator.get.createCombiner
        var kv: Product2[K, V] = null
        val update = (hadValue: Boolean, oldValue: C) => {
          if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
        }
        while (records.hasNext) {
          addElementsRead()
          kv = records.next()
          map.changeValue((getPartition(kv._1), kv._1), update)
          maybeSpillCollection(usingMap = true)
        }
      } else {
        // Stick values into our buffer
        while (records.hasNext) {
          addElementsRead()
          val kv = records.next()
          buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
          maybeSpillCollection(usingMap = false)
        }
      }
    }
    
    // spill之后重新再申请一个内存对象
    private def maybeSpillCollection(usingMap: Boolean): Unit = {
      var estimatedSize = 0L
      if (usingMap) {
        estimatedSize = map.estimateSize()
        if (maybeSpill(map, estimatedSize)) {
          map = new PartitionedAppendOnlyMap[K, C]
        }
      } else {
        estimatedSize = buffer.estimateSize()
        if (maybeSpill(buffer, estimatedSize)) {
          buffer = new PartitionedPairBuffer[K, C]
        }
      }
    
      if (estimatedSize > _peakMemoryUsedBytes) {
        _peakMemoryUsedBytes = estimatedSize
      }
    }
    
    /**
     * Write all the data added into this ExternalSorter into a file in the disk store. This is
     * called by the SortShuffleWriter.
     *
     * @param blockId block ID to write to. The index file will be blockId.name + ".index".
     * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
     */
    def writePartitionedFile(
        blockId: BlockId,
        outputFile: File): Array[Long] = {
    
      // Track location of each range in the output file
      val lengths = new Array[Long](numPartitions)
      val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
        context.taskMetrics().shuffleWriteMetrics)
    
      if (spills.isEmpty) {
        // Case where we only have in-memory data
        val collection = if (aggregator.isDefined) map else buffer
        val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
        while (it.hasNext) {
          val partitionId = it.nextPartition()
          while (it.hasNext && it.nextPartition() == partitionId) {
            it.writeNext(writer)
          }
          val segment = writer.commitAndGet()
          lengths(partitionId) = segment.length
        }
      } else {
        // We must perform merge-sort; get an iterator by partition and write everything directly.
        for ((id, elements) <- this.partitionedIterator) {
          if (elements.hasNext) {
            for (elem <- elements) {
              writer.write(elem._1, elem._2)
            }
            val segment = writer.commitAndGet()
            lengths(id) = segment.length
          }
        }
      }
    
      writer.close()
      context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
      context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
      context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
    
      lengths
    }
    

    ref:

    1. https://brilliant.org/wiki/radix-sort/
    2. https://medium.com/@rylanbauermeister/understanding-timsort-191c758a42f3

    相关文章

      网友评论

          本文标题:Spark中的一些数据结构

          本文链接:https://www.haomeiwen.com/subject/mdhdxhtx.html