Java 实现 Snowflake 算法

作者: 又语 | 来源:发表于2019-07-13 10:56 被阅读47次

    本文介绍 Java 实现 Snowflake 算法生成分布式 ID。


    目录

    • Snowflake 算法简介
    • 示例
    • 总结

    Snowflake 算法简介

    Snowflake 算法是 Twitter 开源的分布式 ID 生成算法,将 64 bit 划分为多个不同组成部分,每部分代表不同含义。

    • 第一部分占用 1 bit,即第 1 位,值始终为 0,可看做符号位暂时不用;
    • 第二部分占用 41 bit,即第 2 至 42 位,代表毫秒数,2 ^ 41 = 2199023255552,2199023255552 / (1000 * 3600 * 24 * 365) > 69.73,因此 Snowflake 算法可用时间年限大约是 69 年;
    • 第三部分占用 10 bit,即第 43 至 52 位,代表机器数,2 ^ 10 = 1024,所以总共允许有 1024 台机器参与生成分布式 ID。如果存在数据中心(Data Center),则可以将这 10 bit 进一步划分,如前 5 bit 代表数据中心,后 5 bit 代表机器,则允许 32 个数据中心且每个数据中心 32 台机器参与生成分布式 ID;
    • 第四部分占用 12 bit,即第 53 至 64 位,属于自增序列,2 ^ 12 = 4096,因此每毫秒一台机器上可生成 4096 个有序且不重复的 ID。

    Snowflake 算法生成的分布式 ID 并非绝对唯一,但已满足绝大多数应用场景需求。


    示例

    package tutorial.java.util;
    
    import java.rmi.UnexpectedException;
    import java.util.concurrent.atomic.AtomicLong;
    
    public class SnowflakeDistributedId {
    
        /**
         * Snowflake算法中第三部分长度,即数据中心和工作机器ID总共占位长度
         */
        private static final long DATA_CENTER_AND_WORKER_ID_BITS = 10;
    
        /**
         * Snowflake算法中第四部分长度,即自增序列占位长度
         */
        private static final long AUTO_INCREMENT_SEQUENCE_BITS = 12;
    
        /**
         * 自增序列最大值
         */
        private static final long MAX_SEQUENCE = 4095;
    
        /**
         * 开始时间戳
         */
        private final long epoch;
    
        /**
         * 数据中心ID
         */
        private final long dataCenterId;
    
        /**
         * 机器ID占位长度
         */
        private final long workerIdBits;
    
        /**
         * 机器ID
         */
        private final long workerId;
    
        /**
         * 保存上一次生成ID的时间戳
         */
        private long lastTimestamp;
    
        /**
         * 分布式ID自增序列
         */
        private AtomicLong autoIncrementSequence;
    
        /**
         * @param dataCenterIdBits 数据中心ID占位长度
         * @param dataCenterId     数据中心ID
         * @param workerId         工作机器ID
         */
        public SnowflakeDistributedId(long epoch, long dataCenterIdBits, long dataCenterId, long workerId) {
            this.epoch = epoch;
            this.dataCenterId = validateDataCenterId(dataCenterIdBits, dataCenterId);
            workerIdBits = DATA_CENTER_AND_WORKER_ID_BITS - dataCenterIdBits;
            this.workerId = validateWorkerId(workerId);
            this.lastTimestamp = -1L;
            this.autoIncrementSequence = new AtomicLong(0);
        }
    
        /**
         * 初始化数据中心ID
         *
         * @param dataCenterIdBits 数据中心ID占位长度
         * @param dataCenterId     数据中心ID
         * @return 校验通过的数据中心ID
         */
        private long validateDataCenterId(long dataCenterIdBits, long dataCenterId) {
            if (dataCenterIdBits < 0 || dataCenterIdBits >= DATA_CENTER_AND_WORKER_ID_BITS) {
                throw new IllegalArgumentException("Data center ID bits must be in [0, 10)!");
            }
            if (dataCenterIdBits > 0) {
                // 支持的最大数据中心 ID
                long maxDataCenterId = ~(-1 << dataCenterIdBits);
                if (dataCenterId < 0 || dataCenterId > maxDataCenterId) {
                    throw new IllegalArgumentException("Data center ID must be in [0, " + maxDataCenterId + "]!");
                }
                return dataCenterId;
            }
            return -1;
        }
    
        /**
         * 初始化工作机器ID
         *
         * @param workerId 工作机器ID
         * @return 校验通过的工作机器ID
         */
        private long validateWorkerId(long workerId) {
            // 支持的最大机器ID
            long maxWorkerId = ~(-1 << this.workerIdBits);
            if (workerId < 0 || workerId > maxWorkerId) {
                throw new IllegalArgumentException("Worker ID must be in [0, " + maxWorkerId + "]!");
            }
            return workerId;
        }
    
        /**
         * 生成分布式ID
         *
         * @return long类型ID
         * @throws UnexpectedException 如果系统时间回退则抛出此异常
         */
        public long generate() throws UnexpectedException {
            long currentTimestamp = System.currentTimeMillis();
            // 如果当前时间小于上一次ID生成时间,说明系统时间回退
            if (currentTimestamp < lastTimestamp) {
                throw new UnexpectedException("System clock moved backward, refused to generate ID!");
            }
            long currentSequence;
            if (currentTimestamp == lastTimestamp) {
                // 如果当前时间等于上一次ID生成时间,获取自增序列值后加1
                currentSequence = autoIncrementSequence.getAndIncrement();
                // 如果获取的自增序列值大于允许的最大值
                if (currentSequence > MAX_SEQUENCE) {
                    // 等待到下一毫秒
                    currentTimestamp = block(currentTimestamp);
                    // 更新时间戳
                    lastTimestamp = currentTimestamp;
                    // 重新获取自增序列值
                    currentSequence = resetAutoIncrementSequence();
                }
            } else {
                // 如果当前时间大于上一次ID生成时间,重置自增序列并获取自增序列值后加1
                currentSequence = resetAutoIncrementSequence();
                // 更新时间戳
                lastTimestamp = currentTimestamp;
            }
            // 时间戳左移
            long id = (currentTimestamp - epoch) << (DATA_CENTER_AND_WORKER_ID_BITS + AUTO_INCREMENT_SEQUENCE_BITS);
            if (dataCenterId != -1) {
                // 数据中心ID左移
                id = id | (this.dataCenterId << (workerIdBits + AUTO_INCREMENT_SEQUENCE_BITS));
            }
            return id | (this.workerId << AUTO_INCREMENT_SEQUENCE_BITS) | currentSequence;
        }
    
        /**
         * 重置自增序列
         *
         * @return 自增序列值
         */
        private synchronized long resetAutoIncrementSequence() {
            autoIncrementSequence = new AtomicLong(0);
            return autoIncrementSequence.getAndIncrement();
        }
    
        /**
         * 阻塞至下一毫秒
         *
         * @param timestamp 当前时间戳
         * @return 下一毫秒时间戳
         */
        private long block(long timestamp) {
            long currentTimestamp = System.currentTimeMillis();
            while (currentTimestamp <= timestamp) {
                currentTimestamp = System.currentTimeMillis();
            }
            return currentTimestamp;
        }
    }
    

    单元测试

    import org.junit.Assert;
    import org.junit.Test;
    
    import java.rmi.UnexpectedException;
    import java.time.Instant;
    import java.util.HashSet;
    import java.util.Set;
    
    public class SnowflakeDistributedIdTest {
    
        @Test
        public void test() {
            SnowflakeDistributedId id = new SnowflakeDistributedId(Instant.now().toEpochMilli(),
                    5, 1, 8);
            Set<Long> ids = new HashSet<>();
            int iteratorTimes = 100000;
            Runnable runnable = () -> {
                for (int i = 0; i < iteratorTimes; i++) {
                    try {
                        ids.add(id.generate());
                    } catch (UnexpectedException e) {
                        Assert.fail();
                    }
                }
            };
            Set<Thread> threads = new HashSet<>();
            int threadCount = 10;
            for (int i = 0; i < threadCount; i++) {
                threads.add(new Thread(runnable));
            }
            threads.forEach(thread -> {
                thread.start();
                try {
                    thread.join();
                } catch (InterruptedException e) {
                    Assert.fail();
                }
            });
            Assert.assertEquals(iteratorTimes * threadCount, ids.stream().distinct().count());
        }
    }
    

    单元测试说明:共启动 10 个线程,每个线程循环 100000 次执行生成 ID 操作,生成的 ID 全部放入 SET 数据结构中,执行过程抛出任何异常都会导致单元测试失败,最后检查 SET 中元素数量是否等于 10 * 100000,测试结果略。


    总结

    1. Java 中 long 类型长度为 64 bit,因此 Java 实现 Snowflake 算法生成的 ID 即保存为 long 类型。
    2. 除 Snowflake 算法外,常见的分布式 ID 生成方案还包括:
      • UUID
      • 数据库生成
      • Redis 生成
      • 百度 UidGenerator
      • 美团 Leaf

    相关文章

      网友评论

        本文标题:Java 实现 Snowflake 算法

        本文链接:https://www.haomeiwen.com/subject/qlypkctx.html