// Create an instance of the class with the given name, possibly initializing it with our conf
// 实例化class对象。这里可以抽成Utils
def instantiateClass[T](className: String): T = {
    // Class.forName(className)
    val cls = Utils.classForName(className)
    // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
    // SparkConf, then one taking no arguments
    try {
    // 根据sparkconf,boolean构造对象
    cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
        .newInstance(conf, new java.lang.Boolean(isDriver))
    } catch {
    case _: NoSuchMethodException =>
        try {
// 失败时,只使用sparkconf构造对象
        } catch {
        case _: NoSuchMethodException =>
            // 兜底方案是使用默认无参构造函数

// Create an instance of the class named by the given SparkConf property, or defaultClassName
// if the property is not set, possibly initializing it with our conf
// 通过conf配置项的key获取对应的class value,不存在时defaultClassName
def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
    instantiateClass[T](conf.get(propertyName, defaultClassName))

// key: "spark.serializer",可以在conf中配置其他序列化实现: KryoSerializer
// default class: "org.apache.spark.serializer.JavaSerializer"
val serializer = instantiateClassFromConf[Serializer](
    "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")

// 构建序列化管理器对象
val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
// 闭包序列化器JavaSerializer
val closureSerializer = new JavaSerializer(conf)

// "spark.io.encryption.enabled",默认false
val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
  // "spark.io.encryption.keySizeBits": 密钥长度,有128、192、256三种长度
  // "spark.io.encryption.keygen.algorithm": 加密算法,默认为HmacSHA1
} else {





 * Component which configures serialization, compression and encryption for various Spark
 * components, including automatic selection of which [[Serializer]] to use for shuffles.
private[spark] class SerializerManager(
    defaultSerializer: Serializer,
    conf: SparkConf,
    encryptionKey: Option[Array[Byte]]) {

  def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None)
  // 创建KryoSerializer对象
  private[this] val kryoSerializer = new KryoSerializer(conf)

  // 设置Serializer的类加载器classloader。主要在kryoSerializer类的newKryo()方法里,Class.forName(classname, true, classloader)
  // https://issues-test.apache.org/jira/browse/SPARK-21928
  // 当用户自定义实现KryoRegistrator接口,register用户的类时,netty进行MessageToMessageDecoder会报ClassNotFoundException,所以newKryo()方法需要切换ClassLoader
  // 加载用户自定义的class文件时,切记ClassLoader的切换
  def setDefaultClassLoader(classLoader: ClassLoader): Unit = {

  private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
  // 原生类型及原生类型的数组类型: Boolean、Array[boolean]、Int、Array[int]
  private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = {
    val primitiveClassTags = Set[ClassTag[_]](
    val arrayClassTags = primitiveClassTags.map(_.wrap)
    primitiveClassTags ++ arrayClassTags

  // 广播对象、Shuffle输出数据、RDD、溢出到磁盘的Shuffle数据,是否压缩配置
  // Whether to compress broadcast variables that are stored
  private[this] val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
  // Whether to compress shuffle output that are stored
  private[this] val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
  // Whether to compress RDD partitions that are stored serialized
  private[this] val compressRdds = conf.getBoolean("spark.rdd.compress", false)
  // Whether to compress shuffle output temporarily spilled to disk
  private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)

  /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
   * the initialization of the compression codec until it is first used. The reason is that a Spark
   * program could be using a user-defined codec in a third party jar, which is loaded in
   * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
   * loaded yet. */
// 提供4种压缩方法,默认是<"spark.io.compression.codec", "lz4">
// "lz4" -> LZ4CompressionCodec -> lz4-java
// "lzf" -> LZFCompressionCodec -> compress-lzf
// "snappy" -> SnappyCompressionCodec -> snappy-java
// "zstd" -> ZStdCompressionCodec -> zstd-jni
// conf配置"spark.io.compression.codec"对应的class,反射获取对象
  private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
  // 是否支持加密
  def encryptionEnabled: Boolean = encryptionKey.isDefined
  // 根据基础类型判定是否使用kryo: Boolean、Double、Int、Long、String、Array等
  def canUseKryo(ct: ClassTag[_]): Boolean = {
    primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag

  // SPARK-18617: As feature in SPARK-13990 can not be applied to Spark Streaming now. The worst
  // result is streaming job based on `Receiver` mode can not run on Spark 2.x properly. It may be
  // a rational choice to close `kryo auto pick` feature for streaming in the first step.
  // autoPick: !blockId.isInstanceOf[StreamBlockId]
  // Receiver模式的流计算建议关闭kryo,使用JavaSerializer
  def getSerializer(ct: ClassTag[_], autoPick: Boolean): Serializer = {
    if (autoPick && canUseKryo(ct)) {
    } else {

   * Pick the best serializer for shuffling an RDD of key-value pairs.
  // 校验keyClass、valueClass是否可以使用kryo
  def getSerializer(keyClassTag: ClassTag[_], valueClassTag: ClassTag[_]): Serializer = {
    if (canUseKryo(keyClassTag) && canUseKryo(valueClassTag)) {
    } else {

  // 类型模式匹配,返回是否需要压缩的配置项
  private def shouldCompress(blockId: BlockId): Boolean = {
    blockId match {
      case _: ShuffleBlockId => compressShuffle
      case _: BroadcastBlockId => compressBroadcast
      case _: RDDBlockId => compressRdds
      case _: TempLocalBlockId => compressShuffleSpill
      case _: TempShuffleBlockId => compressShuffle
      case _ => false

   * Wrap an input stream for encryption and compression
   // 用wrapForEncryption、wrapForCompression装饰InputStream
  def wrapStream(blockId: BlockId, s: InputStream): InputStream = {
    wrapForCompression(blockId, wrapForEncryption(s))

   * Wrap an output stream for encryption and compression
   // 用wrapForEncryption、wrapForCompression装饰OutputStream
  def wrapStream(blockId: BlockId, s: OutputStream): OutputStream = {
    wrapForCompression(blockId, wrapForEncryption(s))

   * Wrap an input stream for encryption if shuffle encryption is enabled
   // 使用apache common-crypto实现AES加密。注意CryptoParams类里conf的生成实现
   // CryptoUtils.toCryptoConf: 将spark的key转成Crypto对应的key
  def wrapForEncryption(s: InputStream): InputStream = {
      .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }

   * Wrap an output stream for encryption if shuffle encryption is enabled
  def wrapForEncryption(s: OutputStream): OutputStream = {
      .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }

   * Wrap an output stream for compression if block compression is enabled for its block type
   // InputStream、OutputStream使用装饰器模式添加Compress功能
  def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
    if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s

   * Wrap an input stream for compression if block compression is enabled for its block type
  def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
    if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s

  /** Serializes into a stream. */
  // 序列化values数据到OutputStream
  def dataSerializeStream[T: ClassTag](
      blockId: BlockId,
      outputStream: OutputStream,
      values: Iterator[T]): Unit = {
    // 用BufferedOutputStream包装,缓冲
    val byteStream = new BufferedOutputStream(outputStream)
    val autoPick = !blockId.isInstanceOf[StreamBlockId]
    // 先获取KryoSerializer,后获取KryoSerializerInstance实例
    val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance()
    // Instance实例获取序列化流KryoSerializationStream,values迭代数据写入output
    ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()

  /** Serializes into a chunked byte buffer. */
  def dataSerialize[T: ClassTag](
      blockId: BlockId,
      values: Iterator[T]): ChunkedByteBuffer = {
    dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])

  /** Serializes into a chunked byte buffer. */
  // 将values数据序列化到分块字节缓冲区ArrayBuffer[ByteBuffer]
  def dataSerializeWithExplicitClassTag(
      blockId: BlockId,
      values: Iterator[_],
      classTag: ClassTag[_]): ChunkedByteBuffer = {
    // 即ByteBuffer.allocate(4*1024*1024),每个块大小
    val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
    val byteStream = new BufferedOutputStream(bbos)
    val autoPick = !blockId.isInstanceOf[StreamBlockId]
    val ser = getSerializer(classTag, autoPick).newInstance()
    ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
    // ArrayBuffer[ByteBuffer] -> Array[ByteBuffer]

   * Deserializes an InputStream into an iterator of values and disposes of it when the end of
   * the iterator is reached.
   // 反序列化InputStream,返回Iterator[T]
  def dataDeserializeStream[T](
      blockId: BlockId,
      inputStream: InputStream)
      (classTag: ClassTag[T]): Iterator[T] = {
    val stream = new BufferedInputStream(inputStream)
    val autoPick = !blockId.isInstanceOf[StreamBlockId]
    getSerializer(classTag, autoPick)
      .deserializeStream(wrapForCompression(blockId, stream))

SerializerManager类中data开头的Serialize、Deserialize方法,都默认先对InputStream、OutputStream进行Buffered缓冲包装,然后wrapForCompression(blockId, stream)对流进行压缩处理




trait CompressionCodec {

  def compressedOutputStream(s: OutputStream): OutputStream

  def compressedInputStream(s: InputStream): InputStream

再分别实现lz4、lzf、snappy、zstd压缩算法: lz4-java,compress-lzf,snappy-java,zstd-jni

// lz4
class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec {
  override def compressedOutputStream(s: OutputStream): OutputStream = {
    val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt
    new LZ4BlockOutputStream(s, blockSize)

  override def compressedInputStream(s: InputStream): InputStream = {
    val disableConcatenationOfByteStream = false
    new LZ4BlockInputStream(s, disableConcatenationOfByteStream)

// lzf
class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
  override def compressedOutputStream(s: OutputStream): OutputStream = {
    new LZFOutputStream(s).setFinishBlockOnFlush(true)

  override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s)

// snappy
class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
  // Snappy.getNativeLibraryVersion
  val version = SnappyCompressionCodec.version

  override def compressedOutputStream(s: OutputStream): OutputStream = {
    val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt
    // 这里使用自定义的SnappyOutputStreamWrapper类包装(因为SnappyOutputStream类的close方法不是幂等的,当两个SnappyOutputStream对象共用同一个Buffer时,其中一个close会导致另一个引用出错)
    // snappy-java:1.1.2已经修复,可以直接返回SnappyOutputStream对象
    new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize))

  override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s)

// zstd。使用Buffered缓冲,减少调用JNI方法频次,降低开销
class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec {
  private val bufferSize = conf.getSizeAsBytes("spark.io.compression.zstd.bufferSize", "32k").toInt
  // Default compression level for zstd compression to 1 because it is
  // fastest of all with reasonably high compression ratio.
  private val level = conf.getInt("spark.io.compression.zstd.level", 1)

  override def compressedOutputStream(s: OutputStream): OutputStream = {
    // Wrap the zstd output stream in a buffered output stream, so that we can
    // avoid overhead excessive of JNI call while trying to compress small amount of data.
    new BufferedOutputStream(new ZstdOutputStream(s, level), bufferSize)

  override def compressedInputStream(s: InputStream): InputStream = {
    // Wrap the zstd input stream in a buffered input stream so that we can
    // avoid overhead excessive of JNI call while trying to uncompress small amount of data.
    new BufferedInputStream(new ZstdInputStream(s), bufferSize)


abstract class SerializerInstance {
  def serialize[T: ClassTag](t: T): ByteBuffer

  def deserialize[T: ClassTag](bytes: ByteBuffer): T

  def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T

  def serializeStream(s: OutputStream): SerializationStream

  def deserializeStream(s: InputStream): DeserializationStream



spark的KryoSerializer实现相对复杂,使用twitter-chill开源的scala kryo库,其封装了kryopool,register基础class等操作。下面提供一个日常开发的Kryo工具类

public class KryoSerDe {

    private static KryoPool pool = new KryoPool.Builder(() -> {
        Kryo kryo = new Kryo();

        // 关闭循环引用,节约空间。但对象有循环嵌套时,可能会出现StackOverflowError
        kryo.setInstantiatorStrategy(new Kryo.DefaultInstantiatorStrategy(new StdInstantiatorStrategy()));
        return kryo;

     * 使用writeObject只序列化对象,不记录类信息。反序列时readObject+Class
     * 使用writeClassAndObject时,序列对象和类信息。反序列化readClassAndObject,不需要Class
     * @param t
     * @param <T>
     * @return
    public static <T> byte[] serialize(T t) {

        // apache common-io, not java.io
        ByteArrayOutputStream stream = new ByteArrayOutputStream();
        Output output = new Output(stream);

        pool.run(kryo -> {
            kryo.writeClassAndObject(output, t);
            return output;

        return stream.toByteArray();

    public static <T> T deserialize(byte[] bytes) {
        if (bytes == null) {
            return null;

        Input input = new Input(new ByteArrayInputStream(bytes));
        T t = (T) pool.run(kryo -> kryo.readClassAndObject(input));

        return t;

Kryo官方库使用ConcurrentLinkedQueue、SoftReferenceQueue实现KryoPool对象池,值得借鉴: KryoPool接口定义borrow、release方法;KryoFactory定义create方法构建Kryo对象;KryoCallback封装业务代码,屏蔽borrow、release操作,类似JedisTemplate思想;KryoPoolQueueImpl是具体操作实现

可以事先对Class进行register,这样kryo序列化时会用整数代替类名,节省空间。或者extends Serializer,自定义对象的序列化、反序列化实现


classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() {
  override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = {
    bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output))
  override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = {
    val ret = new RoaringBitmap
    ret.deserialize(new KryoInputObjectInputBridge(kryo, input))


private[spark] class KryoInputObjectInputBridge(
    kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput {
  override def readLong(): Long = input.readLong()
  override def readChar(): Char = input.readChar()
  override def readFloat(): Float = input.readFloat()
  override def readByte(): Byte = input.readByte()
  override def readShort(): Short = input.readShort()
  override def readUTF(): String = input.readString() // readString in kryo does utf8
  override def readInt(): Int = input.readInt()
  override def readUnsignedShort(): Int = input.readShortUnsigned()
  override def skipBytes(n: Int): Int = {
  override def readFully(b: Array[Byte]): Unit = input.read(b)
  override def readFully(b: Array[Byte], off: Int, len: Int): Unit = input.read(b, off, len)
  override def readLine(): String = throw new UnsupportedOperationException("readLine")
  override def readBoolean(): Boolean = input.readBoolean()
  override def readUnsignedByte(): Int = input.readByteUnsigned()
  override def readDouble(): Double = input.readDouble()
  override def readObject(): AnyRef = kryo.readClassAndObject(input)




JavaSerializer实现Externalizable接口的writeExternal()、readExternal()方法,用来自定义类的序列化、反序列化字段: counterReset、extraDebugInfo


class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
  private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
  private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)

  protected def this() = this(new SparkConf())  // For deserialization only

  override def newInstance(): SerializerInstance = {
    val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
    new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)

  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {

  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
    counterReset = in.readInt()
    extraDebugInfo = in.readBoolean()


public class JavaSerDe {

    public static <T extends Serializable> byte[] serialize(T t) {

        try {
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            ObjectOutputStream objOut = new ObjectOutputStream(bos);

            return bos.toByteArray();
        } catch (IOException e) {
            throw new RuntimeException(e);

    public static <T extends Serializable> T deserialize(byte[] bytes) {
        if (bytes == null) {
            return null;

        try {
            ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
            ObjectInputStream objIn = new ObjectInputStream(bis);

            return (T) objIn.readObject();
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);


ObjectInputStream objIn = new ObjectInputStream(bis) {
    protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
        String name = desc.getName();
        try {
            // 这里的loader可以外界传入
            return Class.forName(name, false, loader);
        } catch (ClassNotFoundException ex) {
            // java基本类型,直接返回
            Class<?> cl = primClasses.get(name);
            if (cl != null) {
                return cl;
            } else {
                throw ex;

// 下面这段代码是ObjectInputStream:223的源码,spark的JavaDeserializationStream:60也是类似代码
private static final HashMap<String, Class<?>> primClasses
        = new HashMap<>(8, 1.0F);
static {
    primClasses.put("boolean", boolean.class);
    primClasses.put("byte", byte.class);
    primClasses.put("char", char.class);
    primClasses.put("short", short.class);
    primClasses.put("int", int.class);
    primClasses.put("long", long.class);
    primClasses.put("float", float.class);
    primClasses.put("double", double.class);
    primClasses.put("void", void.class);





  1. 避免ObjectOutputStream成为类变量,也就是每次使用时进行new,最后close
  2. 当需要同一个ObjectOutputStream对象,多次调用writeObject方法时,切记进行reset,清除持有的引用
private[spark] class JavaSerializationStream(
    out: OutputStream, counterReset: Int, extraDebugInfo: Boolean)
  extends SerializationStream {
  private val objOut = new ObjectOutputStream(out)
  private var counter = 0

   * Calling reset to avoid memory leak:
   * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
   * But only call it every 100th time to avoid bloated serialization streams (when
   * the stream 'resets' object class descriptions have to be re-written)
  def writeObject[T: ClassTag](t: T): SerializationStream = {
    try {
    } catch {
      case e: NotSerializableException if extraDebugInfo =>
        throw SerializationDebugger.improveException(t, e)
    counter += 1
    // 每次调用+1,超过100就reset。提高性能,也避免内存泄漏
    if (counterReset > 0 && counter >= counterReset) {
      counter = 0

  def flush() { objOut.flush() }
  def close() { objOut.close() }

Spark的objOut.writeObject对writeObject方法进行了调用count计数,当超过设置的阈值: conf.getInt("spark.serializer.objectStreamReset", 100)时,reset()!



