app.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from flask import Flask, request, jsonify
  2. from sentence_transformers import SentenceTransformer
  3. from sklearn.metrics.pairwise import cosine_similarity
  4. import json
  5. app = Flask(__name__)
  6. # 创建一个全局的模型缓存字典
  7. model_cache = {}
  8. # 分割文本块
  9. def split_text(text, block_size, overlap_chars, delimiter):
  10. chunks = text.split(delimiter)
  11. text_blocks = []
  12. current_block = ""
  13. for chunk in chunks:
  14. if len(current_block) + len(chunk) + 1 <= block_size:
  15. if current_block:
  16. current_block += " " + chunk
  17. else:
  18. current_block = chunk
  19. else:
  20. text_blocks.append(current_block)
  21. current_block = chunk
  22. if current_block:
  23. text_blocks.append(current_block)
  24. overlap_blocks = []
  25. for i in range(len(text_blocks)):
  26. if i > 0:
  27. overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
  28. overlap_blocks.append(overlap_block)
  29. overlap_blocks.append(text_blocks[i])
  30. return overlap_blocks
  31. # 文本向量化
  32. def vectorize_text_blocks(text_blocks, model):
  33. return model.encode(text_blocks)
  34. # 文本检索
  35. def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
  36. # 将知识库拆分为文本块
  37. text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
  38. # 向量化文本块
  39. knowledge_vectors = vectorize_text_blocks(text_blocks, model)
  40. # 向量化查询文本
  41. query_vector = model.encode([query]).reshape(1, -1)
  42. # 计算相似度
  43. similarities = cosine_similarity(query_vector, knowledge_vectors)
  44. # 获取相似度最高的 k 个文本块的索引
  45. top_k_indices = similarities[0].argsort()[-k:][::-1]
  46. # 返回文本块和它们的向量
  47. top_k_texts = [text_blocks[i] for i in top_k_indices]
  48. top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
  49. return top_k_texts, top_k_embeddings
  50. @app.route('/vectorize', methods=['POST'])
  51. def vectorize_text():
  52. # 从请求中获取 JSON 数据
  53. data = request.json
  54. print(f"Received request data: {data}") # 调试输出请求数据
  55. text_list = data.get("text", [])
  56. model_name = data.get("model_name", "msmarco-distilbert-base-tas-b") # 默认模型
  57. delimiter = data.get("delimiter", "\n") # 默认分隔符
  58. k = int(data.get("k", 3)) # 默认检索条数
  59. block_size = int(data.get("block_size", 500)) # 默认文本块大小
  60. overlap_chars = int(data.get("overlap_chars", 50)) # 默认重叠字符数
  61. if not text_list:
  62. return jsonify({"error": "Text is required."}), 400
  63. # 检查模型是否已经加载
  64. if model_name not in model_cache:
  65. try:
  66. model = SentenceTransformer(model_name)
  67. model_cache[model_name] = model # 缓存模型
  68. except Exception as e:
  69. return jsonify({"error": f"Failed to load model: {e}"}), 500
  70. model = model_cache[model_name]
  71. top_k_texts_all = []
  72. top_k_embeddings_all = []
  73. # 如果只有一个查询文本
  74. if len(text_list) == 1:
  75. top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
  76. top_k_texts_all.append(top_k_texts)
  77. top_k_embeddings_all.append(top_k_embeddings)
  78. elif len(text_list) > 1:
  79. # 如果多个查询文本,依次处理
  80. for query in text_list:
  81. top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
  82. top_k_texts_all.append(top_k_texts)
  83. top_k_embeddings_all.append(top_k_embeddings)
  84. # 将嵌入向量(ndarray)转换为可序列化的列表
  85. top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
  86. print(f"Top K texts: {top_k_texts_all}") # 打印检索到的文本
  87. print(f"Top K embeddings: {top_k_embeddings_all}") # 打印检索到的向量
  88. # 返回 JSON 格式的数据
  89. return jsonify({
  90. "topKEmbeddings": top_k_embeddings_all # 返回嵌入向量
  91. })
  92. if __name__ == '__main__':
  93. app.run(host="0.0.0.0", port=5000, debug=True)