这里主要记录一下用法,所以不做过多的注释了。环境:
elasticsearch: "org.elasticsearch:elasticsearch:7.5.1",
elasticsearch_client: "org.elasticsearch.client:transport:7.5.1",
springboot_elasticsearch: "org.springframework.boot:spring-boot-starter-data-elasticsearch:2.2.0.RELEASE"
MyPageableRequest.java
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Sort;
public class MyPageableRequest extends PageRequest {
/**
* Creates a new {@link PageRequest} with sort parameters applied.
*
* @param page zero-based page index, must not be negative.
* @param size the size of the page to be returned, must be greater than 0.
* @param sort must not be {@literal null}, use {@link Sort#unsorted()} instead.
*/
public MyPageableRequest(int page, int size, Sort sort) {
super(page, size, sort);
}
}
SearchQueryEngine.java
import com.es.exception.SearchQueryBuildException;
import org.apache.commons.lang3.StringUtils;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.metadata.IndexMetaData;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.Document;
import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
import org.springframework.data.elasticsearch.core.ScrolledPage;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.Date;
import java.util.List;
/**
* 抽象接口定义
*/
public abstract class SearchQueryEngine<T> {
protected ElasticsearchTemplate elasticsearchTemplate;
public abstract void deleteIndex(String indexName);
public abstract void createIndex(String indexName);
public abstract int saveOrUpdate(List<T> list);
public abstract <R> List<R> aggregation(T query, Class<R> clazz);
public abstract <R> ScrolledPage<R> scroll(T query, Class<R> clazz, Pageable pageable);
public abstract <R> List<R> find(T query, Class<R> clazz, int size);
public abstract <R> Page<R> find(T query, Class<R> clazz, Pageable pageable);
public abstract <R> R sum(T query, Class<R> clazz);
public SearchQueryEngine(TransportClient client) {
this.elasticsearchTemplate = new ElasticsearchTemplate(client);
}
protected Document getDocument(T t) {
Document annotation = t.getClass().getAnnotation(Document.class);
if (annotation == null) {
throw new SearchQueryBuildException("Can't find annotation @Document on " + t.getClass().getName());
}
return annotation;
}
/**
* 获取字段名,若设置column则返回该值
*
* @param field
* @param column
* @return
*/
protected String getFieldName(Field field, String column) {
return StringUtils.isNotBlank(column) ? column : field.getName();
}
/**
* 设置属性值
*
* @param field
* @param obj
* @param value
*/
protected void setFieldValue(Field field, Object obj, Object value) {
boolean isAccessible = field.isAccessible();
field.setAccessible(true);
try {
switch (field.getType().getSimpleName()) {
case "BigDecimal":
field.set(obj, new BigDecimal(value.toString()).setScale(5, BigDecimal.ROUND_HALF_UP));
break;
case "Long":
field.set(obj, new Long(value.toString()));
break;
case "Integer":
field.set(obj, new Integer(value.toString()));
break;
case "Date":
field.set(obj, new Date(Long.valueOf(value.toString())));
break;
default:
field.set(obj, value);
}
} catch (IllegalAccessException e) {
throw new SearchQueryBuildException(e);
} finally {
field.setAccessible(isAccessible);
}
}
/**
* 获取字段值
*
* @param field
* @param obj
* @return
*/
protected Object getFieldValue(Field field, Object obj) {
boolean isAccessible = field.isAccessible();
field.setAccessible(true);
try {
return field.get(obj);
} catch (IllegalAccessException e) {
throw new SearchQueryBuildException(e);
} finally {
field.setAccessible(isAccessible);
}
}
/**
* 转换为es识别的value值
*
* @param value
* @return
*/
protected Object formatValue(Object value) {
if (value instanceof Date) {
return ((Date) value).getTime();
} else {
return value;
}
}
/**
* 获取索引分区数
*
* @param t
* @return
*/
protected int getNumberOfShards(T t) {
return Integer.parseInt(elasticsearchTemplate.getSetting(getDocument(t).indexName()).get(IndexMetaData.SETTING_NUMBER_OF_SHARDS).toString());
}
}
SimpleSearchQueryEngine.java
import annotation.*;
import com.es.common.Container;
import com.es.common.Operator;
import com.es.entity.AggregationResultsExtractor;
import com.es.entity.MyPageableRequest;
import com.es.exception.SearchQueryBuildException;
import com.es.exception.SearchResultBuildException;
import com.es.util.BeanPropertyUtil;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.InternalSum;
import org.elasticsearch.search.aggregations.metrics.SumAggregationBuilder;
import org.springframework.beans.BeanUtils;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.elasticsearch.annotations.Document;
import org.springframework.data.elasticsearch.core.ScrolledPage;
import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder;
import org.springframework.data.elasticsearch.core.query.SearchQuery;
import org.springframework.data.elasticsearch.core.query.UpdateQuery;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.lang.reflect.Field;
import java.util.*;
import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
import static org.elasticsearch.index.query.QueryBuilders.rangeQuery;
@Component
public class SimpleSearchQueryEngine<T> extends SearchQueryEngine<T> {
private int numberOfRowsPerScan = 10;
/**
* scroll游标快照超时时间,单位ms
*/
private static final long SCROLL_TIMEOUT = 3000;
public SimpleSearchQueryEngine(TransportClient client) {
super(client);
}
@Override
public void deleteIndex(String indexName) {
if(elasticsearchTemplate.indexExists(indexName)){
elasticsearchTemplate.deleteIndex(indexName);
}
}
@Override
public void createIndex(String indexName) {
if(!elasticsearchTemplate.indexExists(indexName)){
elasticsearchTemplate.createIndex(indexName);
}
}
@Override
public int saveOrUpdate(List<T> list) {
if (CollectionUtils.isEmpty(list)) {
return 0;
}
T base = list.get(0);
Field id = null;
for (Field field : base.getClass().getDeclaredFields()) {
Id businessID = field.getAnnotation(Id.class);
if (businessID != null) {
id = field;
break;
}
}
if (id == null) {
throw new SearchQueryBuildException("Can't find @BusinessID on " + base.getClass().getName());
}
Document document = getDocument(base);
List<UpdateQuery> bulkIndex = new ArrayList<>();
for (T t : list) {
UpdateQuery updateQuery = new UpdateQuery();
updateQuery.setIndexName(document.indexName());
updateQuery.setType(document.type());
updateQuery.setId(getFieldValue(id, t).toString());
// doc()的参数不能为json字符串或t对象,只能为map
updateQuery.setUpdateRequest(new UpdateRequest(updateQuery.getIndexName(), updateQuery.getType(), updateQuery.getId()).doc(BeanPropertyUtil.beanToMap(t)));
updateQuery.setDoUpsert(true);
updateQuery.setClazz(t.getClass());
bulkIndex.add(updateQuery);
}
elasticsearchTemplate.bulkUpdate(bulkIndex);
return list.size();
}
@Override
public <R> List<R> aggregation(T query, Class<R> clazz) {
NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
nativeSearchQueryBuilder.addAggregation(buildGroupBy(query));
Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
try {
return transformList(null, aggregations, clazz.newInstance(), new ArrayList());
} catch (Exception e) {
throw new SearchResultBuildException(e);
}
}
/**
* 将Aggregations转为List
*/
private <R> List<R> transformList(Aggregation terms, Aggregations aggregations, R baseObj, List<R> resultList) throws NoSuchFieldException, IllegalAccessException, InstantiationException {
for (String column : aggregations.asMap().keySet()) {
Aggregation childAggregation = aggregations.get(column);
if (childAggregation instanceof InternalSum) {
// 使用@Sum
if (!(terms instanceof InternalSum)) {
R targetObj = (R) baseObj.getClass().newInstance();
BeanUtils.copyProperties(baseObj, targetObj);
resultList.add(targetObj);
}
setFieldValue(baseObj.getClass().getDeclaredField(column), resultList.get(resultList.size() - 1), ((InternalSum) childAggregation).getValue());
terms = childAggregation;
} else {
Terms childTerms = (Terms) childAggregation;
for (Terms.Bucket bucket : childTerms.getBuckets()) {
if (CollectionUtils.isEmpty(bucket.getAggregations().asList())) {
// 未使用@Sum
R targetObj = (R) baseObj.getClass().newInstance();
BeanUtils.copyProperties(baseObj, targetObj);
setFieldValue(targetObj.getClass().getDeclaredField(column), targetObj, bucket.getKey());
resultList.add(targetObj);
} else {
setFieldValue(baseObj.getClass().getDeclaredField(column), baseObj, bucket.getKey());
transformList(childTerms, bucket.getAggregations(), baseObj, resultList);
}
}
}
}
return resultList;
}
/**
* 分页查询
*/
@Override
public <R> ScrolledPage<R> scroll(T query, Class<R> clazz, Pageable pageable) {
if (pageable.getPageSize() % numberOfRowsPerScan > 0) {
throw new SearchQueryBuildException("Page size must be an integral multiple of " + numberOfRowsPerScan);
}
int pageSize = numberOfRowsPerScan / getNumberOfShards(query);
SearchQuery searchQuery = buildNativeSearchQueryBuilder(query)
.withPageable(new MyPageableRequest(pageable.getPageNumber(), pageSize, pageable.getSort()))
.build();
ScrolledPage<R> page = (ScrolledPage<R>)elasticsearchTemplate.startScroll(SCROLL_TIMEOUT,searchQuery, clazz);
elasticsearchTemplate.clearScroll(page.getScrollId());
return page;
}
/**
* 查询固定条数的记录
* 有时候因为要获取的条数过大,而elasticsearch每次只返回10000条,因此这里用分页查询
*/
@Override
public <R> List<R> find(T query, Class<R> clazz, int size) {
// Caused by: QueryPhaseExecutionException[Result window is too large, from + size must be less than or equal to: [10000] but was [2147483647].
// See the scroll api for a more efficient way to request large data sets. This limit can be set by changing the [index.max_result_window] index level parameter.]
if (size % numberOfRowsPerScan > 0) {
throw new SearchQueryBuildException("Parameter 'size' must be an integral multiple of " + numberOfRowsPerScan);
}
int pageNum = 0;
List<R> result = new ArrayList<>();
int pageSize = numberOfRowsPerScan / getNumberOfShards(query);
SearchQuery searchQuery = buildNativeSearchQueryBuilder(query)
.withPageable(new MyPageableRequest(pageNum, pageSize,buildDefaultSort()))
.build();
ScrolledPage<R> page = (ScrolledPage<R>)elasticsearchTemplate.startScroll(SCROLL_TIMEOUT,searchQuery, clazz);
while(page.hasContent() && result.size()<size){
result.addAll(page.getContent());
// 取下一页,scrollId在es服务器上可能会发生变化,需要用最新的。发起continueScroll请求会重新刷新快照保留时间
page = (ScrolledPage<R>) elasticsearchTemplate.continueScroll(page.getScrollId() , SCROLL_TIMEOUT, clazz);
}
elasticsearchTemplate.clearScroll(page.getScrollId());
return result;
}
@Override
public <R> Page<R> find(T query, Class<R> clazz, Pageable pageable) {
NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query).withPageable(pageable);
return elasticsearchTemplate.queryForPage(nativeSearchQueryBuilder.build(), clazz);
}
@Override
public <R> R sum(T query, Class<R> clazz) {
NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
for (SumAggregationBuilder sumBuilder : getSumBuilderList(query)) {
nativeSearchQueryBuilder.addAggregation(sumBuilder);
}
Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
try {
return transformSumResult(aggregations, clazz);
} catch (Exception e) {
throw new SearchResultBuildException(e);
}
}
private <R> R transformSumResult(Aggregations aggregations, Class<R> clazz) throws IllegalAccessException, InstantiationException, NoSuchFieldException {
R targetObj = clazz.newInstance();
for (Aggregation sum : aggregations.asList()) {
if (sum instanceof InternalSum) {
setFieldValue(targetObj.getClass().getDeclaredField(sum.getName()), targetObj, ((InternalSum) sum).getValue());
}
}
return targetObj;
}
/**
* 构建查询对象
*/
private NativeSearchQueryBuilder buildNativeSearchQueryBuilder(T query) {
Document document = getDocument(query);
NativeSearchQueryBuilder nativeSearchQueryBuilder = new NativeSearchQueryBuilder()
.withIndices(document.indexName())
.withTypes(document.type());
QueryBuilder whereBuilder = buildBoolQuery(query);
if (whereBuilder != null) {
//将搜索条件设置到构建中
nativeSearchQueryBuilder.withQuery(whereBuilder);
}
return nativeSearchQueryBuilder;
}
/**
* 创建QueryBuilder(即设置查询条件)这儿创建的是组合查询(也叫多条件查询)
*/
private BoolQueryBuilder buildBoolQuery(T query) {
/**
* 组合查询BoolQueryBuilder:builder下有must、should以及mustNot 相当于sql中的and、or以及not
* must(QueryBuilders) :AND
* mustNot(QueryBuilders):NOT
* should: :OR
*/
BoolQueryBuilder boolQueryBuilder = boolQuery();
buildMatchQuery(boolQueryBuilder, query);
buildRangeQuery(boolQueryBuilder, query);
BoolQueryBuilder queryBuilder = boolQuery().must(boolQueryBuilder);
return queryBuilder;
}
/**
* and or 查询构建
*/
private void buildMatchQuery(BoolQueryBuilder boolQueryBuilder, T query) {
Class clazz = query.getClass();
for (Field field : clazz.getDeclaredFields()) {
MatchQuery annotation = field.getAnnotation(MatchQuery.class);
Object value = getFieldValue(field, query);
if (annotation == null || value == null) {
continue;
}
if (Container.must.equals(annotation.container())) {
boolQueryBuilder.must(matchQuery(getFieldName(field, annotation.column()), formatValue(value)));
} else if (Container.should.equals(annotation.container())) {
if (value instanceof Collection) {
BoolQueryBuilder shouldQueryBuilder = boolQuery();
Collection tmp = (Collection) value;
for (Object obj : tmp) {
shouldQueryBuilder.should(matchQuery(getFieldName(field, annotation.column()), formatValue(obj)));
}
boolQueryBuilder.must(shouldQueryBuilder);
} else {
boolQueryBuilder.must(boolQuery().should(matchQuery(getFieldName(field, annotation.column()), formatValue(value))));
}
}
}
}
/**
* 范围查询构建
*/
private void buildRangeQuery(BoolQueryBuilder boolQueryBuilder, T query) {
Class clazz = query.getClass();
for (Field field : clazz.getDeclaredFields()) {
TermRangeQuery annotation = field.getAnnotation(TermRangeQuery.class);
Object value = getFieldValue(field, query);
if (annotation == null || value == null) {
continue;
}
if (Operator.gt.equals(annotation.operator())) {
boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gt(formatValue(value)));
} else if (Operator.gte.equals(annotation.operator())) {
boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gte(formatValue(value)));
} else if (Operator.lt.equals(annotation.operator())) {
boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lt(formatValue(value)));
} else if (Operator.lte.equals(annotation.operator())) {
boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lte(formatValue(value)));
}
}
}
/**
* Sum构建(非嵌套获取数据求和)
*/
private List<SumAggregationBuilder> getSumBuilderList(T query) {
List<SumAggregationBuilder> list = new ArrayList<>();
Class clazz = query.getClass();
for (Field field : clazz.getDeclaredFields()) {
Sum annotation = field.getAnnotation(Sum.class);
if (annotation == null) {
continue;
}
list.add(AggregationBuilders.sum(field.getName()).field(field.getName()));
}
if (CollectionUtils.isEmpty(list)) {
throw new SearchQueryBuildException("Can't find @Sum on " + clazz.getName());
}
return list;
}
/**
* GroupBy构建
*/
private TermsAggregationBuilder buildGroupBy(T query) {
List<Field> sumList = new ArrayList<>();
Object groupByCollection = null;
Class clazz = query.getClass();
for (Field field : clazz.getDeclaredFields()) {
// 嵌套数据求和
Sum sumAnnotation = field.getAnnotation(Sum.class);
if (sumAnnotation != null) {
sumList.add(field);
}
// 获取含有GroupBy注解的字段(只能有一个),获取其值,其值为一个列表
GroupBy groupByannotation = field.getAnnotation(GroupBy.class);
Object value = getFieldValue(field, query);
if (groupByannotation == null || value == null) {
continue;
} else if (!(value instanceof Collection)) {
throw new SearchQueryBuildException("GroupBy filed must be collection");
} else if (CollectionUtils.isEmpty((Collection<String>) value)) {
continue;
} else if (groupByCollection != null) {
throw new SearchQueryBuildException("Only one @GroupBy is allowed");
} else {
groupByCollection = value;
}
}
Iterator<String> iterator = ((Collection<String>) groupByCollection).iterator();
TermsAggregationBuilder termsBuilder = recursiveAddAggregation(iterator, sumList);
return termsBuilder;
}
/**
* 添加Aggregation
*/
private TermsAggregationBuilder recursiveAddAggregation(Iterator<String> iterator, List<Field> sumList) {
String groupBy = iterator.next();
// 定义单个桶的类型term
TermsAggregationBuilder termsBuilder = AggregationBuilders.terms(groupBy).field(groupBy).size(0);
if (iterator.hasNext()) {
termsBuilder.subAggregation(recursiveAddAggregation(iterator, sumList));
} else {
for (Field field : sumList) {
termsBuilder.subAggregation(AggregationBuilders.sum(field.getName()).field(field.getName()));
}
sumList.clear();
}
// Ordering the buckets alphabetically by their terms in an ascending manner
return termsBuilder.order(BucketOrder.key(true));
}
private Sort buildDefaultSort(){
List<Sort.Order> orders = new ArrayList<>();
Sort sort = Sort.by(orders);
return sort;
}
}
网友评论