|
@@ -2,39 +2,29 @@ package org.ruoyi.chat.service.chat.impl;
|
|
|
|
|
|
import cn.dev33.satoken.stp.StpUtil;
|
|
|
import cn.hutool.core.collection.CollectionUtil;
|
|
|
-import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import com.google.protobuf.ServiceException;
|
|
|
-import com.zhipu.oapi.ClientV4;
|
|
|
-import com.zhipu.oapi.service.v4.tools.*;
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
-import okhttp3.*;
|
|
|
-
|
|
|
-import org.ruoyi.chat.config.ChatConfig;
|
|
|
-import org.ruoyi.chat.listener.SSEEventSourceListener;
|
|
|
+import okhttp3.ResponseBody;
|
|
|
+import org.ruoyi.chat.enums.ChatModeType;
|
|
|
import org.ruoyi.chat.service.chat.IChatCostService;
|
|
|
-import org.ruoyi.chat.service.chat.IChatService;
|
|
|
import org.ruoyi.chat.service.chat.ISseService;
|
|
|
import org.ruoyi.chat.util.IpUtil;
|
|
|
import org.ruoyi.common.chat.config.LocalCache;
|
|
|
-import org.ruoyi.common.chat.request.ChatRequest;
|
|
|
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
|
|
import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
|
|
import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
|
|
-
|
|
|
import org.ruoyi.common.chat.entity.chat.Message;
|
|
|
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
|
|
import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
|
|
|
import org.ruoyi.common.chat.openai.OpenAiStreamClient;
|
|
|
-import org.ruoyi.common.core.service.ConfigService;
|
|
|
+import org.ruoyi.common.chat.request.ChatRequest;
|
|
|
import org.ruoyi.common.core.utils.DateUtils;
|
|
|
import org.ruoyi.common.core.utils.StringUtils;
|
|
|
import org.ruoyi.common.core.utils.file.FileUtils;
|
|
|
import org.ruoyi.common.core.utils.file.MimeTypeUtils;
|
|
|
-
|
|
|
import org.ruoyi.common.redis.utils.RedisUtils;
|
|
|
-
|
|
|
import org.ruoyi.domain.vo.ChatModelVo;
|
|
|
import org.ruoyi.service.EmbeddingService;
|
|
|
import org.ruoyi.service.IChatModelService;
|
|
@@ -55,11 +45,7 @@ import java.nio.file.Files;
|
|
|
import java.nio.file.Path;
|
|
|
import java.time.Duration;
|
|
|
import java.util.ArrayList;
|
|
|
-import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
-import java.util.concurrent.TimeUnit;
|
|
|
-import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
-import java.util.concurrent.atomic.AtomicReference;
|
|
|
|
|
|
@Service
|
|
|
@Slf4j
|
|
@@ -72,19 +58,16 @@ public class SseServiceImpl implements ISseService {
|
|
|
|
|
|
private final VectorStoreService vectorStore;
|
|
|
|
|
|
- private final ConfigService configService;
|
|
|
-
|
|
|
private final IChatCostService chatCostService;
|
|
|
|
|
|
- private final IChatService chatService;
|
|
|
-
|
|
|
private final IChatModelService chatModelService;
|
|
|
|
|
|
- private static final String requestIdTemplate = "company-%d";
|
|
|
+ private final OpenAIServiceImpl openAIService;
|
|
|
+
|
|
|
+ private final OllamaServiceImpl ollamaService;
|
|
|
|
|
|
- private static final ObjectMapper mapper = new ObjectMapper();
|
|
|
+ private ChatModelVo chatModelVo;
|
|
|
|
|
|
- private final ChatConfig chatConfig;
|
|
|
|
|
|
@Override
|
|
|
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
|
@@ -100,7 +83,7 @@ public class SseServiceImpl implements ISseService {
|
|
|
|
|
|
chatRequest.setUserId(chatCostService.getUserId());
|
|
|
// 保存消息记录 并扣除费用
|
|
|
- // chatCostService.deductToken(chatRequest);
|
|
|
+ chatCostService.deductToken(chatRequest);
|
|
|
}
|
|
|
// 根据模型名称前缀调用不同的处理逻辑
|
|
|
switchModelAndHandle(chatRequest,sseEmitter);
|
|
@@ -143,35 +126,11 @@ public class SseServiceImpl implements ISseService {
|
|
|
* 根据模型名称前缀调用不同的处理逻辑
|
|
|
*/
|
|
|
private void switchModelAndHandle(ChatRequest chatRequest,SseEmitter emitter) {
|
|
|
- SSEEventSourceListener openAIEventSourceListener = new SSEEventSourceListener(emitter);
|
|
|
- String model = chatRequest.getModel();
|
|
|
- // 如果模型名称以ollama开头,则调用ollama中部署的本地模型
|
|
|
- if (model.startsWith("ollama-")) {
|
|
|
- String[] parts = chatRequest.getModel().split("ollama-", 2);
|
|
|
- if (parts.length > 1) {
|
|
|
- chatRequest.setModel(parts[1]);
|
|
|
- chatService.mcpChat(chatRequest,emitter);
|
|
|
- } else {
|
|
|
- throw new IllegalArgumentException("Invalid ollama model name: " + chatRequest.getModel());
|
|
|
- }
|
|
|
+ // 调用ollama中部署的本地模型
|
|
|
+ if (ChatModeType.OLLAMA.getCode().equals(chatModelVo.getCategory())) {
|
|
|
+ ollamaService.chat(chatRequest,emitter);
|
|
|
} else {
|
|
|
-
|
|
|
- if (model.startsWith("gpt-4-gizmo")) {
|
|
|
- chatRequest.setModel("gpt-4-gizmo");
|
|
|
- }
|
|
|
- ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
|
|
- //openAiStreamClient = chatConfig.createOpenAiStreamClient(chatModelVo.getApiHost(), chatModelVo.getApiKey());
|
|
|
-
|
|
|
- ChatCompletion completion = ChatCompletion
|
|
|
- .builder()
|
|
|
- .messages(chatRequest.getMessages())
|
|
|
- .model(chatRequest.getModel())
|
|
|
- .temperature(0.2)
|
|
|
- .topP(1.0)
|
|
|
- .stream(true)
|
|
|
- .build();
|
|
|
- openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener);
|
|
|
-
|
|
|
+ openAIService.chat(chatRequest,emitter);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -179,7 +138,7 @@ public class SseServiceImpl implements ISseService {
|
|
|
* 构建消息列表
|
|
|
*/
|
|
|
private void buildChatMessageList(ChatRequest chatRequest){
|
|
|
- ChatModelVo chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
|
|
+ chatModelVo = chatModelService.selectModelByName(chatRequest.getModel());
|
|
|
// 获取对话消息列表
|
|
|
List<Message> messages = chatRequest.getMessages();
|
|
|
String sysPrompt = chatModelVo.getSystemPrompt();
|
|
@@ -220,11 +179,6 @@ public class SseServiceImpl implements ISseService {
|
|
|
}
|
|
|
// 设置对话信息
|
|
|
chatRequest.setPrompt(chatString);
|
|
|
- // 加载联网信息
|
|
|
- if(chatRequest.getSearch()){
|
|
|
- Message message = Message.builder().role(Message.Role.ASSISTANT).content("联网信息:"+webSearch(chatString)).build();
|
|
|
- messages.add(message);
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
|
|
@@ -333,58 +287,4 @@ public class SseServiceImpl implements ISseService {
|
|
|
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
- public String webSearch (String prompt) {
|
|
|
- String zpValue = configService.getConfigValue("zhipu", "key");
|
|
|
- if(StringUtils.isEmpty(zpValue)){
|
|
|
- throw new IllegalStateException("请在chat_config中配置智谱key信息");
|
|
|
- }else {
|
|
|
- ClientV4 client = new ClientV4.Builder(zpValue)
|
|
|
- .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
|
|
- .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
|
|
- .build();
|
|
|
-
|
|
|
- SearchChatMessage jsonNodes = new SearchChatMessage();
|
|
|
- jsonNodes.setRole(Message.Role.USER.getName());
|
|
|
- jsonNodes.setContent(prompt);
|
|
|
-
|
|
|
- String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
|
|
- WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
|
|
- .model("web-search-pro")
|
|
|
- .stream(Boolean.TRUE)
|
|
|
- .messages(Collections.singletonList(jsonNodes))
|
|
|
- .requestId(requestId)
|
|
|
- .build();
|
|
|
- WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
|
|
|
- List<ChoiceDelta> choices = new ArrayList<>();
|
|
|
- if (webSearchApiResponse.isSuccess()) {
|
|
|
- AtomicBoolean isFirst = new AtomicBoolean(true);
|
|
|
-
|
|
|
- AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
|
|
|
-
|
|
|
- webSearchApiResponse.getFlowable().map(result -> result)
|
|
|
- .doOnNext(accumulator -> {
|
|
|
- {
|
|
|
- if (isFirst.getAndSet(false)) {
|
|
|
- log.info("Response: ");
|
|
|
- }
|
|
|
- ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
|
|
|
- if (delta != null && delta.getToolCalls() != null) {
|
|
|
- log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
|
|
|
- }
|
|
|
- choices.add(delta);
|
|
|
- }
|
|
|
- })
|
|
|
- .doOnComplete(() -> System.out.println("Stream completed."))
|
|
|
- .doOnError(throwable -> System.err.println("Error: " + throwable))
|
|
|
- .blockingSubscribe();
|
|
|
-
|
|
|
- WebSearchPro chatMessageAccumulator = lastAccumulator.get();
|
|
|
- webSearchApiResponse.setFlowable(null);
|
|
|
- webSearchApiResponse.setData(chatMessageAccumulator);
|
|
|
- }
|
|
|
- return choices.get(1).getToolCalls().toString();
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
}
|