美文网首页
一起来学Java8(八)——ForkJoin

一起来学Java8(八)——ForkJoin

作者: 猿敲月下码 | 来源:发表于2020-02-21 09:24 被阅读0次

    一起来学Java8(七)——Stream中我们了解了reduce的用法,其中并行流的底层是使用了分支/合并框架

    分支/合并框架的核心思想是把一个大的任务拆分成多个子任务,然后把每个子任务的执行结果整合起来,返回一个最终结果。

    ForkJoinPool

    分支/合并框架的核心类是java.util.concurrent.ForkJoinPool,从名称中可以看到它是一个线程池,线程数量是默认处理器数量,可以通过下面这句话来改变线程数:

    System.setProperty("jaav.util.concurrent.ForkJoinPool.common.parallelism", "8");
    

    RecursiveTask<T>

    前面说到了ForkJoinPool类是一个线程池,那么RecursiveTask的做用就是生成一个任务,然后把这个任务放到ForkJoinPool当中去。

    RecursiveTask类是一个抽象类,并且有一个抽象方法

    protected abstract V compute();
    

    这个方法的主要功能是拆分任务逻辑,直到无法拆分时返回子任务的执行结果。

    下面从一个简单的例子来说明各个类的使用方式。

    这个例子演示将一个数组内的所有数字相加,得到一个总和。思想是将数组进行对半拆分,得到两个子数组,然后再对两个子数组进行拆分,以此类推,直到数组长度小于等于10的时候不再拆分。

    @AllArgsConstructor
        static class NumberAddTask extends RecursiveTask<Long> {
    
            // 存放数字
            private long[] numbers;
            // 计算的起始位置
            private int startIndex;
            // 计算的结束位置
            private int endIndex;
    
            @Override
            protected Long compute() {
                int len = endIndex - startIndex;
                // 数组长度小于10,无法拆分,开始运算
                if (len <= 10) {
                    return execute();
                }
                // 拆分左边的子任务
                NumberAddTask leftTask = new NumberAddTask(numbers, startIndex, startIndex + len / 2);
                // 将子任务加入到ForkJoinPool中去
                leftTask.fork();
                // 创建右边的任务
                NumberAddTask rightTask = new NumberAddTask(numbers, startIndex + len / 2, endIndex);
                // 执行右边的任务
                Long rightSum = rightTask.compute();
                // 读取左边的子任务结果,这里会阻塞
                Long leftSum = leftTask.join();
                // 合并结果
                return leftSum + rightSum;
            }
    
            private long execute() {
                long sum = 0;
                for (int i = startIndex; i < endIndex; i++) {
                    sum += numbers[i];
                }
                return sum;
            }
        }
    

    运行代码:

    long startTime = System.currentTimeMillis();
    // 生成一个数组,存放1,2,3,4....
    long[] numbers = LongStream.rangeClosed(1, 1000000).toArray();
    // 创建一个任务,起始位置是0,结束位置是数组的长度
    NumberAddTask numberAddTask = new NumberAddTask(numbers, 0, numbers.length);
    // 将任务加入到线程池中运行,得到总和
    Long sum = new ForkJoinPool().invoke(numberAddTask);
    long time = System.currentTimeMillis() - startTime;
    System.out.println("sum:" + sum + ", 耗时:" + time + "毫秒");
    

    打印:

    sum:500000500000, 耗时:63毫秒

    Spliterator

    Spliterator是Java8新增的一个接口,从名字上可解读出两种意思:split,iterate,即该接口提供分割迭代的功能。Spliterator接口需要配合Stream一起使用。

    使用方式:

    // 创建顺序执行的Stream
    Stream stream = StreamSupport.stream(Spliterator, false);
    
    // 创建并行的Stream
    Stream stream = StreamSupport.stream(Spliterator, true);
    

    这两行代码即为Collection.stream()方法的默认实现。

    Spliterator接口声明了4个抽象方法,需要开发者自己实现。

    public interface Spliterator<T> {
        boolean tryAdvance(Consumer<? super T> action);
    
        Spliterator<T> trySplit();
    
        long estimateSize();
    
        int characteristics();
    }
    

    boolean tryAdvance(Consumer<? super T> action):方法用于遍历获取元素,然后通过Consumer来执行,如果取到元素返回true,否则返回false

    Spliterator<T> trySplit():执行分割操作,如果集合还能再继续分割,则返回一个新的Spliterator,如果不能继续分割则返回null

    long estimateSize():用来返回剩余元素数量

    int characteristics():指定集合一些特性。

    characteristics特性列表如下:

    特性 含义
    ORDERD 集合中的的元素有顺序概念
    DISTINCT 对任意一对遍历过的元素x,y,x.equals(y)返回false
    SORTED 遍历的元素按照一个预定义的顺序排序
    SIZED 集合元素大小可确定的
    NONNULL 保证遍历元素没有null
    IMMUTABLE 集合元素不能修改
    CONCURRENT 集合可被其它线程同时修改,无需同步
    SUBSIZED 该Spliterator和所有从它拆分出来的子Spliterator都有SIZED特性

    characteristics()使用方式如下:

    @Override
    public int characteristics() {
        return Spliterator.SIZED | Spliterator.SUBSIZED;
    }
    

    下面我们来实现一个自定义的分割迭代器

    static class LongSpliterator implements Spliterator<Long> {
    
            private long[] array;
            private int index;
            private int end;
    
            public LongSpliterator(long[] array, int index, int end) {
                this.array = array;
                this.index = index;
                this.end = end;
            }
    
            @Override
            public boolean tryAdvance(Consumer<? super Long> action) {
                if (index >= 0 && index < end) {
                    // 取出元素
                    Long l = array[index++];
                    // 执行
                    action.accept(l);
                    return true;
                }
                return false;
            }
    
            /**
             * 尝试分割集合。
             * 切割规则:数组一分为2,留后半段,将前半段再一分为2。
             * @return 返回null结束分割
             */
            @Override
            public Spliterator<Long> trySplit() {
                int start = 0;
                // 取中间
                int middle = (start + end) >>> 1;
                if (start < middle) {
                    return null;
                }
                // 当前index变成中间值,即当前类的操作范围是:middle ~ end
                index = middle;
                // 将前半段再一分为2
                return new LongSpliterator(array, start, middle);
            }
    
            @Override
            public long estimateSize() {
                return end - index;
            }
    
            @Override
            public int characteristics() {
                return Spliterator.SIZED | Spliterator.SUBSIZED;
            }
        }
    

    该分割迭代器中的集合是一个long数组,分割规则是将数组一分为2,留后半段,将前半段再一分为2。

    现在来使用这个分割迭代器,计算从1加到1000000,测试用例如下:

        public void testDo() {
            // 创建一个数组
            long[] numbers = LongStream.rangeClosed(1, 1000000).toArray();
            long startTime = System.currentTimeMillis();
            LongSpliterator spliterator = new LongSpliterator(numbers, 0, numbers.length);
            // 申明一个并行的Stream
            Stream<Long> stream = StreamSupport.stream(spliterator, true);
            // 计算从1加到1000000,结果应该是:500000500000
            Long sum = stream.reduce((n1, n2) -> n1 + n2).orElse(0L);
            long time = System.currentTimeMillis() - startTime;
            System.out.println("sum:" + sum + ", 耗时:" + time + "毫秒");
        }
    

    打印:

    sum:500000500000, 耗时:32毫秒

    定期分享技术干货,一起学习,一起进步!微信公众号:猿敲月下码

    相关文章

      网友评论

          本文标题:一起来学Java8(八)——ForkJoin

          本文链接:https://www.haomeiwen.com/subject/cwbqqhtx.html