Przeglądaj źródła

本地向量化

jiahao.he@vtradex.com 2 miesięcy temu
rodzic
commit
4967c3f906

+ 38 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchRequest.java

@@ -0,0 +1,38 @@
+package org.ruoyi.common.chat.entity.models;
+
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * @program: RUOYIAI
+ * @ClassName LocalModelsSearchRequest
+ * @description:
+ * @author: hejh
+ * @create: 2025-03-15 17:22
+ * @Version 1.0
+ **/
+@Data
+public class LocalModelsSearchRequest {
+
+    private List<String> text;
+    private String model_name;
+    private String delimiter;
+    private int k;
+    private int block_size;
+    private int overlap_chars;
+
+    // 构造函数、Getter 和 Setter
+    public LocalModelsSearchRequest(List<String> text, String model_name, String delimiter, int k, int block_size, int overlap_chars) {
+        this.text = text;
+        this.model_name = model_name;
+        this.delimiter = delimiter;
+        this.k = k;
+        this.block_size = block_size;
+        this.overlap_chars = overlap_chars;
+    }
+
+
+}
+
+

+ 20 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/entity/models/LocalModelsSearchResponse.java

@@ -0,0 +1,20 @@
+package org.ruoyi.common.chat.entity.models;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.Data;
+
+import java.util.List;
+
+@Data
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class LocalModelsSearchResponse {
+    @JsonProperty("topKEmbeddings")
+
+    private List<List<List<Double>>> topKEmbeddings;  // 处理三层嵌套数组
+
+    // 默认构造函数
+    public LocalModelsSearchResponse() {}
+
+
+
+}

+ 198 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/LocalModelsofitClient.java

@@ -0,0 +1,198 @@
+package org.ruoyi.common.chat.localModels;
+
+import io.micrometer.common.util.StringUtils;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.OkHttpClient;
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
+import org.springframework.stereotype.Service;
+import retrofit2.Call;
+import retrofit2.Callback;
+import retrofit2.Response;
+import retrofit2.Retrofit;
+import retrofit2.converter.jackson.JacksonConverterFactory;
+
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+
+@Slf4j
+@Service
+public class LocalModelsofitClient {
+    private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL
+    private static Retrofit retrofit = null;
+
+    // 获取 Retrofit 实例
+    public static Retrofit getRetrofitInstance() {
+        if (retrofit == null) {
+            OkHttpClient client = new OkHttpClient.Builder()
+                    .build();
+
+            retrofit = new Retrofit.Builder()
+                    .baseUrl(BASE_URL)
+                    .client(client)
+                    .addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换
+                    .build();
+        }
+        return retrofit;
+    }
+
+    /**
+     * 向 Flask 服务发送文本向量化请求
+     *
+     * @param queries 查询文本列表
+     * @param modelName 模型名称
+     * @param delimiter 文本分隔符
+     * @param topK 返回的结果数
+     * @param blockSize 文本块大小
+     * @param overlapChars 重叠字符数
+     * @return 返回计算得到的 Top K 嵌入向量列表
+     */
+
+    public static List<List<Double>> getTopKEmbeddings(
+            List<String> queries,
+            String modelName,
+            String delimiter,
+            int topK,
+            int blockSize,
+            int overlapChars) {
+
+        modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称
+        delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : ".";                             // 默认分隔符
+        topK = (topK > 0) ? topK : 3;                                                  // 默认返回 3 个结果
+        blockSize = (blockSize > 0) ? blockSize : 500;                                 // 默认文本块大小为 500
+        overlapChars = (overlapChars > 0) ? overlapChars : 50;                         // 默认重叠字符数为 50
+
+        // 创建 Retrofit 实例
+        Retrofit retrofit = getRetrofitInstance();
+
+        // 创建 SearchService 接口
+        SearchService service = retrofit.create(SearchService.class);
+
+        // 创建请求对象 LocalModelsSearchRequest
+        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
+                queries,            // 查询文本列表
+                modelName,          // 模型名称
+                delimiter,          // 文本分隔符
+                topK,               // 返回的结果数
+                blockSize,          // 文本块大小
+                overlapChars        // 重叠字符数
+        );
+
+        final CountDownLatch latch = new CountDownLatch(1);  // 创建一个 CountDownLatch
+        final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List)
+
+        // 发起异步请求
+        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
+            @Override
+            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
+                if (response.isSuccessful()) {
+                    LocalModelsSearchResponse searchResponse = response.body();
+                    if (searchResponse != null) {
+                        topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0);  // 获取结果
+                        log.info("Successfully retrieved embeddings");
+                    } else {
+                        log.error("Response body is null");
+                    }
+                } else {
+                    log.error("Request failed. HTTP error code: " + response.code());
+                }
+                latch.countDown();  // 请求完成,减少计数
+            }
+
+            @Override
+            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
+                t.printStackTrace();
+                log.error("Request failed: ", t);
+                latch.countDown();  // 请求失败,减少计数
+            }
+        });
+
+        try {
+            latch.await();  // 等待请求完成
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+
+        return topKEmbeddings[0];  // 返回结果
+    }
+
+//    public static void main(String[] args) {
+//        // 示例调用
+//        List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
+//        String modelName = "msmarco-distilbert-base-tas-b";
+//        String delimiter = ".";
+//        int topK = 3;
+//        int blockSize = 500;
+//        int overlapChars = 50;
+//
+//        List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
+//
+//        // 打印结果
+//        if (topKEmbeddings != null) {
+//            System.out.println("Top K embeddings: ");
+//            for (List<Double> embedding : topKEmbeddings) {
+//                System.out.println(embedding);
+//            }
+//        } else {
+//            System.out.println("No embeddings returned.");
+//        }
+//    }
+
+
+//    public static void main(String[] args) {
+//        // 创建 Retrofit 实例
+//        Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
+//
+//        // 创建 SearchService 接口
+//        SearchService service = retrofit.create(SearchService.class);
+//
+//        // 创建请求对象 LocalModelsSearchRequest
+//        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
+//                Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表
+//                "msmarco-distilbert-base-tas-b",  // 模型名称
+//                ".",  // 分隔符
+//                3,  // 返回的结果数
+//                500,  // 文本块大小
+//                50  // 重叠字符数
+//        );
+//
+//        // 发起请求
+//        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
+//            @Override
+//            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
+//                if (response.isSuccessful()) {
+//                    LocalModelsSearchResponse searchResponse = response.body();
+//                    System.out.println("Response Body: " + response.body());  // Print the whole response body for debugging
+//
+//                    if (searchResponse != null) {
+//                        // If the response is not null, process it.
+//                        // Example: Extract the embeddings and print them
+//                        List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
+//                        if (topKEmbeddings != null) {
+//                            // Print the Top K embeddings
+//
+//                        } else {
+//                            System.err.println("Top K embeddings are null");
+//                        }
+//
+//                        // If there is more information you want to process, handle it here
+//
+//                    } else {
+//                        System.err.println("Response body is null");
+//                    }
+//                } else {
+//                    System.err.println("Request failed. HTTP error code: " + response.code());
+//                    log.error("Failed to retrieve data. HTTP error code: " + response.code());
+//                }
+//            }
+//
+//            @Override
+//            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
+//                // 请求失败,打印错误
+//                t.printStackTrace();
+//                log.error("Request failed: ", t);
+//            }
+//        });
+//    }
+
+}

+ 25 - 0
ruoyi-common/ruoyi-common-chat/src/main/java/org/ruoyi/common/chat/localModels/SearchService.java

@@ -0,0 +1,25 @@
+package org.ruoyi.common.chat.localModels;
+
+
+
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
+import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
+import retrofit2.Call;
+import retrofit2.http.Body;
+import retrofit2.http.POST;
+/**
+ * @program: RUOYIAI
+ * @ClassName SearchService
+ * @description: 请求模型
+ * @author: hejh
+ * @create: 2025-03-15 17:27
+ * @Version 1.0
+ **/
+
+
+public interface SearchService {
+    @POST("/vectorize") // 与 Flask 服务中的路由匹配
+    Call<LocalModelsSearchResponse> vectorize(@Body LocalModelsSearchRequest request);
+}
+
+

+ 92 - 0
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/LocalModelsVectorization.java

@@ -0,0 +1,92 @@
+package org.ruoyi.knowledge.chain.vectorizer;
+
+import jakarta.annotation.Resource;
+import lombok.Getter;
+import lombok.RequiredArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.ruoyi.common.chat.config.ChatConfig;
+import org.ruoyi.common.chat.localModels.LocalModelsofitClient;
+import org.ruoyi.common.chat.openai.OpenAiStreamClient;
+import org.ruoyi.knowledge.domain.vo.KnowledgeInfoVo;
+import org.ruoyi.knowledge.service.IKnowledgeInfoService;
+import org.springframework.stereotype.Component;
+
+import java.util.ArrayList;
+import java.util.List;
+
+@Component
+@Slf4j
+@RequiredArgsConstructor
+public class LocalModelsVectorization   {
+    @Resource
+    private IKnowledgeInfoService knowledgeInfoService;
+
+    @Resource
+    private LocalModelsofitClient localModelsofitClient;
+
+    @Getter
+    private OpenAiStreamClient openAiStreamClient;
+
+    private final ChatConfig chatConfig;
+
+    /**
+     * 批量向量化
+     *
+     * @param chunkList 文本块列表
+     * @param kid 知识 ID
+     * @return 向量化结果
+     */
+
+    public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
+        logVectorizationRequest(kid, chunkList);  // 在向量化开始前记录日志
+        openAiStreamClient = chatConfig.getOpenAiStreamClient(); // 获取 OpenAi 客户端
+        KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid)); // 查询知识信息
+        // 调用 localModelsofitClient 获取 Top K 嵌入向量
+        try {
+            return localModelsofitClient.getTopKEmbeddings(
+                    chunkList,
+                    knowledgeInfoVo.getVector(),
+                    knowledgeInfoVo.getKnowledgeSeparator(),
+                    knowledgeInfoVo.getRetrieveLimit(),
+                    knowledgeInfoVo.getTextBlockSize(),
+                    knowledgeInfoVo.getOverlapChar()
+            );
+        } catch (Exception e) {
+            log.error("Failed to perform batch vectorization for knowledgeId: {}", kid, e);
+            throw new RuntimeException("Batch vectorization failed", e);
+        }
+    }
+
+    /**
+     * 单一文本块向量化
+     *
+     * @param chunk 单一文本块
+     * @param kid 知识 ID
+     * @return 向量化结果
+     */
+
+    public List<Double> singleVectorization(String chunk, String kid) {
+        List<String> chunkList = new ArrayList<>();
+        chunkList.add(chunk);
+
+        // 调用批量向量化方法
+        List<List<Double>> vectorList = batchVectorization(chunkList, kid);
+
+        if (vectorList.isEmpty()) {
+            log.warn("Vectorization returned empty list for chunk: {}", chunk);
+            return new ArrayList<>();
+        }
+
+        return vectorList.get(0); // 返回第一个向量
+    }
+
+    /**
+     * 提供更简洁的日志记录方法
+     *
+     * @param kid 知识 ID
+     * @param chunkList 文本块列表
+     */
+    private void logVectorizationRequest(String kid, List<String> chunkList) {
+        log.info("Starting vectorization for Knowledge ID: {} with {} chunks.", kid, chunkList.size());
+    }
+}

+ 51 - 9
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/OpenAiVectorization.java

@@ -18,6 +18,7 @@ import org.springframework.stereotype.Component;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.stream.Collectors;
 
 @Component
 @Slf4j
@@ -27,6 +28,9 @@ public class OpenAiVectorization implements Vectorization {
     @Lazy
     @Resource
     private IKnowledgeInfoService knowledgeInfoService;
+    @Lazy
+    @Resource
+    private LocalModelsVectorization localModelsVectorization;
 
     @Getter
     private OpenAiStreamClient openAiStreamClient;
@@ -35,25 +39,63 @@ public class OpenAiVectorization implements Vectorization {
 
     @Override
     public List<List<Double>> batchVectorization(List<String> chunkList, String kid) {
-        openAiStreamClient = chatConfig.getOpenAiStreamClient();
+        List<List<Double>> vectorList = new ArrayList<>();
+
+        // 获取知识库信息
         KnowledgeInfoVo knowledgeInfoVo = knowledgeInfoService.queryById(Long.valueOf(kid));
-        Embedding embedding = Embedding.builder()
-            .input(chunkList)
-            .model(knowledgeInfoVo.getVectorModel())
-            .build();
+
+        // 如果使用本地模型
+        try {
+            return localModelsVectorization.batchVectorization(chunkList, kid);
+        } catch (Exception e) {
+            log.error("Local models vectorization failed, falling back to OpenAI embeddings", e);
+        }
+
+        // 如果本地模型失败,则调用 OpenAI 服务进行向量化
+        Embedding embedding = buildEmbedding(chunkList, knowledgeInfoVo);
         EmbeddingResponse embeddings = openAiStreamClient.embeddings(embedding);
+
+        // 处理 OpenAI 返回的嵌入数据
+        vectorList = processOpenAiEmbeddings(embeddings);
+
+        return vectorList;
+    }
+
+    /**
+     * 构建 Embedding 对象
+     */
+    private Embedding buildEmbedding(List<String> chunkList, KnowledgeInfoVo knowledgeInfoVo) {
+        return Embedding.builder()
+                .input(chunkList)
+                .model(knowledgeInfoVo.getVectorModel())
+                .build();
+    }
+
+    /**
+     * 处理 OpenAI 返回的嵌入数据
+     */
+    private List<List<Double>> processOpenAiEmbeddings(EmbeddingResponse embeddings) {
         List<List<Double>> vectorList = new ArrayList<>();
+
         embeddings.getData().forEach(data -> {
             List<BigDecimal> vector = data.getEmbedding();
-            List<Double> doubleVector = new ArrayList<>();
-            for (BigDecimal bd : vector) {
-                doubleVector.add(bd.doubleValue());
-            }
+            List<Double> doubleVector = convertToDoubleList(vector);
             vectorList.add(doubleVector);
         });
+
         return vectorList;
     }
 
+    /**
+     * 将 BigDecimal 转换为 Double 列表
+     */
+    private List<Double> convertToDoubleList(List<BigDecimal> vector) {
+        return vector.stream()
+                .map(BigDecimal::doubleValue)
+                .collect(Collectors.toList());
+    }
+
+
     @Override
     public List<Double> singleVectorization(String chunk, String kid) {
         List<String> chunkList = new ArrayList<>();

+ 15 - 0
ruoyi-modules/ruoyi-knowledge/src/main/java/org/ruoyi/knowledge/chain/vectorizer/VectorizationType.java

@@ -0,0 +1,15 @@
+package org.ruoyi.knowledge.chain.vectorizer;
+
+public enum VectorizationType {
+    OPENAI,    // OpenAI 向量化
+    LOCAL;     // 本地模型向量化
+
+    public static VectorizationType fromString(String type) {
+        for (VectorizationType v : values()) {
+            if (v.name().equalsIgnoreCase(type)) {
+                return v;
+            }
+        }
+        throw new IllegalArgumentException("Unknown VectorizationType: " + type);
+    }
+}