美文网首页
10行代码实现“生产者消费者”

10行代码实现“生产者消费者”

作者: 黄大海 | 来源:发表于2019-10-28 17:08 被阅读0次

请原谅我的标题,10行代码指的是模版类的调用方。。。

最近接到一个任务:对一个错综复杂的老系统优化性能。
好几处都用到了“生产者消费者”来做并行处理。
所以写了个模版类,把多线程涉及的系统功能和具体的业务处理分离开。
先贴上完整代码,之后再分段说明。

package com.xxx.xxx;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.LongStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;

/**
 * 一个生产者对应多个消费者
 * @author huanghai
 * @date 2019/10/28
 */
public class ProducerConsumersTemplate<P, D> {
    private final static Logger log = LoggerFactory.getLogger(ProducerConsumersTemplate.class);
    /** 完成标志 */
    private final D MARK_AS_END;
    /** 执行器(线程池) */
    private Executor executor;
    /** 缓冲队列大小 */
    private Integer queueSize = 300;
    /** 消费者线程数 */
    private Integer consumerCount = 100;
    /** 生产者者处理回调 */
    private Function<P, Iterator<D>> producer;
    /** 消费者处理回调 */
    private BiConsumer<P, D> consumer;
    
    private ProducerConsumersTemplate(D markAsEnd) {
        this.MARK_AS_END = markAsEnd;
    }
    
    public static <P, D> Builder<P, D> builder(D markAsEnd){
        return new Builder<>(markAsEnd);
    }
    
    /**
     * 构建template, 校验参数
     */
    public static class Builder<P, D> {
        private ProducerConsumersTemplate<P, D> result;
        
        public Builder(D markAsEnd){
            this.result = new ProducerConsumersTemplate<P, D>(markAsEnd);
        }
        public Builder<P, D> setExecutor(Executor executor) {
            result.executor = executor;
            return this;
        }
        public Builder<P, D> setQueueSize(Integer queueSize) {
            result.queueSize = queueSize;
            return this;
        }
        public Builder<P, D> setConsumerCount(Integer consumerCount) {
            result.consumerCount = consumerCount;
            return this;
        }
        public Builder<P, D> setProducer(Function<P, Iterator<D>> producer) {
            result.producer = producer;
            return this;
        }
        public Builder<P, D> setConsumer(BiConsumer<P, D> consumer) {
            result.consumer = consumer;
            return this;
        }
        public ProducerConsumersTemplate<P, D> build(){
            Assert.notNull(result.MARK_AS_END, "MARK_AS_END required");
            Assert.notNull(result.executor, "executor required");
            Assert.notNull(result.queueSize, "queueSize required");
            Assert.notNull(result.consumerCount, "consumerCount required");
            Assert.notNull(result.producer, "producer required");
            Assert.notNull(result.consumer, "consumer required");
            return result;
        }
    }

    public CompletableFuture<Void> run(P param){
        
        Context context = new Context(param, queueSize, consumerCount);
        try {
            log.trace("开启消费线程 param: {}", param);
            startConsumers(context);
            
            log.trace("开启生产线程 param: {}", param);
            produceData(context);
            
            log.trace("等待处理完成 param: {}", param);
            return CompletableFuture.allOf(context.getRunnings()
                .toArray(new CompletableFuture[context.getRunnings().size()]));
        }catch(Exception e){
            log.trace("处理失败 param: {}", param);
            context.getRunnings().forEach(each -> each.cancel(true));
            throw new RuntimeException();
        }
    }

    public void produceData(Context context) throws Exception {
        CompletableFuture<Void> producerFuture = CompletableFuture.runAsync(() -> {
                try{
                    Iterator<D> it = producer.apply(context.getParam());
                    while(it.hasNext()){
                        context.getDataQueue().put(it.next());
                    }
                    for(int i = 0; i < consumerCount; i++){
                        context.getDataQueue().put(MARK_AS_END);//通过MARK_AS_END通知线程结束
                    }
                    log.debug("生产线程完成");
                }catch(Exception e){
                    throw new RuntimeException(e);
                }
            }, executor);
        context.getRunnings().add(producerFuture);
        log.trace("生产线程已启动");
    }
    
    protected void startConsumers(Context context) {
        for(int i = 0; i < consumerCount; i++){
            CompletableFuture<Void> running = CompletableFuture.runAsync(() -> {
                try {
                    while(true){
                        D data = context.getDataQueue().take();
                        if(data == MARK_AS_END){
                            log.debug("收到终止信号,消费线程退出");
                            return;
                        }
                        log.trace("处理数据, {}", data);
                        consumer.accept(context.getParam(), data);
                    }
                }catch(Exception e){
                    throw new RuntimeException(e);
                }
            }, executor);
            context.getRunnings().add(running);
        }
        log.trace("{} 消费线程已启动", consumerCount);
    }
    
    /**
     * 等待所有线程完成,合并消费者的结果
     * @param context 相关参数
     * @return 合并后的结果集
     */
    public class Context {
        /** 自定义参数 */
        private P param;
        /** 生产者/消费者-缓冲队列 */
        private BlockingQueue<D> dataQueue;
        /** 所有线程 */
        private List<CompletableFuture<Void>> runnings = new ArrayList<>();
        
        protected Context(P param, int queueSize, int threadCount){
            this.param = param;
            this.dataQueue = new LinkedBlockingQueue<>(queueSize);
            this.runnings = new ArrayList<>(threadCount);
        }
        public P getParam() {
            return param;
        }
        protected BlockingQueue<D> getDataQueue() {
            return dataQueue;
        }
        protected List<CompletableFuture<Void>> getRunnings() {
            return runnings;
        }
    }
    
    /**
     * 测试/示例
     */
    public static void main(String[] args) {
        Builder<Object, Long> builder = ProducerConsumersTemplate.builder(Long.MAX_VALUE)
            .setExecutor(Executors.newFixedThreadPool(101))
            .setConsumerCount(100)
            .setQueueSize(300);
        builder.setProducer(param -> LongStream.range(1L, 200L).iterator());//生产1-200的Long
        builder.setConsumer((param, data) -> {
            try {
                log.info("consume {}", data);
                Thread.sleep(100L);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
        builder.build().run(new Object()).join();
        log.info("Done");
    }
}

我们从使用的角度直观得感受一下

    /**
     * 测试/示例
     */
    public static void main(String[] args) {
        Builder<Object, Long> builder = ProducerConsumersTemplate.builder(Long.MAX_VALUE)
            .setExecutor(Executors.newFixedThreadPool(101))
            .setConsumerCount(100)
            .setQueueSize(300);
        builder.setProducer(param -> LongStream.range(1L, 200L).iterator());//生产1-200的Long
        builder.setConsumer((param, data) -> {
            try {
                log.info("consume {}", data);
                Thread.sleep(100L);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        });
        builder.build().run(new Object()).join();
        log.info("Done");
    }
  1. 首先构造一个Builder对象,该对象用来初始化模版Template的一些参数。
  2. 设置“生产者”和“消费者”的回调方法
  3. builder.build()方法返回构造好的Template对象
  4. run()方法会执行启动各种线程幷返回一个future,传入对象是自定义的业务对象,用来传递业务相关值
  5. join()方法会阻塞,等待执行完成。

构造template

public class ProducerConsumersTemplate<P, D> {
    private final static Logger log = LoggerFactory.getLogger(ProducerConsumersTemplate.class);
    /** 完成标志 */
    private final D MARK_AS_END;
    /** 执行器(线程池) */
    private Executor executor;
    /** 缓冲队列大小 */
    private Integer queueSize = 300;
    /** 消费者线程数 */
    private Integer consumerCount = 100;
    /** 生产者者处理回调 */
    private Function<P, Iterator<D>> producer;
    /** 消费者处理回调 */
    private BiConsumer<P, D> consumer;
    
    private ProducerConsumersTemplate(D markAsEnd) {
        this.MARK_AS_END = markAsEnd;
    }
    
    public static <P, D> Builder<P, D> builder(D markAsEnd){
        return new Builder<>(markAsEnd);
    }
    
    /**
     * 构建template, 校验参数
     */
    public static class Builder<P, D> {
        private ProducerConsumersTemplate<P, D> result;
        
        public Builder(D markAsEnd){
            this.result = new ProducerConsumersTemplate<P, D>(markAsEnd);
        }
        public Builder<P, D> setExecutor(Executor executor) {
            result.executor = executor;
            return this;
        }
        public Builder<P, D> setQueueSize(Integer queueSize) {
            result.queueSize = queueSize;
            return this;
        }
        public Builder<P, D> setConsumerCount(Integer consumerCount) {
            result.consumerCount = consumerCount;
            return this;
        }
        public Builder<P, D> setProducer(Function<P, Iterator<D>> producer) {
            result.producer = producer;
            return this;
        }
        public Builder<P, D> setConsumer(BiConsumer<P, D> consumer) {
            result.consumer = consumer;
            return this;
        }
        public ProducerConsumersTemplate<P, D> build(){
            Assert.notNull(result.MARK_AS_END, "MARK_AS_END required");
            Assert.notNull(result.executor, "executor required");
            Assert.notNull(result.queueSize, "queueSize required");
            Assert.notNull(result.consumerCount, "consumerCount required");
            Assert.notNull(result.producer, "producer required");
            Assert.notNull(result.consumer, "consumer required");
            return result;
        }
    }
    ...
  1. template需要比较多的参数,这里我们采用经典的静态内部类来实现Builder模式
  2. 泛型P指的是Parameter参数对象,用来传递业务参数
  3. 泛型D指的是Data数据对象,由生产者创建,放入队列,最后由消费者处理
  4. 各个参数的意义参考代码内注释,之后还会涉及

主处理线程

    public CompletableFuture<Void> run(P param){
        
        Context context = new Context(param, queueSize, consumerCount);
        try {
            log.trace("开启消费线程 param: {}", param);
            startConsumers(context);
            
            log.trace("开启生产线程 param: {}", param);
            produceData(context);
            
            log.trace("等待处理完成 param: {}", param);
            return CompletableFuture.allOf(context.getRunnings()
                .toArray(new CompletableFuture[context.getRunnings().size()]));
        }catch(Exception e){
            log.trace("处理失败 param: {}", param);
            context.getRunnings().forEach(each -> each.cancel(true));
            throw new RuntimeException();
        }
    }
  1. Context对象保存一些公用变量,方便传递和简化接口
  2. CompletableFuture.allOf()把所有线程合成一个future,又上层调用者决定是否同步执行
  3. 如果发生异常,把所有运行中的线程取消掉

生产者

    public void produceData(Context context) throws Exception {
        CompletableFuture<Void> producerFuture = CompletableFuture.runAsync(() -> {
                try{
                    Iterator<D> it = producer.apply(context.getParam());
                    while(it.hasNext()){
                        context.getDataQueue().put(it.next());
                    }
                    for(int i = 0; i < consumerCount; i++){
                        context.getDataQueue().put(MARK_AS_END);//通过MARK_AS_END通知线程结束
                    }
                    log.debug("生产线程完成");
                }catch(Exception e){
                    throw new RuntimeException(e);
                }
            }, executor);
        context.getRunnings().add(producerFuture);
        log.trace("生产线程已启动");
    }
  1. 生产者回调函数会传入业务Param,并返回一个Iterator对象
  2. 遍历Iterator,放入BlockingQueue, 队列满会阻塞
  3. 数据遍历结束后,放入MARK_AS_END这个特殊对象,这个对象用来通知消费者,生产者完成了。(也叫Poision Pill)

消费者

    protected void startConsumers(Context context) {
        for(int i = 0; i < consumerCount; i++){
            CompletableFuture<Void> running = CompletableFuture.runAsync(() -> {
                try {
                    while(true){
                        D data = context.getDataQueue().take();
                        if(data == MARK_AS_END){
                            log.debug("收到终止信号,消费线程退出");
                            return;
                        }
                        log.trace("处理数据, {}", data);
                        consumer.accept(context.getParam(), data);
                    }
                }catch(Exception e){
                    throw new RuntimeException(e);
                }
            }, executor);
            context.getRunnings().add(running);
        }
        log.trace("{} 消费线程已启动", consumerCount);
    }
  1. 消费者从队列里拿出对象,先判断是否== MARK_AS_END。如果相等。说明生产者已经结束,数据消费完了。直接退出
  2. 如果不相等,传给消费者回调

相关文章

网友评论

      本文标题:10行代码实现“生产者消费者”

      本文链接:https://www.haomeiwen.com/subject/pxvbvctx.html