retrieval_deepseek_api.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import os
  2. import time
  3. import json
  4. import numpy as np
  5. import faiss
  6. from openai import OpenAI
  7. from text2vec import SentenceModel
  8. from langchain_ollama import ChatOllama
  9. from langchain_core.prompts import ChatPromptTemplate
  10. from langchain_core.output_parsers import StrOutputParser
  11. from langchain_core.runnables import RunnableLambda
  12. import config.config
  13. # 向量存储目录结构
  14. VECTOR_STORE_BASE = config.config.VECTOR_STORE_BASE
  15. MODEL_PATH = config.config.MODEL_PATH
  16. # 加载文本向量化模型
  17. def load_embedding_model(device="cuda"):
  18. """加载text2vec-base-chinese模型"""
  19. print(f"正在加载向量化模型: text2vec-base-chinese")
  20. print(f"使用设备: {device}")
  21. return SentenceModel(MODEL_PATH, device=device)
  22. # 转译用户问题的链
  23. def question_translation_deepseek_api(question):
  24. """转译用户问题"""
  25. client = OpenAI(api_key=config.config.DEEPSEEK_API, base_url="https://api.deepseek.com")
  26. system_template = """
  27. 你是一个法律问题转译专家。你的任务是将用户的自然语言法律问题转译为更加专业、简洁、直接的形式,以便于向量检索系统能够找到最相关的法律条文。
  28. 遵循以下原则:
  29. 1. 提取问题中的关键法律概念和术语
  30. 2. 使用法律专业术语重新表述问题
  31. 3. 去除无关信息,保留核心法律问题
  32. 4. 确保转译后的问题简洁明了,直指法律要点
  33. 请直接输出转译后的问题,不要包含额外解释。
  34. """
  35. response = client.chat.completions.create(
  36. model="deepseek-chat",
  37. messages=[
  38. {"role": "system", "content": system_template},
  39. {"role": "user", "content": question},
  40. ],
  41. stream=False
  42. )
  43. return response.choices[0].message.content
  44. # 向量检索相关函数
  45. def load_vector_store(level_dir):
  46. """加载指定层级的向量存储"""
  47. store_dir = os.path.join(VECTOR_STORE_BASE, level_dir)
  48. index_path = os.path.join(store_dir, "index")
  49. if not os.path.exists(f"{index_path}_index"):
  50. raise FileNotFoundError(f"向量索引文件不存在: {index_path}_index")
  51. index = faiss.read_index(f"{index_path}_index")
  52. # 加载向量、文本和元数据
  53. vectors = np.load(f"{index_path}_vectors.npy")
  54. # texts 现在包含的是原始完整文本(包含"第xx条 "前缀)
  55. texts = np.load(f"{index_path}_texts.npy", allow_pickle=True)
  56. metadata = np.load(f"{index_path}_metadata.npy", allow_pickle=True)
  57. print(f"成功加载{level_dir}向量库: {len(vectors)}条记录")
  58. return {
  59. "index": index,
  60. "vectors": vectors,
  61. "texts": texts,
  62. "metadata": metadata
  63. }
  64. def search_in_level(query_embedding, level_dir, top_k=10, store=None):
  65. """在指定层级中搜索,返回top_k个结果"""
  66. if store is None:
  67. store = load_vector_store(level_dir)
  68. # 搜索
  69. distances, indices = store["index"].search(query_embedding.reshape(1, -1), top_k)
  70. results = []
  71. for i, idx in enumerate(indices[0]):
  72. if idx != -1: # Faiss返回-1表示找不到足够的结果
  73. results.append({
  74. "text": store["texts"][idx],
  75. "metadata": store["metadata"][idx],
  76. "score": float(1.0 - distances[0][i]), # 将距离转换为相似度分数
  77. "index": int(idx),
  78. "level": level_dir
  79. })
  80. return results
  81. def rank_final_results(results, query_embedding, top_k=5):
  82. """对所有结果重新排序,返回top_k个最相关的结果"""
  83. # 提取所有文本并计算与查询的相似度
  84. texts = [item["text"] for item in results]
  85. # 加载模型
  86. embedding_model = load_embedding_model()
  87. # 计算文本向量
  88. embeddings = embedding_model.encode(texts)
  89. # 计算相似度
  90. similarities = np.dot(embeddings, query_embedding.T) / (
  91. np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding)
  92. )
  93. # 更新相似度分数
  94. for i, item in enumerate(results):
  95. item["score"] = float(similarities[i])
  96. # 先按相似度排序,选出前20个结果
  97. sorted_by_similarity = sorted(results, key=lambda x: x["score"], reverse=True)[:20]
  98. # 去重 - 基于文本内容
  99. unique_results = []
  100. seen_texts = set()
  101. for item in sorted_by_similarity:
  102. if item["text"] not in seen_texts:
  103. seen_texts.add(item["text"])
  104. unique_results.append(item)
  105. if len(unique_results) >= top_k:
  106. break
  107. # 对筛选后的top_k个不重复结果按层级优先级排序
  108. level_priority = {"level_1": 1, "level_2": 2, "level_3": 3}
  109. final_results = sorted(unique_results, key=lambda x: level_priority.get(x["level"], 999))
  110. return final_results
  111. def hierarchical_search(user_query, l1_results=10, l2_sub_searches=3, l3_sub_searches=1, final_results=5):
  112. """执行三层次的向量检索"""
  113. print(f"开始处理用户查询: {user_query}")
  114. # 1. 转译用户查询
  115. translated_query = question_translation_deepseek_api(user_query)
  116. print(f"转译后的查询: {translated_query}")
  117. # 2. 加载向量化模型
  118. embedding_model = load_embedding_model()
  119. # 3. 对转译后的查询进行向量化
  120. query_embedding = embedding_model.encode(translated_query)
  121. # 4. 加载各层级向量库
  122. level_1_store = load_vector_store("level_1")
  123. level_2_store = load_vector_store("level_2")
  124. level_3_store = load_vector_store("level_3")
  125. # 5. 在level_1中搜索前l1_results个结果
  126. level_1_results = search_in_level(query_embedding, "level_1", l1_results, level_1_store)
  127. print(f"一级检索结果: {len(level_1_results)}条")
  128. # 6. 对每个level_1结果在level_2中搜索
  129. all_results = []
  130. all_results.extend(level_1_results)
  131. level_2_results = []
  132. for l1_item in level_1_results:
  133. # 用level_1中的文本创建查询向量
  134. l1_text_embedding = embedding_model.encode(l1_item["text"])
  135. # 在level_2中查找相关内容
  136. l2_items = search_in_level(l1_text_embedding, "level_2", l2_sub_searches, level_2_store)
  137. level_2_results.extend(l2_items)
  138. print(f"二级检索结果: {len(level_2_results)}条")
  139. all_results.extend(level_2_results)
  140. # 7. 对每个level_2结果在level_3中搜索
  141. # level_3_results = []
  142. # for l2_item in level_2_results:
  143. # # 用level_2中的文本创建查询向量
  144. # l2_text_embedding = embedding_model.encode(l2_item["text"])
  145. # # 在level_3中查找相关内容
  146. # l3_items = search_in_level(l2_text_embedding, "level_3", l3_sub_searches, level_3_store)
  147. # level_3_results.extend(l3_items)
  148. # print(f"二级检索结果: {len(level_3_results)}条")
  149. # all_results.extend(level_3_results)
  150. print(f"所有层级检索结果总量: {len(all_results)}条")
  151. # 8. 对所有结果重新排序,选取最相关的前final_results个
  152. final_results = rank_final_results(all_results, query_embedding, final_results)
  153. print(f"最终筛选结果: {len(final_results)}条")
  154. return {
  155. "original_query": user_query,
  156. "translated_query": translated_query,
  157. "results": final_results
  158. }
  159. def main():
  160. while True:
  161. user_query = input("\n请输入您的法律问题 (输入'q'退出): ")
  162. if user_query.lower() == 'q':
  163. break
  164. start_time = time.time()
  165. results = hierarchical_search(user_query)
  166. print("\n检索结果:")
  167. print(f"原始问题: {results['original_query']}")
  168. print(f"转译问题: {results['translated_query']}")
  169. print(f"耗时: {time.time() - start_time:.2f}秒")
  170. for i, item in enumerate(results['results']):
  171. print(f"\n[{i+1}] 相似度: {item['score']:.4f} - 层级: {item['level']}")
  172. print(f"法律ID: {item['metadata'].get('law_id', '未知')}")
  173. print(f"内容: {item['text']}") # 这里输出的是原始完整文本
  174. if __name__ == "__main__":
  175. main()