美文网首页
手写实现自定义线程池

手写实现自定义线程池

作者: 雨夜都行 | 来源:发表于2020-04-04 01:54 被阅读0次

    从今天开始养成写文章的习惯。同时想把自己知道的,学习到的java知识和大家一起分享,共同进步。

    动手实现一个简化版的线程池,可以通过这个例子,了解线程池大致的工作原理
    已实现的功能:
    1.阻塞等待队列
    2.自定义线程池
    3.拒绝策略
    4.线程池测试

    package threadpool;
    
    import java.time.LocalDateTime;
    
    import java.util.ArrayDeque;
    import java.util.Deque;
    import java.util.concurrent.RejectedExecutionException;
    import java.util.concurrent.TimeUnit;
    import java.util.concurrent.atomic.AtomicInteger;
    import java.util.concurrent.locks.Condition;
    import java.util.concurrent.locks.ReentrantLock;
    
    /**
     * 自定义线程池测试
     *
     * @author
     * @date 2020.4.3
     */
    public class MyThreadPoolTest {
        public static void main(String[] args) {
            // MyThreadPool threadPool = new MyThreadPool(2, 1, (r) -> {
                // 拒绝策略1:打印日志
    //            System.out.println("拒绝执行"));
    //        MyThreadPool threadPool = new MyThreadPool(2, 1, (r) -> {
                // 拒绝策略2:抛出异常
    //            throw new RejectedExecutionException("阻塞等待队列已满");
    //        });
            MyThreadPool threadPool = new MyThreadPool(2, 1, (r) -> {
                // 拒绝策略3:由当前线程来执行
                r.run();
            });
            threadPool.execute(() -> {
                System.out.println("----->hello1 start");
                System.out.println(LocalDateTime.now());
                // 模拟需要执行很久
                sleep(2);
                System.out.println(LocalDateTime.now());
                System.out.println("----->hello1 end");
            });
            threadPool.execute(() -> {
                System.out.println("----->hello2 start");
                System.out.println(LocalDateTime.now());
                sleep(3);
                System.out.println(LocalDateTime.now());
                System.out.println("----->hello2 end");
            });
            // 先加入队列 , 2s 后执行
            threadPool.execute(() -> {
                System.out.println("----->hello3 start");
                System.out.println(LocalDateTime.now());
                sleep(1);
                System.out.println(LocalDateTime.now());
                System.out.println("----->hello3 end");
            });
            // 进入等待 2.2s 后 应该要进入 队列 . 3s 后执行
            sleep(0.2);
            threadPool.execute(() -> {
                System.out.println(Thread.currentThread().getName() + "----->hello4 start");
                System.out.println(LocalDateTime.now());
                System.out.println("----->hello4 end");
            });
        }
    
        public static void sleep(double sleepTime) {
            try {
                TimeUnit.MILLISECONDS.sleep((long) (sleepTime * 1000L));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
    
    /**
     * 自定义线程池
     */
    class MyThreadPool {
    
        // 核心线程数
        private int corePoolSize;
    
        // 提交的任务数
        private AtomicInteger count;
    
        // 阻塞队列
        private BlockingQueue<Runnable> blockingQueue;
    
        // 拒绝策略
        RejectedPolicy rejectedPolicy;
    
        public MyThreadPool(int corePoolSize, int QueueSize) {
            this.corePoolSize =  corePoolSize;
            this.count = new AtomicInteger();
            this.blockingQueue = new BlockingQueue<>(QueueSize);
        }
    
        public MyThreadPool(int corePoolSize, int QueueSize, RejectedPolicy rejectedPolicy) {
            this.corePoolSize =  corePoolSize;
            this.count = new AtomicInteger();
            this.blockingQueue = new BlockingQueue<>(QueueSize);
            this.rejectedPolicy = rejectedPolicy;
        }
    
        public void execute(Runnable task) {
            int curCount = count.getAndIncrement();
            if (curCount < corePoolSize) {
                Worker worker = new Worker(task, "worker" + curCount);
                worker.start();
            } else {
                if (rejectedPolicy != null) {
                    blockingQueue.put(task, rejectedPolicy);
    
                } else {
                    blockingQueue.put(task);
                }
            }
        }
    
        public boolean shutdown() {
            return true;
        }
    
        /**
         * 工作线程用来处理提交的任务
         */
        class Worker extends Thread {
    
            private Runnable task;
            private int count;
            private String workerName;
    
            public Worker(Runnable task, String workerName) {
                super(task);
                super.setName(workerName);
                this.workerName = workerName;
                this.task = task;
            }
    
            @Override
            public void run() {
                while(task != null || (task = blockingQueue.take()) != null) {
                    System.out.println(workerName + "执行任务");
                    task.run();
                    task = null;
                    System.out.println(workerName + "执行已执行任务数:" + ++count);
                }
            }
        }
    }
    
    /**
     * 拒绝策略
     */
    interface RejectedPolicy{
    
        void rejectedExecution(Runnable runnable);
    }
    
    /**
     * 阻塞队列
     * @param <T>
     */
    class BlockingQueue<T> {
    
        // 存放数据
        private Deque<T> blockingQueue;
    
        // 等待时间
        private int timeOut;
    
        // 队列大小
        private int capacity;
    
        // 生产者与消费者公用的锁
        private ReentrantLock lock = new ReentrantLock();
    
        // 条件变量
        private Condition notFull = lock.newCondition();
        // 条件变量
        private Condition notEmpty = lock.newCondition();
    
        public BlockingQueue(int capacity){
            this.capacity = capacity;
            this.blockingQueue = new ArrayDeque<>(capacity);
        }
    
    
        public void put(T ele) {
            if (ele == null) {
                return;
            }
            lock.lock();
            try {
                while(blockingQueue.size() == capacity) {
                    System.out.println("进入等待");
                    System.out.println(LocalDateTime.now());
                    notFull.await();
                }
                System.out.println("加入队列");
                blockingQueue.addLast(ele);
                notEmpty.signalAll();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                lock.unlock();
            }
        }
    
        /**
         * 带拒绝策略的put
         * @param ele
         * @param rejectedPolicy
         */
        public void put(T ele, RejectedPolicy rejectedPolicy) {
            if (ele == null || rejectedPolicy == null) {
                return;
            }
            lock.lock();
            try {
                while(blockingQueue.size() == capacity) {
                    rejectedPolicy.rejectedExecution((Runnable)ele);
                    return;
                }
                System.out.println("加入队列");
                blockingQueue.addLast(ele);
                notEmpty.signalAll();
            } finally {
                lock.unlock();
            }
        }
    
        /**
         * 带超时时间的获取
         * @param timeOut
         * @return
         */
        public T poll(long timeOut) {
            T ele = null;
            long next = 0;
            lock.lock();
            try {
                while(blockingQueue.size() == 0) {
                    timeOut = timeOut - next;
                    if (timeOut <= 0) {
                        return null;
                    }
                    next = notEmpty.awaitNanos(timeOut);
                }
                ele = blockingQueue.pollFirst();
                notFull.signalAll();
                return ele;
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                lock.unlock();
            }
            return ele;
        }
    
        public T take() {
            T ele = null;
            lock.lock();
            try {
                while(blockingQueue.size() == 0) {
                    System.out.println("队列中的数据为空,"+ Thread.currentThread().getName() +"进入等待");
                    // 反过来看,等待队列中的数据不为空
                    notEmpty.await();
                }
                ele = blockingQueue.pollFirst();
                notFull.signalAll();
                return ele;
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                lock.unlock();
            }
            return ele;
        }
    }
    

    以上就是线程池简化版的实现,下个文章和大家分享AQS同步器的工作原理

    相关文章

      网友评论

          本文标题:手写实现自定义线程池

          本文链接:https://www.haomeiwen.com/subject/bbyrphtx.html