1. ChunkedByteBuffer
ChunkedByteBuffer
是一个只读的bytebuffer
, 其实是一个bytebuffer
的数组
private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
require(chunks != null, "chunks must not be null")
require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
/**
* Write this buffer to a channel.
*/
def writeFully(channel: WritableByteChannel): Unit = {
for (bytes <- getChunks()) {
val originalLimit = bytes.limit()
while (bytes.hasRemaining) {
// If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct
// ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread.
// Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may
// cause significant native memory leak, if a large direct ByteBuffer is allocated and
// cached, as it's never released until thread exits. Here we write the `bytes` with
// fixed-size slices to limit the size of the cached direct ByteBuffer.
// Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details.
val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
bytes.limit(bytes.position() + ioSize)
channel.write(bytes)
bytes.limit(originalLimit)
}
}
}
}
为什么在写入channel
的时候需要指定大小来写入
因为当一个
on-heap
的bytebuffer
将要写入channel
的时候, JAVA的nio api会将它先copy到一个direct bytebuffer
, 并且是一个线程一个(使用ThreadLocal来实现), 当源bytebuffer
很大的时候, 很容易造成native内存泄漏
多线程copyon-heap
的bytebuffer
容易耗费太多内存, 具体参见
具体如果我们要释放direct bytebuffer
占用的内存, 需使用内置的Cleaner
来处理
2. BitSet
内容存储
private val words = new Array[Long](bit2words(numBits))
private val numWords = words.length
从它的数据结构基本可以推断出要实现BitSet
的API需要写的代码
/**
* Set all the bits up to a given index
*/
def setUntil(bitIndex: Int): Unit = {
val wordIndex = bitIndex >> 6 // divide by 64
Arrays.fill(words, 0, wordIndex, -1) // 计算机中以补位存储,该方法将前面的bit都设置为1了
if(wordIndex < words.length) {
// Set the remaining bits (note that the mask could still be zero)
val mask = ~(-1L << (bitIndex & 0x3f)) // 0x3f正好是6个bit
words(wordIndex) |= mask
}
}
nextSetBit
展现了一个良好的算法写法
def nextSetBit(fromIndex: Int): Int = {
var wordIndex = fromIndex >> 6
if (wordIndex >= numWords) {
return -1
}
// Try to find the next set bit in the current word
val subIndex = fromIndex & 0x3f
var word = words(wordIndex) >> subIndex
if (word != 0) {
return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
}
// Find the next set bit in the rest of the words
wordIndex += 1
while (wordIndex < numWords) {
word = words(wordIndex)
if (word != 0) {
return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
}
wordIndex += 1
}
-1
}
3. OpenHashSet
只支持插入和查询的HashSet
因为使用了Open addressing
来进行hash和data的存储
class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
extends Serializable {
protected var _capacity = nextPowerOf2(initialCapacity)
protected var _mask = _capacity - 1
protected var _size = 0
protected var _growThreshold = (loadFactor * _capacity).toInt
// 存储index的数据
protected var _bitset = new BitSet(_capacity)
// Init of the array in constructor (instead of in declaration) to work around a Scala compiler
// specialization bug that would generate two arrays (one for Object and one for specialized T).
// 存储具体的数据
protected var _data: Array[T] = _
_data = new Array[T](_capacity)
def getPos(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
var delta = 1
while (true) {
if (!_bitset.get(pos)) {
return INVALID_POS
} else if (k == _data(pos)) {
return pos
} else {
// quadratic probing with values increase by 1, 2, 3, ...
pos = (pos + delta) & _mask // 搜寻数据
delta += 1
}
}
throw new RuntimeException("Should never reach here.")
}
def add(k: T) {
addWithoutResize(k)
rehashIfNeeded(k, grow, move)
}
def addWithoutResize(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
var delta = 1
while (true) {
if (!_bitset.get(pos)) {
// This is a new key.
_data(pos) = k
_bitset.set(pos)
_size += 1
return pos | NONEXISTENCE_MASK
} else if (_data(pos) == k) {
// Found an existing key.
return pos
} else {
// quadratic probing with values increase by 1, 2, 3, ...
pos = (pos + delta) & _mask
delta += 1
}
}
throw new RuntimeException("Should never reach here.")
}
}
4. OpenHashMap
因为key使用OpenHashSet
来存储, 所以对应的value数据就可以直接使用Array
来存储
class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
initialCapacity: Int)
extends Iterable[(K, V)]
with Serializable {
def this() = this(64)
// 存储key
protected var _keySet = new OpenHashSet[K](initialCapacity)
// 存储key对应的value
private var _values: Array[V] = _
_values = new Array[V](_keySet.capacity)
def contains(k: K): Boolean = {
if (k == null) {
haveNullValue
} else {
_keySet.getPos(k) != OpenHashSet.INVALID_POS
}
}
/** Get the value for a given key */
def apply(k: K): V = {
if (k == null) {
nullValue
} else {
val pos = _keySet.getPos(k)
if (pos < 0) {
null.asInstanceOf[V]
} else {
_values(pos)
}
}
}
/** Set the value for a key */
def update(k: K, v: V) {
if (k == null) {
haveNullValue = true
nullValue = v
} else {
val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = v
_keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
}
}
5. AppendOnlyMap
内容存储, 可以看出使用的是一个数组来存储KV结构, 也确定了下面寻址的不同
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
private var data = new Array[AnyRef](2 * capacity)
实现的API:
/** Get the value for a given key */
def apply(key: K): V = {
assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
return nullValue
}
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (k.eq(curKey) || k.equals(curKey)) {
return data(2 * pos + 1).asInstanceOf[V]
} else if (curKey.eq(null)) { // 没有找到元素
return null.asInstanceOf[V]
} else { // open addressing 寻找元素
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
null.asInstanceOf[V]
}
/** Set the value for a key */
// 可以看出大部分代码和apply的代码有重复
def update(key: K, value: V): Unit = {
assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = value
haveNullValue = true
return
}
var pos = rehash(key.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (curKey.eq(null)) {
data(2 * pos) = k
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
incrementSize() // Since we added a new key, 其他情况不需要resize
return
} else if (k.eq(curKey) || k.equals(curKey)) {
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
return
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
}
返回一个排序的KV对, 当然这样的话这个map就没法使用了
def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
destroyed = true
// Pack KV pairs into the front of the underlying array
var keyIndex, newIndex = 0
while (keyIndex < capacity) {
if (data(2 * keyIndex) != null) {
data(2 * newIndex) = data(2 * keyIndex)
data(2 * newIndex + 1) = data(2 * keyIndex + 1)
newIndex += 1
}
keyIndex += 1
}
assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)
new Iterator[(K, V)] {
var i = 0
var nullValueReady = haveNullValue
def hasNext: Boolean = (i < newIndex || nullValueReady)
def next(): (K, V) = {
if (nullValueReady) {
nullValueReady = false
(null.asInstanceOf[K], nullValue)
} else {
val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
i += 1
item
}
}
}
}
6. Sorter
sorter
主要用于shufflewriter
里面
6.1 ShuffleExternalSorter
主要用于UnsafeShuffleWriter
ShuffleExternalSorter
不同于ExternalSorter
的地方在于, 前者不会合并spill的文件, 而这个合并的动作则是有 UnsafeShuffleWriter
来完成的
该Sorter
的实现又有2种
6.1.1 RadixSort
使用末尾比较法, 算法复杂度为 , 其中 为数字中最大的位数
radixsort.png6.1.2 TimSort
该算法源于一个假设, 数组中的数字其实都是部分排序好的
timsort.jpg6.2 ExternalSorter
主要用于SortShuffleWriter
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}
// ExternalSorter.scala
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}
// spill之后重新再申请一个内存对象
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
if (usingMap) {
estimatedSize = map.estimateSize()
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = new PartitionedPairBuffer[K, C]
}
}
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
}
/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedFile(
blockId: BlockId,
outputFile: File): Array[Long] = {
// Track location of each range in the output file
val lengths = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}
writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
lengths
}
ref:
网友评论