美文网首页
2023-06-20 springboot 接入gpt stre

2023-06-20 springboot 接入gpt stre

作者: 江江江123 | 来源:发表于2023-06-19 16:28 被阅读0次

    之前用go gin部署了gpt stream,但是目前项目整体框架用的java, 为了和业务结合,还是实现了springboot版的。

    为什么会有stream?

    gpt是生成式的,stream模式非常适合。
    一个不得不用的理由,maxToken过大的话用普通的模式容易接口超时。

    技术点

    还是sse,springboot api返回 SseEmitter
    okhttp接收sream数据 ,需要包okhttp-sse

    pom引入

            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-starter-web</artifactId>
                <exclusions>
                    <exclusion>
                        <groupId>org.springframework.boot</groupId>
                        <artifactId>spring-boot-starter-tomcat</artifactId>
                    </exclusion>
                </exclusions>
            </dependency>
            <dependency>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-starter-undertow</artifactId>
            </dependency>
            <dependency>
                <groupId>com.squareup.okhttp3</groupId>
                <artifactId>okhttp</artifactId>
                <exclusions>
                    <exclusion>
                        <groupId>com.squareup.okio</groupId>
                        <artifactId>okio</artifactId>
                    </exclusion>
                </exclusions>
            </dependency>
            <dependency>
                <groupId>com.squareup.okio</groupId>
                <artifactId>okio</artifactId>
                <version>${okio.version}</version>
            </dependency>
            <dependency>
                <groupId>com.squareup.okhttp3</groupId>
                <artifactId>okhttp-sse</artifactId>
            </dependency>
    

    entity

        @Data
        public class ChatGpt35Dto {
            private String msg;
        }
    
        @Data
        public static class OpenAiRequest {
            public String model;
            public List<GptMessage> messages;
            public Double temperature;
            public Integer max_tokens;
            public Boolean stream;
        }
    
        @Data
        public static class GptMessage {
            public String role;
            public String content;
            public String name;
        }
    
        @Data
        public static class OpenAISteamResponse {
    
            public String id;
            public String object;
            public int created;
            public String model;
            public List<ChoicesBean> choices;
    
            @Data
            public static class ChoicesBean {
                public int index;
                public GptMessage delta;
                public String finish_reason;
            }
        }
    

    okhttp sse utils方法定义

        //发送sse
        public static void sse(String url, String json, Map<String, String> headers, EventSourceListener listener) {
            okhttp3.RequestBody body = okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"), json);
            Request request = new Request.Builder().url(url).headers(Headers.of(headers))
                    .post(body).build();
            EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getOkHttpClient());
            factory.newEventSource(request, listener);
        }
        //返回error
        public static void sendSseError(SseEmitter sseEmitter, String errorMessage) {
            try {
                sseEmitter.send(SseEmitter.event().name("error").data(errorMessage));
            } catch (Exception e) {
                e.printStackTrace();
                sseEmitter.completeWithError(e);
            }
            log.error("sse error {}", errorMessage);
            sseEmitter.complete();
        }
    
    

    controller

    @RestController
    public class ChatGptController {
        @Autowired
        ChatGptService chatGptService;
        @GetMapping(value = {"/openai/gpt35/stream"}, produces = MediaType.TEXT_EVENT_STREAM_VALUE)
        public SseEmitter streamOpenAI(ChatGpt35Dto dto) {
            return chatGptService.gpt35Stream(dto);
        }
    }
    

    serivce

    省略接口的定义,直接写一下实现

        public SseEmitter chat(ChatVirtualRoleDto dto) {
            SseEmitter sseEmitter = new SseEmitter();
            if (dto.getQuestion() == null || StringUtils.isEmpty(dto.getQuestion().trim())) {
                sendSseError(sseEmitter, "question is error");
                return sseEmitter;
            }
            OpenAiRequest openAiRequest = new OpenAiRequest();
            openAiRequest.model = "gpt-3.5-turbo";
            openAiRequest.temperature =  0.7;
            openAiRequest.stream = true;
            //todo 业务封装 message
            log.info("open ai request: {}", JSON.toJSONString(openAiRequest));
            okHttpEvent(sseEmitter, openAiRequest, openAiStreamHeaders, answer -> {
                //todo 拿到返回处理逻辑
                return 1;
            });
            return sseEmitter;
        }
    
        public static void okHttpEvent(SseEmitter emitter, OpenAiRequest openAiRequest, Map<String, String> openAiStreamHeaders, Function<String, Long> function) {
            StringBuilder answer = new StringBuilder();
            OkHttpUtils.sse(openAiUrl, JSON.toJSONString(openAiRequest), openAiStreamHeaders, new EventSourceListener() {
                @Override
                public void onOpen(EventSource eventSource, Response response) {
                }
    
                @Override
                public void onEvent(EventSource eventSource, String id, String type, String data) {
                    //正常发送,收到done结束
                    if ("[DONE]".equals(data)) {
                        //透传参数处理业务
                        Long historyId = function.apply(answer.toString());
                        try {
                            emitter.send(SseEmitter.event().name("stop").data(historyId));
                        } catch (IOException e) {
                        }
                        emitter.complete();
                        return;
                    } else {
                        OpenAISteamResponse openAiResponse = JSONObject.parseObject(data, OpenAISteamResponse.class);
                        OpenAISteamResponse.ChoicesBean choicesBean = openAiResponse.choices.get(0);
                        //如果为空不处理,不然前端收到很多null
                        if (StringUtils.isEmpty(choicesBean.delta.content)) {
                            return;
                        } else {
                            //内容拼接
                            answer.append(choicesBean.delta.content);
                           //返回收到的消息
                            try {
                                emitter.send(SseEmitter.event().name("message").data(JSON.toJSONString(choicesBean)));
                            } catch (IOException e) {
                                emitter.complete();
                            }
                        }
                    }
    
                }
    
                @Override
                public void onClosed(EventSource eventSource) {
                    emitter.complete();
                }
    
                @Override
                public void onFailure(EventSource eventSource, Throwable t, Response response) {
                    //大部分情况出现response code 429
                    sendSseError(emitter, response.code() + "");
                    log.error("event source failure {}", t);
                }
            });
        }
    

    关于异常处理

    大部分情况,springboot会接入统一的异常处理给前端,但是sse异常如果返回的是标准的对象而不是SseEmitter就会抛出springmvc的异常
    所以要专门捕获sse中抛出的异常

    @RestControllerAdvice
    @Slf4j
    public class ControllerAdviceConf {
        @ExceptionHandler(value = AsyncRequestTimeoutException.class)
        public void myExceptionHandler(AsyncRequestTimeoutException ex) {
            log.error("接口异常 async timeout");
            //发生异常进行日志记录,写入数据库或者其他处理,此处省略
        }
    }
    

    相关文章

      网友评论

          本文标题:2023-06-20 springboot 接入gpt stre

          本文链接:https://www.haomeiwen.com/subject/heyeydtx.html