概述
Spark Shuffle过程就是根据分区规则将数据进行重新分配,Shuffle过程中会涉及到磁盘IO和网络间的传输,因此,shuffle的效率就决定了该任务的整体效率。Spark是一个高性能的分布式计算引擎,它的Shuffle实现方式由早期的Hash Based Shuffle一步一步进行优化,形成了目前Write端的三种实现:SortShuffleWriter、ByPassMergeSortShuffleWrite、UnsafeShuffleWrite以及Read端统一实现:BlocakStoreShuffleReader。Spark Shuffle是一个非常复杂的过程,接下来通过源码(2.4.0)对Spark Shuffle Write进行分析。
Spark Shuffle的体系结构

override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
handle match {
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
context,
env.conf)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
}
可以看出getWriter方法会传入三个参数:handle、mapId、context,而方法中会通过handle进行模式匹配然后创建相应的Wirter
,那么这个handle是在什么创建的呢?
handle会在宽依赖算子注册shuffle是创建的。
override def registerShuffle[K, V, C](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) {
// If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
// need map-side aggregation, then write numPartitions files directly and just concatenate
// them at the end. This avoids doing serialization and deserialization twice to merge
// together the spilled files, which would happen with the normal code path. The downside is
// having multiple files open at a time and thus more memory allocated to buffers.
new BypassMergeSortShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
// Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
new SerializedShuffleHandle[K, V](
shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
} else {
// Otherwise, buffer map outputs in a deserialized form:
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
}
在registerShuffle方法中会通过一定的判断条件来创建相应的handle。
接下来看一下对应的条件:
def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
// We cannot bypass sorting if we need to do map-side aggregation.
if (dep.mapSideCombine) {
false
} else {
val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
dep.partitioner.numPartitions <= bypassMergeThreshold
}
}
/**
* Helper method for determining whether a shuffle should use an optimized serialized shuffle
* path or whether it should fall back to the original path that operates on deserialized objects.
*/
def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
val shufId = dependency.shuffleId
val numPartitions = dependency.partitioner.numPartitions
if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
s"${dependency.serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.mapSideCombine) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " +
s"map-side aggregation")
false
} else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
false
} else {
log.debug(s"Can use serialized shuffle for shuffle $shufId")
true
}
}
通过shouldBypassMergeSort方法可以看出,如果没有mapSideCombine&&dep.partitioner.numPartitions <= spark.shuffle.sort.bypassMergeThreshold(默认是200),则会创建BypassMergeSortShuffleHandle。
再看canUseSerializedShuffle方法,它的判定条件是:RDD内部对象的序列化方式支持Relocation&&该任务没有聚合操作&&Reduce分区数 < 2^24,如果满足的话则会创建SerializedShuffleHandle。
如果上述条件都不满足的话,则会创建BaseShuffleHandle。
该相应的判断流程可以理解为:
if (! mapSideCombine && dep.partitioner.numPartitions <= spark.shuffle.sort.bypassMergeThreshold(默认是200)){
new BypassMergeSortShuffleHandle;
} else if (RDD内部对象的序列化方式支持Relocation && 该任务没有聚合操作 && Reduce分区数 < 2^24){
new SerializedShuffleHandle;
} else {
new BaseShuffleHandle;
}
然后在getWriter中根据对应的handle去创建对应的writer。
接下来看一下BypassMergeSortShuffleWriter
根据之前的BypassMergeSortShuffleHandle判断条件可知,该Writer的过程不需要聚合,并且partition的大小应该小于等于200。

这种方式和之前的HashBasedShuffle相似,每个Writer会分配一个缓冲区,对maptask中数据根据分区规则将数据写入对应的临时文件中,每个文件里是同一个reduce分区的数据,与HashBasedShuffle不同的是最后会把每个map输出的文件进行合并,最终会生成一个数据文件和一个索引文件,数据文件中按照reduce分区id进行排序。
再看一下它的writer方法:
@Override
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
if (!records.hasNext()) {
//reduce分区个数
partitionLengths = new long[numPartitions];
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
//最后将本次shuffleMapTask结果信息汇报给scheduler
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
final long openStartTime = System.nanoTime();
//会为每个reduce生成一个diskWriter,通过append方式写入;该数组存储对应的writer
partitionWriters = new DiskBlockObjectWriter[numPartitions];
//保存每个临时文件的对象
partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
//初始化每个reduce的writer
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
// included in the shuffle write time.
writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
//遍历数据,将record按照分区规则将数据写入到对应的临时文件中
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
//将生成的分区临时文件信息记录下来
for (int i = 0; i < numPartitions; i++) {
final DiskBlockObjectWriter writer = partitionWriters[i];
partitionWriterSegments[i] = writer.commitAndGet();
writer.close();
}
File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
File tmp = Utils.tempFileWith(output);
try {
//合并临时文件,生成data file 并且记录每个分区文件的大小
partitionLengths = writePartitionedFile(tmp);
//根据生成的分区文件大小生成index file
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
} finally {
if (tmp.exists() && !tmp.delete()) {
logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
}
}
//最后将本次shuffleMapTask结果信息汇报给scheduler
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
看一下writePartitionedFile的实现逻辑,它会将之前生成的多个临时文件进行合并,因为每个临时文件里面的数据都属于同一个reduce分区,直接将这些文件按顺序copy到目标文件,最后会返回每个文件的大小, 用来记录index:
private long[] writePartitionedFile(File outputFile) throws IOException {
// Track location of the partition starts in the output file
//记录每个分区copy的大小
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
return lengths;
}
//创建data file输出流
final FileOutputStream out = new FileOutputStream(outputFile, true);
final long writeStartTime = System.nanoTime();
boolean threwException = true;
try {
//按照顺序将分区临时文件copy到data文件中
for (int i = 0; i < numPartitions; i++) {
//获取记录的分区临时文件
final File file = partitionWriterSegments[i].file();
if (file.exists()) {
final FileInputStream in = new FileInputStream(file);
boolean copyThrewException = true;
try {
//将该临时文件直接copy到data file中,并返回该临时文件的大小
lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
}
if (!file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
}
}
threwException = false;
} finally {
Closeables.close(out, threwException);
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
return lengths;
}
再看一下WriteIndexFileAndCommit方法
该方法会将writePartitionedFile返回的每个分区文件大小数组作为索引数据,写入索引文件,检查数据完整性,然后更名为正式文件。
def writeIndexFileAndCommit(
shuffleId: Int,
mapId: Int,
lengths: Array[Long],
dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val indexTmp = Utils.tempFileWith(indexFile)
try {
val dataFile = getDataFile(shuffleId, mapId)
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
// the following check and rename are atomic.
synchronized {
//检查生成的index和data是否匹配,不匹配会返回null
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
if (existingLengths != null) {
// Another attempt for the same task has already written our map outputs successfully,
// so just use the existing partition lengths and delete our temporary map outputs.
System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
if (dataTmp != null && dataTmp.exists()) {
dataTmp.delete()
}
} else {
// This is the first successful attempt in writing the map outputs for this task,
// so override any existing index and data files with the ones we wrote.
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
out.writeLong(offset)
for (length <- lengths) {
offset += length
out.writeLong(offset)
}
} {
out.close()
}
if (indexFile.exists()) {
indexFile.delete()
}
if (dataFile.exists()) {
dataFile.delete()
}
if (!indexTmp.renameTo(indexFile)) {
throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
}
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
}
}
}
} finally {
if (indexTmp.exists() && !indexTmp.delete()) {
logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
}
}
}
到这里,BypassMergeSortShuffleWriter的write过程完毕,为每个reduce生成了一个dataFile和IndexFile,其中dataFile中仅按照分区ID进行了排序。
网友评论