遇到的问题
Java 8 开始引入了 Stream, 其中的 api 一直在不断的优化更新完善,Java 9 中更是引入了 ofNullable
还有 takeWhile
和 dropWhile
这两个关键 api。有时候,我们想对 Stream 中的对象进行排重,默认的可以用 distinct 这个 api,例如:
List<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).distinct().collect(Collectors.toList());
底层实现是LinkedHashMap
,其实这个和下面的实现几乎是等价的:
Set<String> collect = Arrays.stream("test1,test2,test2,test3,test3".split(",")).collect(Collectors.toCollection(LinkedHashSet::new));
结果是一样的,靠hashcode()
方法定位槽,equals()
方法判断是否是同一个对象,如果是则排重被去掉,不是的话保留,通过LinkedHashMap来保留原始顺序。
但是,对于同一个对象,有时候我们排重的方式并不统一,所以最好像sorted
接口一样,能让我们传入比较器,来控制如何判断两个对象相等需要排重。
例如下面的这个对象,我们有时候想按照id
排重,有时候想按照name
进行排重。
@Data
@NoArgsConstructor
public class User {
private int id;
private String name;
}
解决思考
首先来实现这个distinct
方法。首先,我们定义一个Key
类用来代理 hashcode 还有 equals 方法:
private static final class Key<E> {
//要比较的对象
private final E e;
//获取对象的hashcode的方法
private final ToIntFunction<E> hashCode;
//判断两个对象是否相等的方法
private final BiPredicate<E, E> equals;
public Key(E e, ToIntFunction<E> hashCode,
BiPredicate<E, E> equals) {
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
}
@Override
public int hashCode() {
return hashCode.applyAsInt(e);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Key)) {
return false;
}
@SuppressWarnings("unchecked")
Key<E> that = (Key<E>) obj;
return equals.test(this.e, that.e);
}
}
然后,增加新的distinct
方法:
public Stream<T> distinct (
ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
//排重的时候,保留哪一个?
BinaryOperator<T> merger
) {
return this.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//通过LinkedHashMap来保持原有的顺序
LinkedHashMap::new))
.values()
.stream();
}
然后,这个方法如何放入 Stream 呢? 我们首先想到的就是代理 Stream
接口,最简单的实现:
public class EnhancedStream<T> implements Stream<T> {
private Stream<T> delegate;
public EnhancedStream(Stream<T> delegate) {
this.delegate = delegate;
}
private static final class Key<E> {
//要比较的对象
private final E e;
//获取对象的hashcode的方法
private final ToIntFunction<E> hashCode;
//判断两个对象是否相等的方法
private final BiPredicate<E, E> equals;
public Key(E e, ToIntFunction<E> hashCode,
BiPredicate<E, E> equals) {
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
}
@Override
public int hashCode() {
return hashCode.applyAsInt(e);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Key)) {
return false;
}
@SuppressWarnings("unchecked")
Key<E> that = (Key<E>) obj;
return equals.test(this.e, that.e);
}
}
public EnhancedStream<T> distinct(
ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
//排重的时候,保留哪一个?
BinaryOperator<T> merger
) {
return new EnhancedStream<>(
delegate.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//通过LinkedHashMap来保持原有的顺序
LinkedHashMap::new))
.values()
.stream()
);
}
@Override
public EnhancedStream<T> filter(Predicate<? super T> predicate) {
return new EnhancedStream<>(delegate.filter(predicate));
}
@Override
public <R> EnhancedStream<R> map(Function<? super T, ? extends R> mapper) {
return new EnhancedStream<>(delegate.map(mapper));
}
@Override
public IntStream mapToInt(ToIntFunction<? super T> mapper) {
return delegate.mapToInt(mapper);
}
@Override
public LongStream mapToLong(ToLongFunction<? super T> mapper) {
return delegate.mapToLong(mapper);
}
@Override
public DoubleStream mapToDouble(ToDoubleFunction<? super T> mapper) {
return delegate.mapToDouble(mapper);
}
@Override
public <R> EnhancedStream<R> flatMap(Function<? super T, ? extends Stream<? extends R>> mapper) {
return new EnhancedStream<>(delegate.flatMap(mapper));
}
@Override
public IntStream flatMapToInt(Function<? super T, ? extends IntStream> mapper) {
return delegate.flatMapToInt(mapper);
}
@Override
public LongStream flatMapToLong(Function<? super T, ? extends LongStream> mapper) {
return delegate.flatMapToLong(mapper);
}
@Override
public DoubleStream flatMapToDouble(Function<? super T, ? extends DoubleStream> mapper) {
return delegate.flatMapToDouble(mapper);
}
@Override
public EnhancedStream<T> distinct() {
return new EnhancedStream<>(delegate.distinct());
}
@Override
public EnhancedStream<T> sorted() {
return new EnhancedStream<>(delegate.sorted());
}
@Override
public EnhancedStream<T> sorted(Comparator<? super T> comparator) {
return new EnhancedStream<>(delegate.sorted(comparator));
}
@Override
public EnhancedStream<T> peek(Consumer<? super T> action) {
return new EnhancedStream<>(delegate.peek(action));
}
@Override
public EnhancedStream<T> limit(long maxSize) {
return new EnhancedStream<>(delegate.limit(maxSize));
}
@Override
public EnhancedStream<T> skip(long n) {
return new EnhancedStream<>(delegate.skip(n));
}
@Override
public void forEach(Consumer<? super T> action) {
delegate.forEach(action);
}
@Override
public void forEachOrdered(Consumer<? super T> action) {
delegate.forEachOrdered(action);
}
@Override
public Object[] toArray() {
return delegate.toArray();
}
@Override
public <A> A[] toArray(IntFunction<A[]> generator) {
return delegate.toArray(generator);
}
@Override
public T reduce(T identity, BinaryOperator<T> accumulator) {
return delegate.reduce(identity, accumulator);
}
@Override
public Optional<T> reduce(BinaryOperator<T> accumulator) {
return delegate.reduce(accumulator);
}
@Override
public <U> U reduce(U identity, BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner) {
return delegate.reduce(identity, accumulator, combiner);
}
@Override
public <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super T> accumulator, BiConsumer<R, R> combiner) {
return delegate.collect(supplier, accumulator, combiner);
}
@Override
public <R, A> R collect(Collector<? super T, A, R> collector) {
return delegate.collect(collector);
}
@Override
public Optional<T> min(Comparator<? super T> comparator) {
return delegate.min(comparator);
}
@Override
public Optional<T> max(Comparator<? super T> comparator) {
return delegate.max(comparator);
}
@Override
public long count() {
return delegate.count();
}
@Override
public boolean anyMatch(Predicate<? super T> predicate) {
return delegate.anyMatch(predicate);
}
@Override
public boolean allMatch(Predicate<? super T> predicate) {
return delegate.allMatch(predicate);
}
@Override
public boolean noneMatch(Predicate<? super T> predicate) {
return delegate.noneMatch(predicate);
}
@Override
public Optional<T> findFirst() {
return delegate.findFirst();
}
@Override
public Optional<T> findAny() {
return delegate.findAny();
}
@Override
public Iterator<T> iterator() {
return delegate.iterator();
}
@Override
public Spliterator<T> spliterator() {
return delegate.spliterator();
}
@Override
public boolean isParallel() {
return delegate.isParallel();
}
@Override
public EnhancedStream<T> sequential() {
return new EnhancedStream<>(delegate.sequential());
}
@Override
public EnhancedStream<T> parallel() {
return new EnhancedStream<>(delegate.parallel());
}
@Override
public EnhancedStream<T> unordered() {
return new EnhancedStream<>(delegate.unordered());
}
@Override
public EnhancedStream<T> onClose(Runnable closeHandler) {
return new EnhancedStream<>(delegate.onClose(closeHandler));
}
@Override
public void close() {
delegate.close();
}
}
测试下:
public static void main(String[] args) {
List<User> users = new ArrayList<>() {{
add(new User(1, "test1"));
add(new User(2, "test1"));
add(new User(2, "test2"));
add(new User(3, "test3"));
add(new User(3, "test4"));
}};
List<User> collect1 = new EnhancedStream<>(users.stream()).distinct(
User::getId,
(u1, u2) -> u1.getId() == u2.getId(),
(u1, u2) -> u1
).collect(Collectors.toList());
List<User> collect2 = new EnhancedStream<>(users.stream()).distinct(
user -> user.getName().hashCode(),
(u1, u2) -> u1.getName().equalsIgnoreCase(u2.getName()),
(u1, u2) -> u1
).collect(Collectors.toList());
}
通过动态代理
上面这种实现有很多冗余代码,可以考虑使用动态代理实现,首先编写代理接口类,通过EnhancedStream
继承Stream
接口,增加distinct
接口,并让所有返回Stream
的接口返回EnhancedStream
,这样才能让返回有新的distinct
接口可以使用。
public interface EnhancedStream<T> extends Stream<T> {
EnhancedStream<T> distinct(ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
BinaryOperator<T> merger);
@Override
EnhancedStream<T> filter(Predicate<? super T> predicate);
@Override
<R> EnhancedStream<R> map(
Function<? super T, ? extends R> mapper);
@Override
<R> EnhancedStream<R> flatMap(
Function<? super T, ? extends Stream<? extends R>> mapper);
@Override
EnhancedStream<T> distinct();
@Override
EnhancedStream<T> sorted();
@Override
EnhancedStream<T> sorted(Comparator<? super T> comparator);
@Override
EnhancedStream<T> peek(Consumer<? super T> action);
@Override
EnhancedStream<T> limit(long maxSize);
@Override
EnhancedStream<T> skip(long n);
@Override
EnhancedStream<T> takeWhile(Predicate<? super T> predicate);
@Override
EnhancedStream<T> dropWhile(Predicate<? super T> predicate);
@Override
EnhancedStream<T> sequential();
@Override
EnhancedStream<T> parallel();
@Override
EnhancedStream<T> unordered();
@Override
EnhancedStream<T> onClose(Runnable closeHandler);
}
然后,编写代理类EnhancedStreamHandler
实现方法代理:
public class EnhancedStreamHandler<T> implements InvocationHandler {
private Stream<T> delegate;
public EnhancedStreamHandler(Stream<T> delegate) {
this.delegate = delegate;
}
private static final Method ENHANCED_DISTINCT;
static {
try {
ENHANCED_DISTINCT = EnhancedStream.class.getMethod(
"distinct", ToIntFunction.class, BiPredicate.class,
BinaryOperator.class
);
} catch (NoSuchMethodException e) {
throw new Error(e);
}
}
/**
* 将EnhancedStream的方法与Stream的方法一一对应
*/
private static final Map<Method, Method> METHOD_MAP =
Stream.of(EnhancedStream.class.getMethods())
.filter(m -> !m.equals(ENHANCED_DISTINCT))
.filter(m -> !Modifier.isStatic(m.getModifiers()))
.collect(Collectors.toUnmodifiableMap(
Function.identity(),
m -> {
try {
return Stream.class.getMethod(
m.getName(), m.getParameterTypes());
} catch (NoSuchMethodException e) {
throw new Error(e);
}
}));
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (method.equals(ENHANCED_DISTINCT)) {
//调用方法为扩展方法distinct
return distinct(
(EnhancedStream<T>) proxy,
(ToIntFunction<T>) args[0],
(BiPredicate<T, T>) args[1],
(BinaryOperator<T>) args[2]);
} else if (method.getReturnType() == EnhancedStream.class) {
//对于返回类型为EnhancedStream的,证明是代理的方法调用,走代理
Method match = METHOD_MAP.get(method);
//更相信代理对象为新的Stream
this.delegate = (Stream) match.invoke(this.delegate, args);
return proxy;
} else {
//否则,直接用代理类调用
return method.invoke(this.delegate, args);
}
}
private static final class Key<E> {
private final E e;
private final ToIntFunction<E> hashCode;
private final BiPredicate<E, E> equals;
public Key(E e, ToIntFunction<E> hashCode,
BiPredicate<E, E> equals) {
this.e = e;
this.hashCode = hashCode;
this.equals = equals;
}
@Override
public int hashCode() {
return hashCode.applyAsInt(e);
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof Key)) {
return false;
}
@SuppressWarnings("unchecked")
Key<E> that = (Key<E>) obj;
return equals.test(this.e, that.e);
}
}
private EnhancedStream<T> distinct(EnhancedStream<T> proxy,
ToIntFunction<T> hashCode,
BiPredicate<T, T> equals,
BinaryOperator<T> merger) {
delegate = delegate.collect(Collectors.toMap(
t -> new Key<>(t, hashCode, equals),
Function.identity(),
merger,
//使用LinkedHashMap,保持入参原始顺序
LinkedHashMap::new))
.values()
.stream();
return proxy;
}
}
最后编写工厂类,生成EnhancedStream
代理类:
public class EnhancedStreamFactory {
public static <E> EnhancedStream<E> newEnhancedStream(Stream<E> stream) {
return (EnhancedStream<E>) Proxy.newProxyInstance(
//必须用EnhancedStream的classLoader,不能用Stream的,因为Stream是jdk的类,ClassLoader是rootClassLoader
EnhancedStream.class.getClassLoader(),
//代理接口
new Class<?>[] {EnhancedStream.class},
//代理类
new EnhancedStreamHandler<>(stream)
);
}
}
这样,代码看上去更优雅了,就算 JDK 以后扩展更多方法,这里也可不用修改
网友评论