|
@@ -11,11 +11,13 @@ import lombok.RequiredArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import okhttp3.*;
|
|
|
|
|
|
+import org.ruoyi.chat.config.ChatConfig;
|
|
|
+import org.ruoyi.chat.listener.SSEEventSourceListener;
|
|
|
import org.ruoyi.chat.service.chat.IChatCostService;
|
|
|
import org.ruoyi.chat.service.chat.IChatService;
|
|
|
import org.ruoyi.chat.service.chat.ISseService;
|
|
|
import org.ruoyi.chat.util.IpUtil;
|
|
|
-import org.ruoyi.chat.util.SSEUtil;
|
|
|
+import org.ruoyi.common.chat.config.LocalCache;
|
|
|
import org.ruoyi.common.chat.request.ChatRequest;
|
|
|
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
|
|
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
|
@@ -33,7 +35,9 @@ import org.ruoyi.common.core.utils.file.MimeTypeUtils;
|
|
|
|
|
|
import org.ruoyi.common.redis.utils.RedisUtils;
|
|
|
|
|
|
+import org.ruoyi.domain.vo.ChatModelVo;
|
|
|
import org.ruoyi.service.EmbeddingService;
|
|
|
+import org.ruoyi.service.IChatModelService;
|
|
|
import org.ruoyi.service.VectorStoreService;
|
|
|
import org.springframework.core.io.InputStreamResource;
|
|
|
import org.springframework.core.io.Resource;
|
|
@@ -74,27 +78,35 @@ public class SseServiceImpl implements ISseService {
|
|
|
|
|
|
private final IChatService chatService;
|
|
|
|
|
|
+ private final IChatModelService chatModelService;
|
|
|
|
|
|
private static final String requestIdTemplate = "company-%d";
|
|
|
|
|
|
private static final ObjectMapper mapper = new ObjectMapper();
|
|
|
|
|
|
+ private final ChatConfig chatConfig;
|
|
|
+
|
|
|
@Override
|
|
|
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
|
|
- SseEmitter sseEmitter = new SseEmitter(0L);
|
|
|
+ SseEmitter sseEmitter = new SseEmitter();
|
|
|
try {
|
|
|
// 构建消息列表增加联网、知识库等内容
|
|
|
buildChatMessageList(chatRequest);
|
|
|
+ if (!StpUtil.isLogin()) {
|
|
|
+ // 未登录用户限制对话次数
|
|
|
+ checkUnauthenticatedUserChatLimit(request);
|
|
|
+ }else {
|
|
|
+ LocalCache.CACHE.put("userId", chatCostService.getUserId());
|
|
|
+
|
|
|
+ chatRequest.setUserId(chatCostService.getUserId());
|
|
|
+ // 保存消息记录 并扣除费用
|
|
|
+ // chatCostService.deductToken(chatRequest);
|
|
|
+ }
|
|
|
// 根据模型名称前缀调用不同的处理逻辑
|
|
|
switchModelAndHandle(chatRequest,sseEmitter);
|
|
|
- // 未登录用户限制对话次数
|
|
|
- checkUnauthenticatedUserChatLimit(request);
|
|
|
- // 保存消息记录 并扣除费用
|
|
|
- chatCostService.deductToken(chatRequest);
|
|
|
} catch (Exception e) {
|
|
|
- String message = e.getMessage();
|
|
|
- SSEUtil.sendErrorEvent(sseEmitter, message);
|
|
|
- return sseEmitter;
|
|
|
+ log.error(e.getMessage(),e);
|
|
|
+ sseEmitter.completeWithError(e);
|
|
|
}
|
|
|
return sseEmitter;
|
|
|
}
|
|
@@ -106,8 +118,7 @@ public class SseServiceImpl implements ISseService {
|
|
|
* @throws ServiceException 如果当日免费次数已用完
|
|
|
*/
|
|
|
public void checkUnauthenticatedUserChatLimit(HttpServletRequest request) throws ServiceException {
|
|
|
- // 未登录用户限制对话次数
|
|
|
- if (!StpUtil.isLogin()) {
|
|
|
+
|
|
|
String clientIp = IpUtil.getClientIp(request);
|
|
|
// 访客每天默认只能对话5次
|
|
|
int timeWindowInSeconds = 5;
|
|
@@ -125,13 +136,14 @@ public class SseServiceImpl implements ISseService {
|
|
|
count++;
|
|
|
RedisUtils.setCacheObject(redisKey, count);
|
|
|
}
|
|
|
- }
|
|
|
+
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 根据模型名称前缀调用不同的处理逻辑
|
|
|
*/
|
|
|
private void switchModelAndHandle(ChatRequest chatRequest,SseEmitter emitter) {
|
|
|
+ SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(emitter);
|
|
|
String model = chatRequest.getModel();
|
|
|
// 如果模型名称以ollama开头,则调用ollama中部署的本地模型
|
|
|
if (model.startsWith("ollama-")) {
|
|
@@ -142,8 +154,24 @@ public class SseServiceImpl implements ISseService {
|
|
|
} else {
|
|
|
throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
|
|
|
}
|
|
|
- } else if (model.startsWith("gpt-4-gizmo")) {
|
|
|
- chatRequest.setModel("gpt-4-gizmo");
|
|
|
+ } else {
|
|
|
+
|
|
|
+ if (model.startsWith("gpt-4-gizmo")) {
|
|
|
+ chatRequest.setModel("gpt-4-gizmo");
|
|
|
+ }
|
|
|
+ ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
|
|
+ //openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
|
|
+
|
|
|
+ ChatCompletion completion = ChatCompletion
|
|
|
+ .builder()
|
|
|
+ .messages(chatRequest.getMessages())
|
|
|
+ .model(chatRequest.getModel())
|
|
|
+ .temperature(0.2)
|
|
|
+ .topP(1.0)
|
|
|
+ .stream(true)
|
|
|
+ .build();
|
|
|
+ openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
|
|
+
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -151,9 +179,10 @@ public class SseServiceImpl implements ISseService {
|
|
|
* 构建消息列表
|
|
|
*/
|
|
|
private void buildChatMessageList(ChatRequest chatRequest){
|
|
|
+ ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
|
|
// 获取对话消息列表
|
|
|
List<Message> messages = chatRequest.getMessages();
|
|
|
- String sysPrompt = chatRequest.getSysPrompt();
|
|
|
+ String sysPrompt = chatModelVo.getSystemPrompt();
|
|
|
if(StringUtils.isEmpty(sysPrompt)){
|
|
|
sysPrompt ="你是一个由RuoYI-AI开发的人工智能助手,名字叫熊猫助手。你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。" +
|
|
|
"当前时间:"+ DateUtils.getDate();
|
|
@@ -162,8 +191,9 @@ public class SseServiceImpl implements ISseService {
|
|
|
Message sysMessage = Message.builder().content(sysPrompt).role(Message.Role.SYSTEM).build();
|
|
|
messages.add(0,sysMessage);
|
|
|
|
|
|
+ chatRequest.setSysPrompt(sysPrompt);
|
|
|
// 查询向量库相关信息加入到上下文
|
|
|
- if(chatRequest.getKid()!=null){
|
|
|
+ if(StringUtils.isNotEmpty(chatRequest.getKid())){
|
|
|
List<Message> knMessages = new ArrayList<>();
|
|
|
String content = messages.get(messages.size() - 1).getContent().toString();
|
|
|
List<String> nearestList;
|