美文网首页
ReentrantLock实现原理-如何实现一把锁(一)

ReentrantLock实现原理-如何实现一把锁(一)

作者: 在下喵星人 | 来源:发表于2021-07-26 07:53 被阅读0次

    一、使用CAS实现一把锁

    锁作用可以抽象理解为避免共享资源被并发访问。按照这条概念我们在JAVA中可以定义一下实现。

    1. 定义一个锁变量state。
    2. 当多个线程同时范围同一个共享资源时,我们通过cas保证只有一个线程修改这个锁变量state成功,即获得锁。其他没有获得锁的线程,不断自旋尝试获得锁。
    3. 当使用完共享资源时,还原state的值,让其他线程获得锁。

    定义锁接口

    public interface Lock {
        void lock();
        void unlock();
    }
    

    按照上面原则具体实现如下:

    public class SpinLock implements Lock {
        AtomicInteger state = new AtomicInteger();
        @Override
        public void  lock() {
            boolean flag;
            do {
            //自旋
                flag = this.state.compareAndSet(0, 1);
            }
            while (!flag);
        }
        @Override
        public void unlock() {
            state.compareAndSet(1,0);
        }
    }
    

    测试

    public class Main {
    
        static int value = 0;
        public static void main(String[] args) throws InterruptedException {
            SpinLock spinLock = new SpinLock();
            final CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
            for (int i = 0; i < 10; i++) {
                new Thread(new Runnable() {
                    public void run() {
                        try {
                            cyclicBarrier.await();
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                        spinLock.lock();
                        for (int j = 0; j < 100; j++) {
                            value++;
                        }
                        spinLock.unlock();
    
                    }
                }).start();
            }
            TimeUnit.SECONDS.sleep(3);
            System.out.println("value: " + value);
        }
    }
    

    结果

    value: 1000
    

    二、实现可重入

    当我们判断是同一个线程再次获得锁时,把state自增1。
    代表获得锁的次数,即可实现可重入。



    为了后面讲解ReentrantLock方便,我们重构代码。定义CustomAbstractQueuedSynchronizer抽象类并继承AbstractOwnableSynchronizer。AbstractOwnableSynchronizer是JDK提供的抽象类,用于设置和获取当前获得锁的线程。为了使用state方便,改用unsafe对state进行操作。

    public abstract class AbstractOwnableSynchronizer
        implements java.io.Serializable {
    
        private static final long serialVersionUID = 3737899427754241961L;
    
     
        protected AbstractOwnableSynchronizer() { }
    
    
        private transient Thread exclusiveOwnerThread;
    
    
        protected final void setExclusiveOwnerThread(Thread thread) {
            exclusiveOwnerThread = thread;
        }
    
      
        protected final Thread getExclusiveOwnerThread() {
            return exclusiveOwnerThread;
        }
    }
    
    public abstract class CustomAbstractQueuedSynchronizer extends AbstractOwnableSynchronizer {
        /**
         * The synchronization state.
         */
        private volatile int state;
        private static final long stateOffset;
    
        static {
            try {
                Field field =
                        Unsafe.class.getDeclaredField("theUnsafe");
                field.setAccessible(true);
                unsafe = (Unsafe) field.get(null);
    
                stateOffset = unsafe.objectFieldOffset
                        (CustomAbstractQueuedSynchronizer.class.getDeclaredField("state"));
            } catch (Exception ex) { throw new Error(ex); }
        }
    
        protected final int getState() {
            return state;
        }
    
        protected final void setState(int newState) {
            state = newState;
        }
    
        protected final boolean compareAndSetState(int expect, int update) {
            // See below for intrinsics setup to support this
            return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
        }
    
    }
    

    重入锁的实现如下:
    实现逻辑很简单,当有线程获得锁时调用setExclusiveOwnerThread方法设置当前获得锁的线程。当cas获得锁失败,判断是否是同一个线程再次获得锁,如果是则state加1。释放锁的时state减1。如果state为0,清空当前获得锁的线程。

    public class SpinReentrantLock implements Lock {
    
    
        private Sync sync;
    
        public SpinReentrantLock() {
            sync = new SimpleNonfairSync();
        }
    
     abstract static class Sync extends CustomAbstractQueuedSynchronizer {
            protected abstract void lock();
    
            protected abstract void unlock();
        }
      static final class SimpleNonfairSync extends Sync {
            @Override
            protected void lock() {
                boolean flag;
                do {
                    Thread current = Thread.currentThread();
                    if (flag = compareAndSetState(0, 1)) {
                        //System.out.println(current.getName() + " 获得锁");
                        setExclusiveOwnerThread(current);
    
                    } else if (current == getExclusiveOwnerThread()) {
                        int c = getState();
                        int nextc = c + 1;
                        if (nextc < 0) {
                            // overflow
                            throw new Error("Maximum lock count exceeded");
                        }
                        //System.out.println(current.getName() + " 重入state:" + nextc);
                        setState(nextc);
                        flag = true;
    
                    }
                }
                while (!flag);
    
            }
    
            @Override
            protected void unlock() {
                int c = getState() - 1;
                if (Thread.currentThread() != getExclusiveOwnerThread())
                    throw new IllegalMonitorStateException();
                if (c == 0) {
                    setExclusiveOwnerThread(null);
                }
               // System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
                setState(c);
            }
    
        }
    
        @Override
        public void lock() {
            sync.lock();
        }
    
        @Override
        public void unlock() {
            sync.unlock();
        }
     }
    

    三、队列

    当并发比较高的时候大量的CAS失败可能导致SpinReentrantLock锁的效率比较低,且自旋比较消耗CUP。所以当线程获取锁失败,我们把线程放入队列中并挂起。当线程释放锁时唤起挂起的线程。


    image.png

    在抽象类CustomAbstractQueuedSynchronizer中加入一个线程安全的链表threadQueue用于存放被挂起的线程。head变量的作用是记录队列的头结点。acquire方法使用的是模板设计模式,tryAcquire获得锁的逻辑,交由子类实现,当线程获得锁失败,调用LockSupport.park(this)挂起线程,如果获得锁成功线程出队,并更新head。完整代码如下

    public abstract class CustomAbstractQueuedSynchronizer extends AbstractOwnableSynchronizer {
        /**
         * The synchronization state.
         */
        private volatile int state;
    
        private static final Unsafe unsafe;
        private static final long stateOffset;
    
        private transient volatile Thread head;
    
        protected Queue<Thread> threadQueue = new ConcurrentLinkedQueue<>();
    
        static {
            try {
                Field field =
                        Unsafe.class.getDeclaredField("theUnsafe");
                field.setAccessible(true);
                unsafe = (Unsafe) field.get(null);
    
                stateOffset = unsafe.objectFieldOffset
                        (CustomAbstractQueuedSynchronizer.class.getDeclaredField("state"));
            } catch (Exception ex) {
                throw new Error(ex);
            }
        }
    
        protected final int getState() {
            return state;
        }
    
        protected final void setState(int newState) {
            state = newState;
        }
    
    
        protected final boolean compareAndSetState(int expect, int update) {
            // See below for intrinsics setup to support this
            return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
        }
    
        public Thread getHead() {
            return head;
        }
    
        public void setHead(Thread head) {
            this.head = head;
        }
    
        /**
         * 获取锁的逻辑,交由子类实现
         * @param arg
         * @return
         */
        protected boolean tryAcquire(int arg) {
            throw new UnsupportedOperationException();
        }
        /**
         * 判断队列中是否为空
         * @return
         */
        public final boolean hasQueuedPredecessors() {
            return threadQueue.isEmpty();
        }
    
    /**
         * 释放锁的逻辑,交由子类实现
         * @param arg
         * @return
         */
        protected boolean tryRelease(int arg) {
            throw new UnsupportedOperationException();
        }
    
        /**
         * 获得锁和线程入队,以及唤醒后的逻辑
         * @param arg
         */
        public final void acquire(int arg) {
            Thread current = Thread.currentThread();
            //调用tryAcquire获得锁失败,线程放入队列中
            if (!tryAcquire(arg) && threadQueue.add(current)) {
                if (getHead() == null) {
                    setHead(threadQueue.peek());
                }
                //只要获得锁成功才能跳出循环
                for (; ; ) {
                    if (current == getHead() && tryAcquire(arg)) {
                    //任务出队
                        threadQueue.poll();
                        //头部元素出队之后,更新头元素
                        setHead(threadQueue.peek());
                        return;
                    }
                   // System.out.println("挂起线程: " + current.getName() + " size: " + Arrays.toString(threadQueue.toArray()));
                   //获得锁失败,挂起线程
                    LockSupport.park(this);
                }
            }
        }
    }
    

    Sync的unlock方法逻辑如下

    1. 重写tryRelease方法,当sate等于0的时候返回true表示释放锁成功。
    2. 如果释放锁成功,则调用threadQueue.peek()方法获得头结点,并通过LockSupport.unpark(poll)唤起线程。
     abstract static class Sync extends CustomAbstractQueuedSynchronizer {
            protected abstract void lock();
    
            protected void unlock() {
             if (tryRelease(1)){
                 Thread poll = threadQueue.peek();
                 if (poll != null) {
                     //System.out.println(Thread.currentThread().getName() + " 唤起线程: " + poll.getName() + " size: "+threadQueue.size());
                     LockSupport.unpark(poll);
                 } else {
                     setHead(null);
                 }
             }
            }
    
            @Override
            protected boolean tryRelease(int arg) {
                int c = getState() - 1;
                if (Thread.currentThread() != getExclusiveOwnerThread()){
                    throw new IllegalMonitorStateException();
                }
                boolean free = false;
                if (c == 0) {
                    free=true;
                    setExclusiveOwnerThread(null);
                }
                // System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
                setState(c);
                return free;
            }
        }
    
    

    NonfairSync类方法如下。
    正如上面提到acquire使用的是模板设计模式,获得锁的逻辑由tryAcquire实现。(tryAcquire的实现是一种非公平的模式)

     static final class NonfairSync extends Sync {
    
            @Override
            protected void lock() {
                Thread current = Thread.currentThread();
                if (compareAndSetState(0, 1)) {
                    // System.out.println(current.getName() + " 获得锁");
                    setExclusiveOwnerThread(current);
    
                }else {
                    acquire(1);
                }
    
            }
    
            @Override
            protected boolean tryAcquire(int arg) {
                return nonfairTryAcquire(arg);
            }
    
            final boolean nonfairTryAcquire(int acquires) {
                final Thread current = Thread.currentThread();
                int c = getState();
                if (c == 0) {
                    if (compareAndSetState(0, acquires)) {
                        //   System.out.println(current.getName() + " 获得锁");
                        setExclusiveOwnerThread(current);
                        return true;
                    }
                } else if (current == getExclusiveOwnerThread()) {
                    int nextc = c + acquires;
                    if (nextc < 0) // overflow
                        throw new Error("Maximum lock count exceeded");
                    // System.out.println(current.getName() + " 重入state:" + nextc);
                    setState(nextc);
                    return true;
                }
                return false;
            }
        }
    

    测试

    public class Main {
    
    
        static int value = 0;
        public static void main(String[] args) throws InterruptedException {
            SpinReentrantLock spinReentrantLock = new SpinReentrantLock(true);
            final CyclicBarrier cyclicBarrier = new CyclicBarrier(1000);
            final CountDownLatch countDownLatch = new CountDownLatch(1000);
            long start = System.currentTimeMillis();
            for (int i = 0; i < 1000 ; i++) {
                new Thread(new Runnable() {
                    public void run() {
                        try {
                            cyclicBarrier.await();
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                        spinReentrantLock.lock();
                        // System.out.println(Thread.currentThread().getName() + " 获得锁");
                        for (int j = 0; j < 1000; j++) {
                            value++;
                        }
                        spinReentrantLock.unlock();
                        countDownLatch.countDown();
                    }
                },"thread:"+i).start();
            }
            countDownLatch.await();
            long end = System.currentTimeMillis();
            System.out.println("执行时间:" + (end - start));
            System.out.println("value: " + value);
        }
    }
    
    执行时间:70
    value: 1000000
    

    四、公平锁

    队列中的任务线程优先执行,后到的线程只能只能排队等待。代码实现如下:


    可以看到相对于非公平锁,公平锁的实现只是在获得锁前,调用hasQueuedPredecessors方法检查队列中是否有值。

      static final class FairSync extends Sync {
    
            @Override
            protected void lock() {
                acquire(1);
            }
    
            protected final boolean tryAcquire(int acquires) {
                final Thread current = Thread.currentThread();
                int c = getState();
                if (c == 0) {
                    if (!hasQueuedPredecessors() &&
                            compareAndSetState(0, acquires)) {
                        setExclusiveOwnerThread(current);
                        return true;
                    }
                } else if (current == getExclusiveOwnerThread()) {
                    int nextc = c + acquires;
                    if (nextc < 0)
                        throw new Error("Maximum lock count exceeded");
                    setState(nextc);
                    return true;
                }
                return false;
            }
        }
    

    <font color=rgb(63,63,63) face="microsoft yahei" size=4>测试

    public class Main {
        static int value = 0;
        public static void main(String[] args) throws InterruptedException {
            SpinReentrantLock spinReentrantLock = new SpinReentrantLock(true);
            final CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
            final CountDownLatch countDownLatch = new CountDownLatch(10);
            long start = System.currentTimeMillis();
            for (int i = 0; i < 10 ; i++) {
                new Thread(new Runnable() {
                    public void run() {
                        try {
                            cyclicBarrier.await();
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                        spinReentrantLock.lock();
                        // System.out.println(Thread.currentThread().getName() + " 获得锁");
                        for (int j = 0; j < 1000; j++) {
                            value++;
                        }
                        spinReentrantLock.unlock();
                        countDownLatch.countDown();
                    }
                },"thread:"+i).start();
            }
            countDownLatch.await();
            long end = System.currentTimeMillis();
            System.out.println("执行时间:" + (end - start));
            System.out.println("value: " + value);
        }
    }
    

    结果,可以看到任务都是按照入队的顺序执行。

    thread:0获得锁
    挂起线程: thread:6 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main]]
    挂起线程: thread:5 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main]]
    挂起线程: thread:1 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
    挂起线程: thread:2 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
    挂起线程: thread:3 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
    挂起线程: thread:4 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
    挂起线程: thread:7 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main]]
    挂起线程: thread:9 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main], Thread[thread:8,5,main]]
    挂起线程: thread:8 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main], Thread[thread:8,5,main]]
    thread:9获得锁
    thread:1获得锁
    thread:2获得锁
    thread:3获得锁
    thread:4获得锁
    thread:5获得锁
    thread:6获得锁
    thread:7获得锁
    thread:8获得锁
    执行时间:3
    value: 10000
    

    五、总结:

    最后附上SpinReentrantLock完整实现。

    public class SpinReentrantLock implements Lock {
    
    
        private Sync sync;
    
        public SpinReentrantLock() {
            sync = new NonfairSync();
    
        }
    
        public SpinReentrantLock(boolean fair) {
            if (fair){
                sync = new FairSync();
            }else {
                sync = new NonfairSync();
            }
        }
    
        static final class FairSync extends Sync {
    
            @Override
            protected void lock() {
                acquire(1);
            }
    
    //        public final void acquire(int arg) {
    //            Thread current = Thread.currentThread();
    //            if (!tryAcquire(arg) &&threadQueue.add(current)) {
    //                if (getHead() == null) {
    //                    setHead(threadQueue.peek());
    //                }
    //                for (; ; ) {
    //                    if (current == getHead() && tryAcquire(arg)) {
    //                        threadQueue.poll();
    //                        //头部元素出队之后,更新头元素
    //                        setHead(threadQueue.peek());
    //                        return;
    //                    }
    //                     System.out.println("挂起线程: " +current.getName()+" size: "+ Arrays.toString(threadQueue.toArray()));
    //                    LockSupport.park(this);
    //                }
    //            }
    //        }
    
            protected final boolean tryAcquire(int acquires) {
                final Thread current = Thread.currentThread();
                int c = getState();
                if (c == 0) {
                    if (!hasQueuedPredecessors() &&
                            compareAndSetState(0, acquires)) {
                        setExclusiveOwnerThread(current);
                        return true;
                    }
                } else if (current == getExclusiveOwnerThread()) {
                    int nextc = c + acquires;
                    if (nextc < 0)
                        throw new Error("Maximum lock count exceeded");
                    setState(nextc);
                    return true;
                }
                return false;
            }
        }
    
        abstract static class Sync extends CustomAbstractQueuedSynchronizer {
            protected abstract void lock();
    
            protected void unlock() {
             if (tryRelease(1)){
                 Thread poll = threadQueue.peek();
                 if (poll != null) {
                     //System.out.println(Thread.currentThread().getName() + " 唤起线程: " + poll.getName() + " size: "+threadQueue.size());
                     LockSupport.unpark(poll);
                 } else {
                     setHead(null);
                 }
             }
            }
    
            @Override
            protected boolean tryRelease(int arg) {
                int c = getState() - 1;
                if (Thread.currentThread() != getExclusiveOwnerThread()){
                    throw new IllegalMonitorStateException();
                }
                boolean free = false;
                if (c == 0) {
                    free=true;
                    setExclusiveOwnerThread(null);
                }
                 System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
                setState(c);
                return free;
            }
        }
    
        static final class NonfairSync extends Sync {
    
            @Override
            protected void lock() {
                Thread current = Thread.currentThread();
                if (compareAndSetState(0, 1)) {
                    // System.out.println(current.getName() + " 获得锁");
                    setExclusiveOwnerThread(current);
    
                }else {
                    acquire(1);
                }
    //            else if (!tryAcquire(1) && threadQueue.add(current)) {
    //                //每次都是从头部元素开始唤起
    //                if (getHead() == null) {
    //                    setHead(threadQueue.peek());
    //                }
    //                for (; ; ) {
    //                    if (current == getHead() && tryAcquire(1)) {
    //                        threadQueue.poll();
    //                        //头部元素出队之后,更新头元素
    //                        setHead(threadQueue.peek());
    //                        return;
    //                    }
    //                  //  System.out.println("挂起线程: " +current.getName()+" size: "+ Arrays.toString(threadQueue.toArray()));
    //                    LockSupport.park(this);
    //                }
    //            }
            }
    
            @Override
            protected boolean tryAcquire(int arg) {
                return nonfairTryAcquire(arg);
            }
    
            final boolean nonfairTryAcquire(int acquires) {
                final Thread current = Thread.currentThread();
                int c = getState();
                if (c == 0) {
                    if (compareAndSetState(0, acquires)) {
                        //   System.out.println(current.getName() + " 获得锁");
                        setExclusiveOwnerThread(current);
                        return true;
                    }
                } else if (current == getExclusiveOwnerThread()) {
                    int nextc = c + acquires;
                    if (nextc < 0) // overflow
                        throw new Error("Maximum lock count exceeded");
                    // System.out.println(current.getName() + " 重入state:" + nextc);
                    setState(nextc);
                    return true;
                }
                return false;
            }
        }
    
        static final class SimpleNonfairSync extends Sync {
            @Override
            protected void lock() {
                boolean flag;
                do {
                    Thread current = Thread.currentThread();
                    if (flag = compareAndSetState(0, 1)) {
                        System.out.println(current.getName() + " 获得锁");
                        setExclusiveOwnerThread(current);
    
                    } else if (current == getExclusiveOwnerThread()) {
                        int c = getState();
                        int nextc = c + 1;
                        if (nextc < 0) {
                            // overflow
                            throw new Error("Maximum lock count exceeded");
                        }
                        System.out.println(current.getName() + " 重入state:" + nextc);
                        setState(nextc);
                        flag = true;
    
                    }
                }
                while (!flag);
    
            }
    
            @Override
            protected void unlock() {
                int c = getState() - 1;
                if (Thread.currentThread() != getExclusiveOwnerThread())
                    throw new IllegalMonitorStateException();
                if (c == 0) {
                    setExclusiveOwnerThread(null);
                }
               // System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
                setState(c);
            }
    
        }
    
        @Override
        public void lock() {
            sync.lock();
        }
    
        @Override
        public void unlock() {
            sync.unlock();
        }
    }
    
    

    上述实现的锁功能还比较简单,比如暂时还不支持响应中断,或者超时挂起等,但实现起来并不难,这里就不在赘述。

    下一节我们探讨线程并发工具的基石AQS

    相关文章

      网友评论

          本文标题:ReentrantLock实现原理-如何实现一把锁(一)

          本文链接:https://www.haomeiwen.com/subject/yhhimltx.html