请原谅我的标题,10行代码指的是模版类的调用方。。。
最近接到一个任务:对一个错综复杂的老系统优化性能。
好几处都用到了“生产者消费者”来做并行处理。
所以写了个模版类,把多线程涉及的系统功能和具体的业务处理分离开。
先贴上完整代码,之后再分段说明。
package com.xxx.xxx;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.LongStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
/**
* 一个生产者对应多个消费者
* @author huanghai
* @date 2019/10/28
*/
public class ProducerConsumersTemplate<P, D> {
private final static Logger log = LoggerFactory.getLogger(ProducerConsumersTemplate.class);
/** 完成标志 */
private final D MARK_AS_END;
/** 执行器(线程池) */
private Executor executor;
/** 缓冲队列大小 */
private Integer queueSize = 300;
/** 消费者线程数 */
private Integer consumerCount = 100;
/** 生产者者处理回调 */
private Function<P, Iterator<D>> producer;
/** 消费者处理回调 */
private BiConsumer<P, D> consumer;
private ProducerConsumersTemplate(D markAsEnd) {
this.MARK_AS_END = markAsEnd;
}
public static <P, D> Builder<P, D> builder(D markAsEnd){
return new Builder<>(markAsEnd);
}
/**
* 构建template, 校验参数
*/
public static class Builder<P, D> {
private ProducerConsumersTemplate<P, D> result;
public Builder(D markAsEnd){
this.result = new ProducerConsumersTemplate<P, D>(markAsEnd);
}
public Builder<P, D> setExecutor(Executor executor) {
result.executor = executor;
return this;
}
public Builder<P, D> setQueueSize(Integer queueSize) {
result.queueSize = queueSize;
return this;
}
public Builder<P, D> setConsumerCount(Integer consumerCount) {
result.consumerCount = consumerCount;
return this;
}
public Builder<P, D> setProducer(Function<P, Iterator<D>> producer) {
result.producer = producer;
return this;
}
public Builder<P, D> setConsumer(BiConsumer<P, D> consumer) {
result.consumer = consumer;
return this;
}
public ProducerConsumersTemplate<P, D> build(){
Assert.notNull(result.MARK_AS_END, "MARK_AS_END required");
Assert.notNull(result.executor, "executor required");
Assert.notNull(result.queueSize, "queueSize required");
Assert.notNull(result.consumerCount, "consumerCount required");
Assert.notNull(result.producer, "producer required");
Assert.notNull(result.consumer, "consumer required");
return result;
}
}
public CompletableFuture<Void> run(P param){
Context context = new Context(param, queueSize, consumerCount);
try {
log.trace("开启消费线程 param: {}", param);
startConsumers(context);
log.trace("开启生产线程 param: {}", param);
produceData(context);
log.trace("等待处理完成 param: {}", param);
return CompletableFuture.allOf(context.getRunnings()
.toArray(new CompletableFuture[context.getRunnings().size()]));
}catch(Exception e){
log.trace("处理失败 param: {}", param);
context.getRunnings().forEach(each -> each.cancel(true));
throw new RuntimeException();
}
}
public void produceData(Context context) throws Exception {
CompletableFuture<Void> producerFuture = CompletableFuture.runAsync(() -> {
try{
Iterator<D> it = producer.apply(context.getParam());
while(it.hasNext()){
context.getDataQueue().put(it.next());
}
for(int i = 0; i < consumerCount; i++){
context.getDataQueue().put(MARK_AS_END);//通过MARK_AS_END通知线程结束
}
log.debug("生产线程完成");
}catch(Exception e){
throw new RuntimeException(e);
}
}, executor);
context.getRunnings().add(producerFuture);
log.trace("生产线程已启动");
}
protected void startConsumers(Context context) {
for(int i = 0; i < consumerCount; i++){
CompletableFuture<Void> running = CompletableFuture.runAsync(() -> {
try {
while(true){
D data = context.getDataQueue().take();
if(data == MARK_AS_END){
log.debug("收到终止信号,消费线程退出");
return;
}
log.trace("处理数据, {}", data);
consumer.accept(context.getParam(), data);
}
}catch(Exception e){
throw new RuntimeException(e);
}
}, executor);
context.getRunnings().add(running);
}
log.trace("{} 消费线程已启动", consumerCount);
}
/**
* 等待所有线程完成,合并消费者的结果
* @param context 相关参数
* @return 合并后的结果集
*/
public class Context {
/** 自定义参数 */
private P param;
/** 生产者/消费者-缓冲队列 */
private BlockingQueue<D> dataQueue;
/** 所有线程 */
private List<CompletableFuture<Void>> runnings = new ArrayList<>();
protected Context(P param, int queueSize, int threadCount){
this.param = param;
this.dataQueue = new LinkedBlockingQueue<>(queueSize);
this.runnings = new ArrayList<>(threadCount);
}
public P getParam() {
return param;
}
protected BlockingQueue<D> getDataQueue() {
return dataQueue;
}
protected List<CompletableFuture<Void>> getRunnings() {
return runnings;
}
}
/**
* 测试/示例
*/
public static void main(String[] args) {
Builder<Object, Long> builder = ProducerConsumersTemplate.builder(Long.MAX_VALUE)
.setExecutor(Executors.newFixedThreadPool(101))
.setConsumerCount(100)
.setQueueSize(300);
builder.setProducer(param -> LongStream.range(1L, 200L).iterator());//生产1-200的Long
builder.setConsumer((param, data) -> {
try {
log.info("consume {}", data);
Thread.sleep(100L);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
builder.build().run(new Object()).join();
log.info("Done");
}
}
我们从使用的角度直观得感受一下
/**
* 测试/示例
*/
public static void main(String[] args) {
Builder<Object, Long> builder = ProducerConsumersTemplate.builder(Long.MAX_VALUE)
.setExecutor(Executors.newFixedThreadPool(101))
.setConsumerCount(100)
.setQueueSize(300);
builder.setProducer(param -> LongStream.range(1L, 200L).iterator());//生产1-200的Long
builder.setConsumer((param, data) -> {
try {
log.info("consume {}", data);
Thread.sleep(100L);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
});
builder.build().run(new Object()).join();
log.info("Done");
}
- 首先构造一个Builder对象,该对象用来初始化模版Template的一些参数。
- 设置“生产者”和“消费者”的回调方法
- builder.build()方法返回构造好的Template对象
- run()方法会执行启动各种线程幷返回一个future,传入对象是自定义的业务对象,用来传递业务相关值
- join()方法会阻塞,等待执行完成。
构造template
public class ProducerConsumersTemplate<P, D> {
private final static Logger log = LoggerFactory.getLogger(ProducerConsumersTemplate.class);
/** 完成标志 */
private final D MARK_AS_END;
/** 执行器(线程池) */
private Executor executor;
/** 缓冲队列大小 */
private Integer queueSize = 300;
/** 消费者线程数 */
private Integer consumerCount = 100;
/** 生产者者处理回调 */
private Function<P, Iterator<D>> producer;
/** 消费者处理回调 */
private BiConsumer<P, D> consumer;
private ProducerConsumersTemplate(D markAsEnd) {
this.MARK_AS_END = markAsEnd;
}
public static <P, D> Builder<P, D> builder(D markAsEnd){
return new Builder<>(markAsEnd);
}
/**
* 构建template, 校验参数
*/
public static class Builder<P, D> {
private ProducerConsumersTemplate<P, D> result;
public Builder(D markAsEnd){
this.result = new ProducerConsumersTemplate<P, D>(markAsEnd);
}
public Builder<P, D> setExecutor(Executor executor) {
result.executor = executor;
return this;
}
public Builder<P, D> setQueueSize(Integer queueSize) {
result.queueSize = queueSize;
return this;
}
public Builder<P, D> setConsumerCount(Integer consumerCount) {
result.consumerCount = consumerCount;
return this;
}
public Builder<P, D> setProducer(Function<P, Iterator<D>> producer) {
result.producer = producer;
return this;
}
public Builder<P, D> setConsumer(BiConsumer<P, D> consumer) {
result.consumer = consumer;
return this;
}
public ProducerConsumersTemplate<P, D> build(){
Assert.notNull(result.MARK_AS_END, "MARK_AS_END required");
Assert.notNull(result.executor, "executor required");
Assert.notNull(result.queueSize, "queueSize required");
Assert.notNull(result.consumerCount, "consumerCount required");
Assert.notNull(result.producer, "producer required");
Assert.notNull(result.consumer, "consumer required");
return result;
}
}
...
- template需要比较多的参数,这里我们采用经典的静态内部类来实现Builder模式
- 泛型P指的是Parameter参数对象,用来传递业务参数
- 泛型D指的是Data数据对象,由生产者创建,放入队列,最后由消费者处理
- 各个参数的意义参考代码内注释,之后还会涉及
主处理线程
public CompletableFuture<Void> run(P param){
Context context = new Context(param, queueSize, consumerCount);
try {
log.trace("开启消费线程 param: {}", param);
startConsumers(context);
log.trace("开启生产线程 param: {}", param);
produceData(context);
log.trace("等待处理完成 param: {}", param);
return CompletableFuture.allOf(context.getRunnings()
.toArray(new CompletableFuture[context.getRunnings().size()]));
}catch(Exception e){
log.trace("处理失败 param: {}", param);
context.getRunnings().forEach(each -> each.cancel(true));
throw new RuntimeException();
}
}
- Context对象保存一些公用变量,方便传递和简化接口
- CompletableFuture.allOf()把所有线程合成一个future,又上层调用者决定是否同步执行
- 如果发生异常,把所有运行中的线程取消掉
生产者
public void produceData(Context context) throws Exception {
CompletableFuture<Void> producerFuture = CompletableFuture.runAsync(() -> {
try{
Iterator<D> it = producer.apply(context.getParam());
while(it.hasNext()){
context.getDataQueue().put(it.next());
}
for(int i = 0; i < consumerCount; i++){
context.getDataQueue().put(MARK_AS_END);//通过MARK_AS_END通知线程结束
}
log.debug("生产线程完成");
}catch(Exception e){
throw new RuntimeException(e);
}
}, executor);
context.getRunnings().add(producerFuture);
log.trace("生产线程已启动");
}
- 生产者回调函数会传入业务Param,并返回一个Iterator对象
- 遍历Iterator,放入BlockingQueue, 队列满会阻塞
- 数据遍历结束后,放入MARK_AS_END这个特殊对象,这个对象用来通知消费者,生产者完成了。(也叫Poision Pill)
消费者
protected void startConsumers(Context context) {
for(int i = 0; i < consumerCount; i++){
CompletableFuture<Void> running = CompletableFuture.runAsync(() -> {
try {
while(true){
D data = context.getDataQueue().take();
if(data == MARK_AS_END){
log.debug("收到终止信号,消费线程退出");
return;
}
log.trace("处理数据, {}", data);
consumer.accept(context.getParam(), data);
}
}catch(Exception e){
throw new RuntimeException(e);
}
}, executor);
context.getRunnings().add(running);
}
log.trace("{} 消费线程已启动", consumerCount);
}
- 消费者从队列里拿出对象,先判断是否== MARK_AS_END。如果相等。说明生产者已经结束,数据消费完了。直接退出
- 如果不相等,传给消费者回调
网友评论