之前用go gin部署了gpt stream,但是目前项目整体框架用的java, 为了和业务结合,还是实现了springboot版的。
为什么会有stream?
gpt是生成式的,stream模式非常适合。
一个不得不用的理由,maxToken过大的话用普通的模式容易接口超时。
技术点
还是sse,springboot api返回 SseEmitter
okhttp接收sream数据 ,需要包okhttp-sse
pom引入
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-tomcat</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-undertow</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<exclusions>
<exclusion>
<groupId>com.squareup.okio</groupId>
<artifactId>okio</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.squareup.okio</groupId>
<artifactId>okio</artifactId>
<version>${okio.version}</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
</dependency>
entity
@Data
public class ChatGpt35Dto {
private String msg;
}
@Data
public static class OpenAiRequest {
public String model;
public List<GptMessage> messages;
public Double temperature;
public Integer max_tokens;
public Boolean stream;
}
@Data
public static class GptMessage {
public String role;
public String content;
public String name;
}
@Data
public static class OpenAISteamResponse {
public String id;
public String object;
public int created;
public String model;
public List<ChoicesBean> choices;
@Data
public static class ChoicesBean {
public int index;
public GptMessage delta;
public String finish_reason;
}
}
okhttp sse utils方法定义
//发送sse
public static void sse(String url, String json, Map<String, String> headers, EventSourceListener listener) {
okhttp3.RequestBody body = okhttp3.RequestBody.create(okhttp3.MediaType.parse("application/json; charset=utf-8"), json);
Request request = new Request.Builder().url(url).headers(Headers.of(headers))
.post(body).build();
EventSource.Factory factory = EventSources.createFactory(OkHttpUtils.getOkHttpClient());
factory.newEventSource(request, listener);
}
//返回error
public static void sendSseError(SseEmitter sseEmitter, String errorMessage) {
try {
sseEmitter.send(SseEmitter.event().name("error").data(errorMessage));
} catch (Exception e) {
e.printStackTrace();
sseEmitter.completeWithError(e);
}
log.error("sse error {}", errorMessage);
sseEmitter.complete();
}
controller
@RestController
public class ChatGptController {
@Autowired
ChatGptService chatGptService;
@GetMapping(value = {"/openai/gpt35/stream"}, produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter streamOpenAI(ChatGpt35Dto dto) {
return chatGptService.gpt35Stream(dto);
}
}
serivce
省略接口的定义,直接写一下实现
public SseEmitter chat(ChatVirtualRoleDto dto) {
SseEmitter sseEmitter = new SseEmitter();
if (dto.getQuestion() == null || StringUtils.isEmpty(dto.getQuestion().trim())) {
sendSseError(sseEmitter, "question is error");
return sseEmitter;
}
OpenAiRequest openAiRequest = new OpenAiRequest();
openAiRequest.model = "gpt-3.5-turbo";
openAiRequest.temperature = 0.7;
openAiRequest.stream = true;
//todo 业务封装 message
log.info("open ai request: {}", JSON.toJSONString(openAiRequest));
okHttpEvent(sseEmitter, openAiRequest, openAiStreamHeaders, answer -> {
//todo 拿到返回处理逻辑
return 1;
});
return sseEmitter;
}
public static void okHttpEvent(SseEmitter emitter, OpenAiRequest openAiRequest, Map<String, String> openAiStreamHeaders, Function<String, Long> function) {
StringBuilder answer = new StringBuilder();
OkHttpUtils.sse(openAiUrl, JSON.toJSONString(openAiRequest), openAiStreamHeaders, new EventSourceListener() {
@Override
public void onOpen(EventSource eventSource, Response response) {
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
//正常发送,收到done结束
if ("[DONE]".equals(data)) {
//透传参数处理业务
Long historyId = function.apply(answer.toString());
try {
emitter.send(SseEmitter.event().name("stop").data(historyId));
} catch (IOException e) {
}
emitter.complete();
return;
} else {
OpenAISteamResponse openAiResponse = JSONObject.parseObject(data, OpenAISteamResponse.class);
OpenAISteamResponse.ChoicesBean choicesBean = openAiResponse.choices.get(0);
//如果为空不处理,不然前端收到很多null
if (StringUtils.isEmpty(choicesBean.delta.content)) {
return;
} else {
//内容拼接
answer.append(choicesBean.delta.content);
//返回收到的消息
try {
emitter.send(SseEmitter.event().name("message").data(JSON.toJSONString(choicesBean)));
} catch (IOException e) {
emitter.complete();
}
}
}
}
@Override
public void onClosed(EventSource eventSource) {
emitter.complete();
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
//大部分情况出现response code 429
sendSseError(emitter, response.code() + "");
log.error("event source failure {}", t);
}
});
}
关于异常处理
大部分情况,springboot会接入统一的异常处理给前端,但是sse异常如果返回的是标准的对象而不是SseEmitter就会抛出springmvc的异常
所以要专门捕获sse中抛出的异常
@RestControllerAdvice
@Slf4j
public class ControllerAdviceConf {
@ExceptionHandler(value = AsyncRequestTimeoutException.class)
public void myExceptionHandler(AsyncRequestTimeoutException ex) {
log.error("接口异常 async timeout");
//发生异常进行日志记录,写入数据库或者其他处理,此处省略
}
}
网友评论