美文网首页
ElasticsearchTemplate的使用

ElasticsearchTemplate的使用

作者: 不知名的蛋挞 | 来源:发表于2020-02-18 15:36 被阅读0次

    这里主要记录一下用法,所以不做过多的注释了。环境:

    elasticsearch:                  "org.elasticsearch:elasticsearch:7.5.1",
    elasticsearch_client:           "org.elasticsearch.client:transport:7.5.1",
    springboot_elasticsearch:       "org.springframework.boot:spring-boot-starter-data-elasticsearch:2.2.0.RELEASE"
    

    MyPageableRequest.java

    import org.springframework.data.domain.PageRequest;
    import org.springframework.data.domain.Sort;
    
    public class MyPageableRequest extends PageRequest {
    
        /**
         * Creates a new {@link PageRequest} with sort parameters applied.
         *
         * @param page zero-based page index, must not be negative.
         * @param size the size of the page to be returned, must be greater than 0.
         * @param sort must not be {@literal null}, use {@link Sort#unsorted()} instead.
         */
        public MyPageableRequest(int page, int size, Sort sort) {
            super(page, size, sort);
        }
    }
    

    SearchQueryEngine.java

    import com.es.exception.SearchQueryBuildException;
    import org.apache.commons.lang3.StringUtils;
    import org.elasticsearch.client.transport.TransportClient;
    import org.elasticsearch.cluster.metadata.IndexMetaData;
    import org.springframework.data.domain.Page;
    import org.springframework.data.domain.Pageable;
    import org.springframework.data.elasticsearch.annotations.Document;
    import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
    import org.springframework.data.elasticsearch.core.ScrolledPage;
    import java.lang.reflect.Field;
    import java.math.BigDecimal;
    import java.util.Date;
    import java.util.List;
    
    /**
     * 抽象接口定义
     */
    public abstract class SearchQueryEngine<T> {
    
        protected ElasticsearchTemplate elasticsearchTemplate;
    
        public abstract void deleteIndex(String indexName);
    
        public abstract void createIndex(String indexName);
    
        public abstract int saveOrUpdate(List<T> list);
    
        public abstract <R> List<R> aggregation(T query, Class<R> clazz);
    
        public abstract <R> ScrolledPage<R> scroll(T query, Class<R> clazz, Pageable pageable);
    
        public abstract <R> List<R> find(T query, Class<R> clazz, int size);
    
        public abstract <R> Page<R> find(T query, Class<R> clazz, Pageable pageable);
    
        public abstract <R> R sum(T query, Class<R> clazz);
    
        public SearchQueryEngine(TransportClient client) {
            this.elasticsearchTemplate = new ElasticsearchTemplate(client);
        }
    
        protected Document getDocument(T t) {
            Document annotation = t.getClass().getAnnotation(Document.class);
            if (annotation == null) {
                throw new SearchQueryBuildException("Can't find annotation @Document on " + t.getClass().getName());
            }
            return annotation;
        }
    
        /**
         * 获取字段名,若设置column则返回该值
         *
         * @param field
         * @param column
         * @return
         */
        protected String getFieldName(Field field, String column) {
            return StringUtils.isNotBlank(column) ? column : field.getName();
        }
    
        /**
         * 设置属性值
         *
         * @param field
         * @param obj
         * @param value
         */
        protected void setFieldValue(Field field, Object obj, Object value) {
            boolean isAccessible = field.isAccessible();
            field.setAccessible(true);
            try {
                switch (field.getType().getSimpleName()) {
                    case "BigDecimal":
                        field.set(obj, new BigDecimal(value.toString()).setScale(5, BigDecimal.ROUND_HALF_UP));
                        break;
                    case "Long":
                        field.set(obj, new Long(value.toString()));
                        break;
                    case "Integer":
                        field.set(obj, new Integer(value.toString()));
                        break;
                    case "Date":
                        field.set(obj, new Date(Long.valueOf(value.toString())));
                        break;
                    default:
                        field.set(obj, value);
                }
            } catch (IllegalAccessException e) {
                throw new SearchQueryBuildException(e);
            } finally {
                field.setAccessible(isAccessible);
            }
        }
    
        /**
         * 获取字段值
         *
         * @param field
         * @param obj
         * @return
         */
        protected Object getFieldValue(Field field, Object obj) {
            boolean isAccessible = field.isAccessible();
            field.setAccessible(true);
            try {
                return field.get(obj);
            } catch (IllegalAccessException e) {
                throw new SearchQueryBuildException(e);
            } finally {
                field.setAccessible(isAccessible);
            }
        }
    
        /**
         * 转换为es识别的value值
         *
         * @param value
         * @return
         */
        protected Object formatValue(Object value) {
            if (value instanceof Date) {
                return ((Date) value).getTime();
            } else {
                return value;
            }
        }
    
        /**
         * 获取索引分区数
         *
         * @param t
         * @return
         */
        protected int getNumberOfShards(T t) {
            return Integer.parseInt(elasticsearchTemplate.getSetting(getDocument(t).indexName()).get(IndexMetaData.SETTING_NUMBER_OF_SHARDS).toString());
        }
    }
    

    SimpleSearchQueryEngine.java

    import annotation.*;
    import com.es.common.Container;
    import com.es.common.Operator;
    import com.es.entity.AggregationResultsExtractor;
    import com.es.entity.MyPageableRequest;
    import com.es.exception.SearchQueryBuildException;
    import com.es.exception.SearchResultBuildException;
    import com.es.util.BeanPropertyUtil;
    import org.elasticsearch.action.update.UpdateRequest;
    import org.elasticsearch.client.transport.TransportClient;
    import org.elasticsearch.index.query.BoolQueryBuilder;
    import org.elasticsearch.index.query.QueryBuilder;
    import org.elasticsearch.search.aggregations.Aggregation;
    import org.elasticsearch.search.aggregations.AggregationBuilders;
    import org.elasticsearch.search.aggregations.Aggregations;
    import org.elasticsearch.search.aggregations.BucketOrder;
    import org.elasticsearch.search.aggregations.bucket.terms.Terms;
    import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
    import org.elasticsearch.search.aggregations.metrics.InternalSum;
    import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
    import org.springframework.beans.BeanUtils;
    import org.springframework.data.annotation.Id;
    import org.springframework.data.domain.Page;
    import org.springframework.data.domain.Pageable;
    import org.springframework.data.domain.Sort;
    import org.springframework.data.elasticsearch.annotations.Document;
    import org.springframework.data.elasticsearch.core.ScrolledPage;
    import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder;
    import org.springframework.data.elasticsearch.core.query.SearchQuery;
    import org.springframework.data.elasticsearch.core.query.UpdateQuery;
    import org.springframework.stereotype.Component;
    import org.springframework.util.CollectionUtils;
    import java.lang.reflect.Field;
    import java.util.*;
    import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
    import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
    import static org.elasticsearch.index.query.QueryBuilders.rangeQuery;
    
    @Component
    public class SimpleSearchQueryEngine<T> extends SearchQueryEngine<T> {
    
        private int numberOfRowsPerScan = 10;
    
        /**
         * scroll游标快照超时时间,单位ms
         */
        private static final long SCROLL_TIMEOUT = 3000;
    
        public SimpleSearchQueryEngine(TransportClient client) {
            super(client);
        }
    
        @Override
        public void deleteIndex(String indexName) {
            if(elasticsearchTemplate.indexExists(indexName)){
                elasticsearchTemplate.deleteIndex(indexName);
            }
        }
    
        @Override
        public void createIndex(String indexName) {
            if(!elasticsearchTemplate.indexExists(indexName)){
                elasticsearchTemplate.createIndex(indexName);
            }
        }
    
        @Override
        public int saveOrUpdate(List<T> list) {
            if (CollectionUtils.isEmpty(list)) {
                return 0;
            }
    
            T base = list.get(0);
            Field id = null;
            for (Field field : base.getClass().getDeclaredFields()) {
                Id businessID = field.getAnnotation(Id.class);
                if (businessID != null) {
                    id = field;
                    break;
                }
            }
            if (id == null) {
                throw new SearchQueryBuildException("Can't find @BusinessID on " + base.getClass().getName());
            }
    
            Document document = getDocument(base);
            List<UpdateQuery> bulkIndex = new ArrayList<>();
            for (T t : list) {
                UpdateQuery updateQuery = new UpdateQuery();
                updateQuery.setIndexName(document.indexName());
                updateQuery.setType(document.type());
                updateQuery.setId(getFieldValue(id, t).toString());
                // doc()的参数不能为json字符串或t对象,只能为map
                updateQuery.setUpdateRequest(new UpdateRequest(updateQuery.getIndexName(), updateQuery.getType(), updateQuery.getId()).doc(BeanPropertyUtil.beanToMap(t)));
                updateQuery.setDoUpsert(true);
                updateQuery.setClazz(t.getClass());
                bulkIndex.add(updateQuery);
            }
            elasticsearchTemplate.bulkUpdate(bulkIndex);
            return list.size();
        }
    
    
        @Override
        public <R> List<R> aggregation(T query, Class<R> clazz) {
            NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
            nativeSearchQueryBuilder.addAggregation(buildGroupBy(query));
            Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
            try {
                return transformList(null, aggregations, clazz.newInstance(), new ArrayList());
            } catch (Exception e) {
                throw new SearchResultBuildException(e);
            }
        }
    
    
        /**
         * 将Aggregations转为List
         */
        private <R> List<R> transformList(Aggregation terms, Aggregations aggregations, R baseObj, List<R> resultList) throws NoSuchFieldException, IllegalAccessException, InstantiationException {
            for (String column : aggregations.asMap().keySet()) {
                Aggregation childAggregation = aggregations.get(column);
                if (childAggregation instanceof InternalSum) {
                    // 使用@Sum
                    if (!(terms instanceof InternalSum)) {
                        R targetObj = (R) baseObj.getClass().newInstance();
                        BeanUtils.copyProperties(baseObj, targetObj);
                        resultList.add(targetObj);
                    }
                    setFieldValue(baseObj.getClass().getDeclaredField(column), resultList.get(resultList.size() - 1), ((InternalSum) childAggregation).getValue());
                    terms = childAggregation;
                } else {
                    Terms childTerms = (Terms) childAggregation;
                    for (Terms.Bucket bucket : childTerms.getBuckets()) {
                        if (CollectionUtils.isEmpty(bucket.getAggregations().asList())) {
                            // 未使用@Sum
                            R targetObj = (R) baseObj.getClass().newInstance();
                            BeanUtils.copyProperties(baseObj, targetObj);
                            setFieldValue(targetObj.getClass().getDeclaredField(column), targetObj, bucket.getKey());
                            resultList.add(targetObj);
                        } else {
                            setFieldValue(baseObj.getClass().getDeclaredField(column), baseObj, bucket.getKey());
                            transformList(childTerms, bucket.getAggregations(), baseObj, resultList);
                        }
                    }
                }
            }
            return resultList;
        }
    
        /**
         * 分页查询
         */
        @Override
        public <R> ScrolledPage<R> scroll(T query, Class<R> clazz, Pageable pageable) {
            if (pageable.getPageSize() % numberOfRowsPerScan > 0) {
                throw new SearchQueryBuildException("Page size must be an integral multiple of " + numberOfRowsPerScan);
            }
    
            int pageSize = numberOfRowsPerScan / getNumberOfShards(query);
            SearchQuery searchQuery = buildNativeSearchQueryBuilder(query)
                    .withPageable(new MyPageableRequest(pageable.getPageNumber(), pageSize, pageable.getSort()))
                    .build();
    
            ScrolledPage<R> page = (ScrolledPage<R>)elasticsearchTemplate.startScroll(SCROLL_TIMEOUT,searchQuery, clazz);
            elasticsearchTemplate.clearScroll(page.getScrollId());
            return page;
        }
    
        /**
         * 查询固定条数的记录
         * 有时候因为要获取的条数过大,而elasticsearch每次只返回10000条,因此这里用分页查询
         */
        @Override
        public <R> List<R> find(T query, Class<R> clazz, int size) {
            // Caused by: QueryPhaseExecutionException[Result window is too large, from + size must be less than or equal to: [10000] but was [2147483647].
            // See the scroll api for a more efficient way to request large data sets. This limit can be set by changing the [index.max_result_window] index level parameter.]
            if (size % numberOfRowsPerScan > 0) {
                throw new SearchQueryBuildException("Parameter 'size' must be an integral multiple of " + numberOfRowsPerScan);
            }
            int pageNum = 0;
            List<R> result = new ArrayList<>();
    
            int pageSize = numberOfRowsPerScan / getNumberOfShards(query);
    
            SearchQuery searchQuery = buildNativeSearchQueryBuilder(query)
                    .withPageable(new MyPageableRequest(pageNum, pageSize,buildDefaultSort()))
                    .build();
    
            ScrolledPage<R> page = (ScrolledPage<R>)elasticsearchTemplate.startScroll(SCROLL_TIMEOUT,searchQuery, clazz);
            while(page.hasContent() && result.size()<size){
                result.addAll(page.getContent());
                // 取下一页,scrollId在es服务器上可能会发生变化,需要用最新的。发起continueScroll请求会重新刷新快照保留时间
                page = (ScrolledPage<R>) elasticsearchTemplate.continueScroll(page.getScrollId() , SCROLL_TIMEOUT, clazz);
            }
    
            elasticsearchTemplate.clearScroll(page.getScrollId());
            return result;
        }
    
        @Override
        public <R> Page<R> find(T query, Class<R> clazz, Pageable pageable) {
            NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query).withPageable(pageable);
            return elasticsearchTemplate.queryForPage(nativeSearchQueryBuilder.build(), clazz);
        }
    
        @Override
        public <R> R sum(T query, Class<R> clazz) {
            NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
            for (SumAggregationBuilder sumBuilder : getSumBuilderList(query)) {
                nativeSearchQueryBuilder.addAggregation(sumBuilder);
            }
            Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
            try {
                return transformSumResult(aggregations, clazz);
            } catch (Exception e) {
                throw new SearchResultBuildException(e);
            }
        }
    
        private <R> R transformSumResult(Aggregations aggregations, Class<R> clazz) throws IllegalAccessException, InstantiationException, NoSuchFieldException {
            R targetObj = clazz.newInstance();
            for (Aggregation sum : aggregations.asList()) {
                if (sum instanceof InternalSum) {
                    setFieldValue(targetObj.getClass().getDeclaredField(sum.getName()), targetObj, ((InternalSum) sum).getValue());
                }
            }
            return targetObj;
        }
    
        /**
         * 构建查询对象
         */
        private NativeSearchQueryBuilder buildNativeSearchQueryBuilder(T query) {
            Document document = getDocument(query);
            NativeSearchQueryBuilder nativeSearchQueryBuilder = new NativeSearchQueryBuilder()
                    .withIndices(document.indexName())
                    .withTypes(document.type());
    
            QueryBuilder whereBuilder = buildBoolQuery(query);
            if (whereBuilder != null) {
                //将搜索条件设置到构建中
                nativeSearchQueryBuilder.withQuery(whereBuilder);
            }
    
            return nativeSearchQueryBuilder;
        }
    
        /**
         * 创建QueryBuilder(即设置查询条件)这儿创建的是组合查询(也叫多条件查询)
         */
        private BoolQueryBuilder buildBoolQuery(T query) {
            /**
             * 组合查询BoolQueryBuilder:builder下有must、should以及mustNot 相当于sql中的and、or以及not
             * must(QueryBuilders)   :AND
             * mustNot(QueryBuilders):NOT
             * should:               :OR
             */
            BoolQueryBuilder boolQueryBuilder = boolQuery();
            buildMatchQuery(boolQueryBuilder, query);
            buildRangeQuery(boolQueryBuilder, query);
            BoolQueryBuilder queryBuilder = boolQuery().must(boolQueryBuilder);
            return queryBuilder;
        }
    
        /**
         * and or 查询构建
         */
        private void buildMatchQuery(BoolQueryBuilder boolQueryBuilder, T query) {
            Class clazz = query.getClass();
            for (Field field : clazz.getDeclaredFields()) {
                MatchQuery annotation = field.getAnnotation(MatchQuery.class);
                Object value = getFieldValue(field, query);
                if (annotation == null || value == null) {
                    continue;
                }
                if (Container.must.equals(annotation.container())) {
                    boolQueryBuilder.must(matchQuery(getFieldName(field, annotation.column()), formatValue(value)));
                } else if (Container.should.equals(annotation.container())) {
                    if (value instanceof Collection) {
                        BoolQueryBuilder shouldQueryBuilder = boolQuery();
                        Collection tmp = (Collection) value;
                        for (Object obj : tmp) {
                            shouldQueryBuilder.should(matchQuery(getFieldName(field, annotation.column()), formatValue(obj)));
                        }
                        boolQueryBuilder.must(shouldQueryBuilder);
                    } else {
                        boolQueryBuilder.must(boolQuery().should(matchQuery(getFieldName(field, annotation.column()), formatValue(value))));
                    }
                }
            }
        }
    
        /**
         * 范围查询构建
         */
        private void buildRangeQuery(BoolQueryBuilder boolQueryBuilder, T query) {
            Class clazz = query.getClass();
            for (Field field : clazz.getDeclaredFields()) {
                TermRangeQuery annotation = field.getAnnotation(TermRangeQuery.class);
                Object value = getFieldValue(field, query);
                if (annotation == null || value == null) {
                    continue;
                }
                if (Operator.gt.equals(annotation.operator())) {
                    boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gt(formatValue(value)));
                } else if (Operator.gte.equals(annotation.operator())) {
                    boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gte(formatValue(value)));
                } else if (Operator.lt.equals(annotation.operator())) {
                    boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lt(formatValue(value)));
                } else if (Operator.lte.equals(annotation.operator())) {
                    boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lte(formatValue(value)));
                }
            }
        }
    
        /**
         * Sum构建(非嵌套获取数据求和)
         */
        private List<SumAggregationBuilder> getSumBuilderList(T query) {
            List<SumAggregationBuilder> list = new ArrayList<>();
            Class clazz = query.getClass();
            for (Field field : clazz.getDeclaredFields()) {
                Sum annotation = field.getAnnotation(Sum.class);
                if (annotation == null) {
                    continue;
                }
                list.add(AggregationBuilders.sum(field.getName()).field(field.getName()));
            }
            if (CollectionUtils.isEmpty(list)) {
                throw new SearchQueryBuildException("Can't find @Sum on " + clazz.getName());
            }
            return list;
        }
    
        /**
         * GroupBy构建
         */
        private TermsAggregationBuilder buildGroupBy(T query) {
            List<Field> sumList = new ArrayList<>();
            Object groupByCollection = null;
            Class clazz = query.getClass();
    
            for (Field field : clazz.getDeclaredFields()) {
                // 嵌套数据求和
                Sum sumAnnotation = field.getAnnotation(Sum.class);
                if (sumAnnotation != null) {
                    sumList.add(field);
                }
    
                // 获取含有GroupBy注解的字段(只能有一个),获取其值,其值为一个列表
                GroupBy groupByannotation = field.getAnnotation(GroupBy.class);
                Object value = getFieldValue(field, query);
                if (groupByannotation == null || value == null) {
                    continue;
                } else if (!(value instanceof Collection)) {
                    throw new SearchQueryBuildException("GroupBy filed must be collection");
                } else if (CollectionUtils.isEmpty((Collection<String>) value)) {
                    continue;
                } else if (groupByCollection != null) {
                    throw new SearchQueryBuildException("Only one @GroupBy is allowed");
                } else {
                    groupByCollection = value;
                }
            }
            Iterator<String> iterator = ((Collection<String>) groupByCollection).iterator();
            TermsAggregationBuilder termsBuilder = recursiveAddAggregation(iterator, sumList);
            return termsBuilder;
        }
    
        /**
         * 添加Aggregation
         */
        private TermsAggregationBuilder recursiveAddAggregation(Iterator<String> iterator, List<Field> sumList) {
            String groupBy = iterator.next();
    
            // 定义单个桶的类型term
            TermsAggregationBuilder termsBuilder = AggregationBuilders.terms(groupBy).field(groupBy).size(0);
            if (iterator.hasNext()) {
                termsBuilder.subAggregation(recursiveAddAggregation(iterator, sumList));
            } else {
                for (Field field : sumList) {
                    termsBuilder.subAggregation(AggregationBuilders.sum(field.getName()).field(field.getName()));
                }
                sumList.clear();
            }
    
            // Ordering the buckets alphabetically by their terms in an ascending manner
            return termsBuilder.order(BucketOrder.key(true));
        }
    
        private Sort buildDefaultSort(){
            List<Sort.Order> orders = new ArrayList<>();
            Sort sort = Sort.by(orders);
            return sort;
        }
    }
    

    相关文章

      网友评论

          本文标题:ElasticsearchTemplate的使用

          本文链接:https://www.haomeiwen.com/subject/rqbgfhtx.html