- 默认按照数据hash码取模来分片
- 指定了分片数就按照制定数来
- 没有指定就按照默认数
- 当数据中相同key占据多数时候,存在数据倾斜(大部分数据在少数机器),rangePartition,采用水塘抽样的方法(基本上均匀划分数据)
- 如果对数据划分要求比较严格,可以自定义划分partition
- 一个长度未知列表中随机抽取一个元素?
- 一个长度未知列表中随机抽取k个元素?
- 每个RDD中水塘抽样算法抽取部分样本(样本权重是抽样比倒数)
- 所有RDD抽样的样本和权重计算出边界数组,按照边界数组分区数据
- spark2.2中partition逻辑
- 离线pipeline中平均划分数据(自定义partition)
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val rdds = (Seq(rdd) ++ others)
val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
if (hasPartitioner.nonEmpty) {
} else {
if (rdd.context.conf.contains("spark.default.parallelism")) {
new HashPartitioner(rdd.context.defaultParallelism)
} else {
new HashPartitioner(
class HashPartitioner(partitions: Int) extends Partitioner {
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
def numPartitions: Int = partitions
def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
override def hashCode: Int = numPartitions
class RangePartitioner[K : Ordering : ClassTag, V](
partitions: Int,
rdd: RDD[_ <: Product2[K, V]],
private var ascending: Boolean = true)
extends Partitioner {
// We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
private var ordering = implicitly[Ordering[K]]
// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions <= 1) {
} else {
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
val sampleSize = math.min(20.0 * partitions, 1e6)
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
val (numItems, sketched) = RangePartitioner.sketch(, sampleSizePerPartition)
if (numItems == 0L) {
} else {
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = mutable.Set.empty[Int]
sketched.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
// The weight is 1 over the sampling probability.
val weight = (n.toDouble / sample.length).toFloat
for (key <- sample) {
candidates += ((key, weight))
if (imbalancedPartitions.nonEmpty) {
// Re-sample imbalanced partitions with the desired sampling probability.
val imbalanced = new PartitionPruningRDD(, imbalancedPartitions.contains)
val seed = byteswap32( - 1)
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
val weight = (1.0 / fraction).toFloat
candidates ++= => (x, weight))
RangePartitioner.determineBounds(candidates, partitions)
def numPartitions: Int = rangeBounds.length + 1
private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
var partition = 0
if (rangeBounds.length <= 128) {
// If we have less than 128 partitions naive search
while (partition < rangeBounds.length &&, rangeBounds(partition))) {
partition += 1
} else {
// Determine which binary search method to use only once.
partition = binarySearch(rangeBounds, k)
// binarySearch either returns the match location or -[insertion point]-1
if (partition < 0) {
partition = -partition-1
if (partition > rangeBounds.length) {
partition = rangeBounds.length
if (ascending) {
} else {
rangeBounds.length - partition
override def equals(other: Any): Boolean = other match {
case r: RangePartitioner[_, _] =>
r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
case _ =>
override def hashCode(): Int = {
val prime = 31
var result = 1
var i = 0
while (i < rangeBounds.length) {
result = prime * result + rangeBounds(i).hashCode
i += 1
result = prime * result + ascending.hashCode
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
case _ =>
val ser = sfactory.newInstance()
Utils.serializeViaNestedStream(out, ser) { stream =>
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
case _ =>
ascending = in.readBoolean()
ordering = in.readObject().asInstanceOf[Ordering[K]]
binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser) { ds =>
implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
rangeBounds = ds.readObject[Array[K]]()
private[spark] object RangePartitioner {
* Sketches the input RDD via reservoir sampling on each partition.
* @param rdd the input RDD to sketch
* @param sampleSizePerPartition max sample size per partition
* @return (total number of items, an array of (partitionId, number of items, sample))
def sketch[K : ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
val shift =
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx ^ (shift << 16))
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
val numItems =
(numItems, sketched)
* Determines the bounds for range partitioning from candidates with weights indicating how many
* items each represents. Usually this is 1 over the probability used to sample this candidate.
* @param candidates unordered candidates with weights
* @param partitions number of partitions
* @return selected bounds
def determineBounds[K : Ordering : ClassTag](
candidates: ArrayBuffer[(K, Float)],
partitions: Int): Array[K] = {
val ordering = implicitly[Ordering[K]]
val ordered = candidates.sortBy(_._1)
val numCandidates = ordered.size
val sumWeights =
val step = sumWeights / partitions
var cumWeight = 0.0
var target = step
val bounds = ArrayBuffer.empty[K]
var i = 0
var j = 0
var previousBound = Option.empty[K]
while ((i < numCandidates) && (j < partitions - 1)) {
val (key, weight) = ordered(i)
cumWeight += weight
if (cumWeight >= target) {
// Skip duplicate values.
if (previousBound.isEmpty ||, previousBound.get)) {
bounds += key
target += step
j += 1
previousBound = Some(key)
i += 1
// 自定义平均划分数据
//自定义Partitioner,保证每个part的行数至多相差1, 且编号小的part行数不大于编号大的part
tfrecords.partitionBy(new Partitioner {
def numPartitions: Int = TF_RECORD_FILE_SIZE
def getPartition(key: Any): Int =
(numPartitions.toLong - 1 - key.asInstanceOf[Long] % numPartitions.toLong).toInt
}).map(value=> {
keyClass= classOf[BytesWritable],
valueClass= classOf[NullWritable],