美文网首页
Spring Boot集成Elasticsearch自定义序列化

Spring Boot集成Elasticsearch自定义序列化

作者: 阿拉喵_d271 | 来源:发表于2018-10-18 18:29 被阅读0次

    spring-boot-elasticsearch对搜索结果反序列化时不能获取_score值,业务中又必须用所以就改一下框架里的序列化代码

    1. 继承AbstractResultMapper增加读取/写入_score的代码,我是使用@Score在实体里标记写到哪个字段实现的
    
    import java.lang.annotation.ElementType;
    import java.lang.annotation.Retention;
    import java.lang.annotation.RetentionPolicy;
    import java.lang.annotation.Target;
    
    /**
     * @author weizhiwen
     * @date 2018/7/4
     */
    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.FIELD, ElementType.METHOD, ElementType.ANNOTATION_TYPE})
    public @interface Score {
    
    }
    
    
    import com.fasterxml.jackson.core.JsonEncoding;
    import com.fasterxml.jackson.core.JsonFactory;
    import com.fasterxml.jackson.core.JsonGenerator;
    import lombok.extern.slf4j.Slf4j;
    import org.apache.commons.lang.StringUtils;
    import org.apache.commons.text.WordUtils;
    import org.elasticsearch.action.get.GetResponse;
    import org.elasticsearch.action.get.MultiGetItemResponse;
    import org.elasticsearch.action.get.MultiGetResponse;
    import org.elasticsearch.action.search.SearchResponse;
    import org.elasticsearch.search.SearchHit;
    import org.elasticsearch.search.SearchHitField;
    import org.springframework.data.domain.Pageable;
    import org.springframework.data.elasticsearch.ElasticsearchException;
    import org.springframework.data.elasticsearch.annotations.Document;
    import org.springframework.data.elasticsearch.annotations.ScriptedField;
    import org.springframework.data.elasticsearch.core.AbstractResultMapper;
    import org.springframework.data.elasticsearch.core.DefaultEntityMapper;
    import org.springframework.data.elasticsearch.core.EntityMapper;
    import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage;
    import org.springframework.data.elasticsearch.core.aggregation.impl.AggregatedPageImpl;
    import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentEntity;
    import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty;
    import org.springframework.data.mapping.context.MappingContext;
    import org.springframework.util.Assert;
    
    import java.io.ByteArrayOutputStream;
    import java.io.IOException;
    import java.lang.reflect.Field;
    import java.lang.reflect.Method;
    import java.nio.charset.Charset;
    import java.util.*;
    import java.util.concurrent.ConcurrentHashMap;
    
    /**
     * @author weizhiwen
     * @date 2018/10/18
     */
    @Slf4j
    public class ZwResultMapper extends AbstractResultMapper {
        /**
         * score set 方法缓存
         */
        private static final Map<Class, List<Method>> SCORE_SET_METHOD_CACHE = new ConcurrentHashMap<>();
    
        public static final String SET = "set";
        private MappingContext<? extends ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty> mappingContext;
    
        public ZwResultMapper() {
            super(new DefaultEntityMapper());
        }
    
        public ZwResultMapper(MappingContext<? extends ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty> mappingContext) {
            super(new DefaultEntityMapper());
            this.mappingContext = mappingContext;
        }
    
        public ZwResultMapper(EntityMapper entityMapper) {
            super(entityMapper);
        }
    
        public ZwResultMapper(
                MappingContext<? extends ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty> mappingContext,
                EntityMapper entityMapper) {
            super(entityMapper);
            this.mappingContext = mappingContext;
        }
    
        @Override
        public <T> AggregatedPage<T> mapResults(SearchResponse response, Class<T> clazz, Pageable pageable) {
            long totalHits = response.getHits().getTotalHits();
            List<T> results = new ArrayList<>();
            for (SearchHit hit : response.getHits()) {
                if (hit != null) {
                    T result = null;
                    if (StringUtils.isNotBlank(hit.getSourceAsString())) {
                        result = mapEntity(hit.getSourceAsString(), clazz);
                    } else {
                        result = mapEntity(hit.getFields().values(), clazz);
                    }
                    //增加的代码
                    setEntityScore(result, hit.getScore(), clazz);
                    setPersistentEntityId(result, hit.getId(), clazz);
                    setPersistentEntityVersion(result, hit.getVersion(), clazz);
                    populateScriptFields(result, hit);
                    results.add(result);
                }
            }
    
            return new AggregatedPageImpl<>(results, pageable, totalHits);
        }
    
        private <T> void populateScriptFields(T result, SearchHit hit) {
            if (hit.getFields() != null && !hit.getFields().isEmpty() && result != null) {
                for (java.lang.reflect.Field field : result.getClass().getDeclaredFields()) {
                    ScriptedField scriptedField = field.getAnnotation(ScriptedField.class);
                    if (scriptedField != null) {
                        String name = scriptedField.name().isEmpty() ? field.getName() : scriptedField.name();
                        SearchHitField searchHitField = hit.getFields().get(name);
                        if (searchHitField != null) {
                            field.setAccessible(true);
                            try {
                                field.set(result, searchHitField.getValue());
                            } catch (IllegalArgumentException e) {
                                throw new ElasticsearchException("failed to set scripted field: " + name + " with value: "
                                        + searchHitField.getValue(), e);
                            } catch (IllegalAccessException e) {
                                throw new ElasticsearchException("failed to access scripted field: " + name, e);
                            }
                        }
                    }
                }
            }
        }
    
        private <T> T mapEntity(Collection<SearchHitField> values, Class<T> clazz) {
            return mapEntity(buildJSONFromFields(values), clazz);
        }
    
        private String buildJSONFromFields(Collection<SearchHitField> values) {
            JsonFactory nodeFactory = new JsonFactory();
            try {
                ByteArrayOutputStream stream = new ByteArrayOutputStream();
                JsonGenerator generator = nodeFactory.createGenerator(stream, JsonEncoding.UTF8);
                generator.writeStartObject();
                for (SearchHitField value : values) {
                    if (value.getValues().size() > 1) {
                        generator.writeArrayFieldStart(value.getName());
                        for (Object val : value.getValues()) {
                            generator.writeObject(val);
                        }
                        generator.writeEndArray();
                    } else {
                        generator.writeObjectField(value.getName(), value.getValue());
                    }
                }
                generator.writeEndObject();
                generator.flush();
                return new String(stream.toByteArray(), Charset.forName("UTF-8"));
            } catch (IOException e) {
                return null;
            }
        }
    
        @Override
        public <T> T mapResult(GetResponse response, Class<T> clazz) {
            T result = mapEntity(response.getSourceAsString(), clazz);
            if (result != null) {
                setPersistentEntityId(result, response.getId(), clazz);
                setPersistentEntityVersion(result, response.getVersion(), clazz);
            }
            return result;
        }
    
        @Override
        public <T> LinkedList<T> mapResults(MultiGetResponse responses, Class<T> clazz) {
            LinkedList<T> list = new LinkedList<>();
            for (MultiGetItemResponse response : responses.getResponses()) {
                if (!response.isFailed() && response.getResponse().isExists()) {
                    T result = mapEntity(response.getResponse().getSourceAsString(), clazz);
                    setPersistentEntityId(result, response.getResponse().getId(), clazz);
                    setPersistentEntityVersion(result, response.getResponse().getVersion(), clazz);
                    list.add(result);
                }
            }
            return list;
        }
    
        private <T> void setPersistentEntityId(T result, String id, Class<T> clazz) {
    
            if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) {
    
                ElasticsearchPersistentEntity<?> persistentEntity = mappingContext.getPersistentEntity(clazz);
                ElasticsearchPersistentProperty idProperty = persistentEntity.getIdProperty();
    
                // Only deal with String because ES generated Ids are strings !
                if (idProperty != null && idProperty.getType().isAssignableFrom(String.class)) {
                    persistentEntity.getPropertyAccessor(result).setProperty(idProperty, id);
                }
    
            }
        }
    
        private <T> void setEntityScore(T result, float score, Class<T> clazz) {
            if (result != null && clazz.isAnnotationPresent(Document.class)) {
                List<Method> methods;
                if (SCORE_SET_METHOD_CACHE.containsKey(clazz)) {
                    methods = SCORE_SET_METHOD_CACHE.get(clazz);
                } else {
                    methods = new Vector<>();
                    for (Method method : clazz.getMethods()) {
                        if (StringUtils.startsWith(method.getName(), SET) && hasAnnotation(clazz, method)) {
                            if (method.getParameterCount() == 1) {
                                if (Float.class.equals(method.getParameterTypes()[0])) {
                                    methods.add(method);
                                }
                            }
                        }
                    }
                    SCORE_SET_METHOD_CACHE.put(clazz, methods);
                }
                for (Method method : methods) {
                    try {
                        method.invoke(result, score);
                    } catch (Exception e) {
                        log.error("{} set score error", clazz.getSimpleName());
                    }
                }
            }
        }
    
        public boolean hasAnnotation(Class<?> clazz, Method method) {
            if (method.getAnnotation(Score.class) != null) {
                return true;
            }
            String findFieldName = StringUtils.removeStart(method.getName(), "set");
            for (Field field : clazz.getDeclaredFields()) {
                if (field.getName().equalsIgnoreCase(findFieldName) && field.getAnnotation(Score.class) != null) {
                    return true;
                }
            }
            return false;
        }
    
        private <T> void setPersistentEntityVersion(T result, long version, Class<T> clazz) {
            if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) {
    
                ElasticsearchPersistentEntity<?> persistentEntity = mappingContext.getPersistentEntity(clazz);
                ElasticsearchPersistentProperty versionProperty = persistentEntity.getVersionProperty();
    
                // Only deal with Long because ES versions are longs !
                if (versionProperty != null && versionProperty.getType().isAssignableFrom(Long.class)) {
                    // check that a version was actually returned in the response, -1 would indicate that
                    // a search didn't request the version ids in the response, which would be an issue
                    Assert.isTrue(version != -1, "Version in response is -1");
                    persistentEntity.getPropertyAccessor(result).setProperty(versionProperty, version);
                }
            }
        }
    }
    
    1. 定制ElasticsearchTemplate,把org.springframework.data.elasticsearch.core.DefaultResultMapper改成ZwResultMapper
    
    import org.elasticsearch.client.Client;
    import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
    import org.springframework.data.elasticsearch.core.EntityMapper;
    import org.springframework.data.elasticsearch.core.ResultsMapper;
    import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter;
    
    /**
     * @author weizhiwen
     * @date 2018/7/4
     */
    public class ZwElasticsearchTemplate extends ElasticsearchTemplate {
    
    
        public ZwElasticsearchTemplate(Client client) {
            super(client);
        }
    
        public ZwElasticsearchTemplate(Client client, EntityMapper entityMapper) {
            super(client, entityMapper);
        }
    
        public ZwElasticsearchTemplate(Client client, ElasticsearchConverter elasticsearchConverter, EntityMapper entityMapper) {
            this(client, elasticsearchConverter,
                    new ZwResultMapper(elasticsearchConverter.getMappingContext(), entityMapper));
        }
    
        public ZwElasticsearchTemplate(Client client, ResultsMapper resultsMapper) {
            super(client, resultsMapper);
        }
    
        public ZwElasticsearchTemplate(Client client, ElasticsearchConverter elasticsearchConverter) {
            this(client, elasticsearchConverter, new ZwResultMapper(elasticsearchConverter.getMappingContext()));
        }
    
        public ZwElasticsearchTemplate(Client client, ElasticsearchConverter elasticsearchConverter, ResultsMapper resultsMapper) {
            super(client, elasticsearchConverter, resultsMapper);
        }
    }
    
    1. 注册ZwElasticsearchTemplate
        @Bean
        public ElasticsearchTemplate elasticsearchTemplate(Client client, ElasticsearchConverter converter) {
            try {
                return new ZwElasticsearchTemplate(client, converter);
            } catch (Exception var4) {
                throw new IllegalStateException(var4);
            }
        }
    

    使用

    类加上@Document字段加上@Score并且类型为Float

    import lombok.Data;
    import org.springframework.data.elasticsearch.annotations.Document;
    
    /**
     * @author weizhiwen
     * @date 2018/10/18
     */
    @Data
    @Document(indexName = "asd")
    public class AudioDocumentDTO extends BaseAudioDocument {
        @Score
        private Float score;
    }
    

    相关文章

      网友评论

          本文标题:Spring Boot集成Elasticsearch自定义序列化

          本文链接:https://www.haomeiwen.com/subject/hzvszftx.html