|
@@ -1,8 +1,10 @@
|
|
|
package org.ruoyi.system.service.impl;
|
|
|
|
|
|
import cn.dev33.satoken.stp.StpUtil;
|
|
|
-import cn.hutool.core.collection.CollectionUtil;
|
|
|
import com.alibaba.fastjson.JSONObject;
|
|
|
+import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
+import com.zhipu.oapi.ClientV4;
|
|
|
+import com.zhipu.oapi.service.v4.tools.*;
|
|
|
import io.github.ollama4j.OllamaAPI;
|
|
|
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
|
|
|
import io.github.ollama4j.models.chat.OllamaChatRequestBuilder;
|
|
@@ -17,10 +19,7 @@ import org.ruoyi.common.chat.config.LocalCache;
|
|
|
import org.ruoyi.common.chat.domain.request.ChatRequest;
|
|
|
import org.ruoyi.common.chat.domain.request.Dall3Request;
|
|
|
import org.ruoyi.common.chat.entity.Tts.TextToSpeech;
|
|
|
-import org.ruoyi.common.chat.entity.chat.ChatCompletion;
|
|
|
-import org.ruoyi.common.chat.entity.chat.ChatCompletionResponse;
|
|
|
-import org.ruoyi.common.chat.entity.chat.Content;
|
|
|
-import org.ruoyi.common.chat.entity.chat.Message;
|
|
|
+import org.ruoyi.common.chat.entity.chat.*;
|
|
|
import org.ruoyi.common.chat.entity.files.UploadFileResponse;
|
|
|
import org.ruoyi.common.chat.entity.images.Image;
|
|
|
import org.ruoyi.common.chat.entity.images.ImageResponse;
|
|
@@ -33,17 +32,15 @@ import org.ruoyi.common.chat.plugin.CmdPlugin;
|
|
|
import org.ruoyi.common.chat.plugin.CmdReq;
|
|
|
import org.ruoyi.common.chat.plugin.SqlPlugin;
|
|
|
import org.ruoyi.common.chat.plugin.SqlReq;
|
|
|
-import org.ruoyi.common.chat.sse.ConsoleEventSourceListener;
|
|
|
import org.ruoyi.common.chat.utils.TikTokensUtil;
|
|
|
import org.ruoyi.common.core.domain.model.LoginUser;
|
|
|
import org.ruoyi.common.core.exception.base.BaseException;
|
|
|
import org.ruoyi.common.core.service.ConfigService;
|
|
|
import org.ruoyi.common.core.utils.StringUtils;
|
|
|
import org.ruoyi.common.satoken.utils.LoginHelper;
|
|
|
+import org.ruoyi.system.domain.SysModel;
|
|
|
import org.ruoyi.system.domain.bo.ChatMessageBo;
|
|
|
-import org.ruoyi.system.domain.bo.SysModelBo;
|
|
|
import org.ruoyi.system.domain.request.translation.TranslationRequest;
|
|
|
-import org.ruoyi.system.domain.vo.SysModelVo;
|
|
|
import org.ruoyi.system.listener.SSEEventSourceListener;
|
|
|
import org.ruoyi.system.service.*;
|
|
|
import org.springframework.core.io.InputStreamResource;
|
|
@@ -65,6 +62,9 @@ import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
import java.util.concurrent.CompletableFuture;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
+import java.util.concurrent.atomic.AtomicReference;
|
|
|
|
|
|
|
|
|
@Service
|
|
@@ -76,18 +76,21 @@ public class SseServiceImpl implements ISseService {
|
|
|
|
|
|
private final ChatConfig chatConfig;
|
|
|
|
|
|
+
|
|
|
private final IChatCostService chatService;
|
|
|
|
|
|
private final IChatMessageService chatMessageService;
|
|
|
|
|
|
private final ISysModelService sysModelService;
|
|
|
|
|
|
- private final ISysUserService userService;
|
|
|
-
|
|
|
private final ConfigService configService;
|
|
|
|
|
|
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
|
|
|
|
|
|
+ private static final String requestIdTemplate = "mycompany-%d";
|
|
|
+
|
|
|
+ private static final ObjectMapper mapper = new ObjectMapper();
|
|
|
+
|
|
|
@Override
|
|
|
public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
|
|
|
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
|
@@ -96,11 +99,10 @@ public class SseServiceImpl implements ISseService {
|
|
|
// 获取对话消息列表
|
|
|
List<Message> messages = chatRequest.getMessages();
|
|
|
try {
|
|
|
+ String chatString = null;
|
|
|
if (StpUtil.isLogin()) {
|
|
|
LocalCache.CACHE.put("userId", getUserId());
|
|
|
Object content = messages.get(messages.size() - 1).getContent();
|
|
|
-
|
|
|
- String chatString = "";
|
|
|
if (content instanceof List<?> listContent) {
|
|
|
if (!listContent.isEmpty() && listContent.get(0) instanceof Content) {
|
|
|
chatString = ((Content) listContent.get(0)).getText();
|
|
@@ -123,39 +125,89 @@ public class SseServiceImpl implements ISseService {
|
|
|
throw new BaseException("文本不合规,请修改!");
|
|
|
}
|
|
|
}
|
|
|
- //根据模型名称查询模型信息
|
|
|
- SysModelBo sysModelBo = new SysModelBo();
|
|
|
+ String model = chatRequest.getModel();
|
|
|
// 如果是gpts系列模型
|
|
|
if (chatRequest.getModel().startsWith("gpt-4-gizmo")) {
|
|
|
- sysModelBo.setModelName("gpt-4-gizmo");
|
|
|
- } else {
|
|
|
- sysModelBo.setModelName(chatRequest.getModel());
|
|
|
+ model = "gpt-4-gizmo";
|
|
|
}
|
|
|
- List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
|
|
-
|
|
|
- if (CollectionUtil.isEmpty(sysModelList)) {
|
|
|
+ SysModel sysModel = sysModelService.selectModelByName(model);
|
|
|
+ if (sysModel != null) {
|
|
|
// 如果模型不存在默认使用token扣费方式
|
|
|
processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
|
|
} else {
|
|
|
- openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModelList.get(0).getApiHost(), sysModelList.get(0).getApiKey());
|
|
|
+ openAiStreamClient = chatConfig.createOpenAiStreamClient(sysModel.getApiHost(), sysModel.getApiKey());
|
|
|
// 模型设置默认提示词
|
|
|
- SysModelVo firstModel = sysModelList.get(0);
|
|
|
- if (StringUtils.isNotEmpty(firstModel.getSystemPrompt())) {
|
|
|
- Message sysMessage = Message.builder().content(firstModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
|
|
|
+
|
|
|
+ if (StringUtils.isNotEmpty(sysModel.getSystemPrompt())) {
|
|
|
+ Message sysMessage = Message.builder().content(sysModel.getSystemPrompt()).role(Message.Role.SYSTEM).build();
|
|
|
messages.add(sysMessage);
|
|
|
}
|
|
|
// 计费类型: 1 token扣费 2 次数扣费
|
|
|
- if ("2".equals(firstModel.getModelType())) {
|
|
|
- processByModelPrice(firstModel, chatMessageBo);
|
|
|
+ if ("2".equals(sysModel.getModelType())) {
|
|
|
+ processByModelPrice(sysModel, chatMessageBo);
|
|
|
} else {
|
|
|
- processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
|
|
+ processByToken(chatRequest.getModel(), chatString, chatMessageBo);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- if("openCmd".equals(chatRequest.getModel())) {
|
|
|
+ String configValue = configService.getConfigValue("zhipu", "key");
|
|
|
+ // 添加联网信息
|
|
|
+ if(StringUtils.isNotEmpty(configValue)){
|
|
|
+ ClientV4 client = new ClientV4.Builder(configValue)
|
|
|
+ .networkConfig(300, 100, 100, 100, TimeUnit.SECONDS)
|
|
|
+ .connectionPool(new okhttp3.ConnectionPool(8, 1, TimeUnit.SECONDS))
|
|
|
+ .build();
|
|
|
+
|
|
|
+ SearchChatMessage jsonNodes = new SearchChatMessage();
|
|
|
+ jsonNodes.setRole(Message.Role.USER.getName());
|
|
|
+ jsonNodes.setContent(chatString);
|
|
|
+
|
|
|
+ String requestId = String.format(requestIdTemplate, System.currentTimeMillis());
|
|
|
+ WebSearchParamsRequest chatCompletionRequest = WebSearchParamsRequest.builder()
|
|
|
+ .model("web-search-pro")
|
|
|
+ .stream(Boolean.TRUE)
|
|
|
+ .messages(Collections.singletonList(jsonNodes))
|
|
|
+ .requestId(requestId)
|
|
|
+ .build();
|
|
|
+ WebSearchApiResponse webSearchApiResponse = client.webSearchProStreamingInvoke(chatCompletionRequest);
|
|
|
+ List<ChoiceDelta> choices = new ArrayList<>();
|
|
|
+ if (webSearchApiResponse.isSuccess()) {
|
|
|
+ AtomicBoolean isFirst = new AtomicBoolean(true);
|
|
|
+
|
|
|
+ AtomicReference<WebSearchPro> lastAccumulator = new AtomicReference<>();
|
|
|
+
|
|
|
+ webSearchApiResponse.getFlowable().map(result -> result)
|
|
|
+ .doOnNext(accumulator -> {
|
|
|
+ {
|
|
|
+ if (isFirst.getAndSet(false)) {
|
|
|
+ log.info("Response: ");
|
|
|
+ }
|
|
|
+ ChoiceDelta delta = accumulator.getChoices().get(0).getDelta();
|
|
|
+ if (delta != null && delta.getToolCalls() != null) {
|
|
|
+ log.info("tool_calls: {}", mapper.writeValueAsString(delta.getToolCalls()));
|
|
|
+ }
|
|
|
+ choices.add(delta);
|
|
|
+ }
|
|
|
+ })
|
|
|
+ .doOnComplete(() -> System.out.println("Stream completed."))
|
|
|
+ .doOnError(throwable -> System.err.println("Error: " + throwable))
|
|
|
+ .blockingSubscribe();
|
|
|
+
|
|
|
+ WebSearchPro chatMessageAccumulator = lastAccumulator.get();
|
|
|
+
|
|
|
+ webSearchApiResponse.setFlowable(null);// 打印前置空
|
|
|
+ webSearchApiResponse.setData(chatMessageAccumulator);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ Message message = Message.builder().role(Message.Role.ASSISTANT).content(choices.get(1).getToolCalls().toString()).build();
|
|
|
+ messages.add(message);
|
|
|
+ }
|
|
|
+
|
|
|
+ if ("openCmd".equals(chatRequest.getModel())) {
|
|
|
sseEmitter.send(cmdPlugin(messages));
|
|
|
sseEmitter.complete();
|
|
|
- }else if ("sqlPlugin".equals(chatRequest.getModel())){
|
|
|
+ } else if ("sqlPlugin".equals(chatRequest.getModel())) {
|
|
|
sseEmitter.send(sqlPlugin(messages));
|
|
|
sseEmitter.complete();
|
|
|
} else {
|
|
@@ -229,7 +281,7 @@ public class SseServiceImpl implements ISseService {
|
|
|
* @param model 模型信息
|
|
|
* @param chatMessageBo 对话信息
|
|
|
*/
|
|
|
- private void processByModelPrice(SysModelVo model, ChatMessageBo chatMessageBo) {
|
|
|
+ private void processByModelPrice(SysModel model, ChatMessageBo chatMessageBo) {
|
|
|
double cost = model.getModelPrice();
|
|
|
chatService.deductUserBalance(getUserId(), cost);
|
|
|
chatMessageBo.setDeductCost(cost);
|
|
@@ -316,16 +368,14 @@ public class SseServiceImpl implements ISseService {
|
|
|
.style(request.getStyle())
|
|
|
.build();
|
|
|
ImageResponse imageResponse = openAiStreamClient.genImages(image);
|
|
|
- SysModelBo sysModelBo = new SysModelBo();
|
|
|
- sysModelBo.setModelName(request.getModel());
|
|
|
- List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
|
|
+ SysModel sysModel = sysModelService.selectModelByName(request.getModel());
|
|
|
//chatService.deductUserBalance(getUserId(),sysModelList.get(0).getModelPrice());
|
|
|
// 保存消息记录
|
|
|
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
|
|
chatMessageBo.setUserId(getUserId());
|
|
|
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
|
|
|
chatMessageBo.setContent(request.getPrompt());
|
|
|
- chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
|
|
|
+ chatMessageBo.setDeductCost(sysModel.getModelPrice());
|
|
|
chatMessageBo.setTotalTokens(0);
|
|
|
chatMessageService.insertByBo(chatMessageBo);
|
|
|
return imageResponse.getData();
|
|
@@ -342,16 +392,14 @@ public class SseServiceImpl implements ISseService {
|
|
|
.n(1)
|
|
|
.build();
|
|
|
ImageResponse imageResponse = openAiStreamClient.genImages(image);
|
|
|
- SysModelBo sysModelBo = new SysModelBo();
|
|
|
- sysModelBo.setModelName("dall3");
|
|
|
- List<SysModelVo> sysModelList = sysModelService.queryList(sysModelBo);
|
|
|
+ SysModel dall3 = sysModelService.selectModelByName("dall3");
|
|
|
chatService.deductUserBalance(Long.valueOf(userId), 0.3);
|
|
|
// 保存消息记录
|
|
|
ChatMessageBo chatMessageBo = new ChatMessageBo();
|
|
|
chatMessageBo.setUserId(getUserId());
|
|
|
chatMessageBo.setModelName(Image.Model.DALL_E_3.getName());
|
|
|
chatMessageBo.setContent(prompt);
|
|
|
- chatMessageBo.setDeductCost(sysModelList.get(0).getModelPrice());
|
|
|
+ chatMessageBo.setDeductCost(dall3.getModelPrice());
|
|
|
chatMessageBo.setTotalTokens(0);
|
|
|
chatMessageService.insertByBo(chatMessageBo);
|
|
|
return imageResponse.getData();
|
|
@@ -527,12 +575,9 @@ public class SseServiceImpl implements ISseService {
|
|
|
chatMessageBo.setDeductCost(0.01);
|
|
|
chatMessageBo.setTotalTokens(0);
|
|
|
chatMessageService.insertByBo(chatMessageBo);
|
|
|
-
|
|
|
openAiStreamClient = chatConfig.getOpenAiStreamClient();
|
|
|
-
|
|
|
List<Message> messageList = new ArrayList<>();
|
|
|
-
|
|
|
- Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一名翻译老师\n" +
|
|
|
+ Message sysMessage = Message.builder().role(Message.Role.SYSTEM).content("你是一位精通各国语言的翻译大师\n" +
|
|
|
"\n" +
|
|
|
"请将用户输入词语翻译成{" + translationRequest.getTargetLanguage() + "}\n" +
|
|
|
"\n" +
|
|
@@ -563,25 +608,21 @@ public class SseServiceImpl implements ISseService {
|
|
|
|
|
|
@Override
|
|
|
public SseEmitter ollamaChat(ChatRequest chatRequest) {
|
|
|
+ String[] parts = chatRequest.getModel().split("ollama-");
|
|
|
+ SysModel sysModel = sysModelService.selectModelByName(parts[1]);
|
|
|
final SseEmitter emitter = new SseEmitter();
|
|
|
- String host = "http://localhost:11434/";
|
|
|
-
|
|
|
+ String host = sysModel.getApiHost();
|
|
|
List<Message> msgList = chatRequest.getMessages();
|
|
|
Message message = msgList.get(msgList.size() - 1);
|
|
|
-
|
|
|
- OllamaAPI ollamaAPI = new OllamaAPI(host);
|
|
|
-
|
|
|
- ollamaAPI.setRequestTimeoutSeconds(100);
|
|
|
-
|
|
|
- OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance("qwen2.5:7b");
|
|
|
-
|
|
|
+ OllamaAPI api = new OllamaAPI(host);
|
|
|
+ api.setRequestTimeoutSeconds(100);
|
|
|
+ OllamaChatRequestBuilder builder = OllamaChatRequestBuilder.getInstance(sysModel.getModelName());
|
|
|
OllamaChatRequestModel requestModel = builder
|
|
|
.withMessage(OllamaChatMessageRole.USER,
|
|
|
message.getContent().toString())
|
|
|
.build();
|
|
|
|
|
|
-
|
|
|
- // 异步执行 Ollama API 调用
|
|
|
+ // 异步执行 OllAma API 调用
|
|
|
CompletableFuture.runAsync(() -> {
|
|
|
try {
|
|
|
StringBuilder response = new StringBuilder();
|
|
@@ -595,14 +636,12 @@ public class SseServiceImpl implements ISseService {
|
|
|
sendErrorEvent(emitter, e.getMessage());
|
|
|
}
|
|
|
};
|
|
|
- ollamaAPI.chat(requestModel, streamHandler);
|
|
|
+ api.chat(requestModel, streamHandler);
|
|
|
emitter.complete();
|
|
|
} catch (Exception e) {
|
|
|
sendErrorEvent(emitter, e.getMessage());
|
|
|
}
|
|
|
});
|
|
|
-
|
|
|
-
|
|
|
return emitter;
|
|
|
}
|
|
|
|
|
@@ -620,6 +659,4 @@ public class SseServiceImpl implements ISseService {
|
|
|
ChatCompletionResponse chatCompletionResponse = openAiStreamClient.chatCompletion(chatCompletion);
|
|
|
return chatCompletionResponse.getChoices().get(0).getMessage().getContent().toString();
|
|
|
}
|
|
|
-
|
|
|
-
|
|
|
}
|