three_variant_retrieval_deepseek_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import os
  2. import time
  3. import json
  4. import numpy as np
  5. import faiss
  6. from openai import OpenAI
  7. from text2vec import SentenceModel
  8. from langchain_ollama import ChatOllama
  9. from langchain_core.prompts import ChatPromptTemplate
  10. from langchain_core.output_parsers import StrOutputParser
  11. from langchain_core.runnables import RunnableLambda
  12. import config.config
  13. # 向量存储目录结构
  14. VECTOR_STORE_BASE = config.config.VECTOR_STORE_BASE
  15. MODEL_PATH = config.config.MODEL_PATH
  16. # 加载文本向量化模型
  17. def load_embedding_model(device="cuda"):
  18. """加载text2vec-base-chinese模型"""
  19. print(f"正在加载向量化模型: text2vec-base-chinese")
  20. print(f"使用设备: {device}")
  21. return SentenceModel(MODEL_PATH, device=device)
  22. # 生成三个相似问题变体
  23. def generate_question_variants_deepseek_api(question):
  24. """生成三个不同的问题变体"""
  25. client = OpenAI(api_key=config.config.DEEPSEEK_API, base_url="https://api.deepseek.com")
  26. system_template = """
  27. 你是一个法律问题转译专家。你的任务是将用户的自然语言法律问题转译为三个不同但相关的专业法律表述。
  28. 遵循以下原则:
  29. 1. 第一个变体:提取核心法律概念,使用最规范的法律术语重新表述
  30. 2. 第二个变体:从不同角度或法律分支解读问题
  31. 3. 第三个变体:扩展问题范围,考虑相关联的法律情况
  32. 所有变体都应该:
  33. - 使用专业法律术语
  34. - 保持简洁明了
  35. - 直指法律要点
  36. - 去除无关信息
  37. 请按以下JSON格式输出三个变体:
  38. {
  39. "variant_1": "第一个变体",
  40. "variant_2": "第二个变体",
  41. "variant_3": "第三个变体"
  42. }
  43. 只输出JSON格式的结果,不要包含任何额外的解释或文本。
  44. """
  45. response = client.chat.completions.create(
  46. model="deepseek-chat",
  47. messages=[
  48. {"role": "system", "content": system_template},
  49. {"role": "user", "content": question},
  50. ],
  51. stream=False
  52. )
  53. try:
  54. # 解析JSON响应
  55. variants_text = response.choices[0].message.content
  56. variants = json.loads(variants_text)
  57. return variants
  58. except json.JSONDecodeError:
  59. # 如果返回的不是有效JSON,尝试提取变体
  60. print(f"警告:API返回的不是有效JSON,尝试手动解析:\n{variants_text}")
  61. # 简单回退方案
  62. return {
  63. "variant_1": question,
  64. "variant_2": question,
  65. "variant_3": question
  66. }
  67. # 向量检索相关函数
  68. def load_vector_store(level_dir):
  69. """加载指定层级的向量存储"""
  70. store_dir = os.path.join(VECTOR_STORE_BASE, level_dir)
  71. index_path = os.path.join(store_dir, "index")
  72. if not os.path.exists(f"{index_path}_index"):
  73. raise FileNotFoundError(f"向量索引文件不存在: {index_path}_index")
  74. index = faiss.read_index(f"{index_path}_index")
  75. # 加载向量、文本和元数据
  76. vectors = np.load(f"{index_path}_vectors.npy")
  77. # texts 现在包含的是原始完整文本(包含"第xx条 "前缀)
  78. texts = np.load(f"{index_path}_texts.npy", allow_pickle=True)
  79. metadata = np.load(f"{index_path}_metadata.npy", allow_pickle=True)
  80. print(f"成功加载{level_dir}向量库: {len(vectors)}条记录")
  81. return {
  82. "index": index,
  83. "vectors": vectors,
  84. "texts": texts,
  85. "metadata": metadata
  86. }
  87. def search_in_level(query_embedding, level_dir, top_k=10, store=None):
  88. """在指定层级中搜索,返回top_k个结果"""
  89. if store is None:
  90. store = load_vector_store(level_dir)
  91. # 搜索
  92. distances, indices = store["index"].search(query_embedding.reshape(1, -1), top_k)
  93. results = []
  94. for i, idx in enumerate(indices[0]):
  95. if idx != -1: # Faiss返回-1表示找不到足够的结果
  96. results.append({
  97. "text": store["texts"][idx],
  98. "metadata": store["metadata"][idx],
  99. "score": float(1.0 - distances[0][i]), # 将距离转换为相似度分数
  100. "index": int(idx),
  101. "level": level_dir
  102. })
  103. return results
  104. def rank_final_results(results, original_query_embedding, top_k=5):
  105. """对所有结果重新排序,返回top_k个最相关的结果"""
  106. # 提取所有文本并计算与查询的相似度
  107. texts = [item["text"] for item in results]
  108. # 加载模型
  109. embedding_model = load_embedding_model()
  110. # 计算文本向量
  111. embeddings = embedding_model.encode(texts)
  112. # 计算相似度
  113. similarities = np.dot(embeddings, original_query_embedding.T) / (
  114. np.linalg.norm(embeddings, axis=1) * np.linalg.norm(original_query_embedding)
  115. )
  116. # 更新相似度分数
  117. for i, item in enumerate(results):
  118. item["score"] = float(similarities[i])
  119. # 先按相似度排序,选出前20个结果
  120. sorted_by_similarity = sorted(results, key=lambda x: x["score"], reverse=True)[:20]
  121. # 去重 - 基于文本内容
  122. unique_results = []
  123. seen_texts = set()
  124. for item in sorted_by_similarity:
  125. if item["text"] not in seen_texts:
  126. seen_texts.add(item["text"])
  127. unique_results.append(item)
  128. if len(unique_results) >= top_k:
  129. break
  130. # 对筛选后的top_k个不重复结果按层级优先级排序
  131. level_priority = {"level_1": 1, "level_2": 2, "level_3": 3}
  132. final_results = sorted(unique_results, key=lambda x: level_priority.get(x["level"], 999))
  133. return final_results
  134. def hierarchical_search_for_variant(variant_query, variant_name, embedding_model, stores, l1_results=10, l2_sub_searches=3, final_results=5):
  135. """为单个变体执行分层检索"""
  136. print(f"处理{variant_name}: {variant_query}")
  137. # 对变体进行向量化
  138. query_embedding = embedding_model.encode(variant_query)
  139. # 在level_1中搜索
  140. level_1_results = search_in_level(query_embedding, "level_1", l1_results, stores["level_1"])
  141. # 收集所有结果
  142. all_results = []
  143. all_results.extend(level_1_results)
  144. # # 对level_1结果在level_2中搜索
  145. # level_2_results = []
  146. # for l1_item in level_1_results:
  147. # l1_text_embedding = embedding_model.encode(l1_item["text"])
  148. # l2_items = search_in_level(l1_text_embedding, "level_2", l2_sub_searches, stores["level_2"])
  149. # level_2_results.extend(l2_items)
  150. #
  151. # all_results.extend(level_2_results)
  152. # 对所有结果重新排序,返回最相关的前final_results个
  153. variant_results = rank_final_results(all_results, query_embedding, final_results)
  154. # 添加变体信息
  155. for item in variant_results:
  156. item["variant"] = variant_name
  157. item["variant_query"] = variant_query
  158. return variant_results
  159. def multi_variant_search(user_query, variants_per_query=3, final_results=15, top_display=5):
  160. """执行多变体的向量检索"""
  161. print(f"开始处理用户查询: {user_query}")
  162. # 1. 生成三个查询变体
  163. variants = generate_question_variants_deepseek_api(user_query)
  164. print(f"生成的三个变体:")
  165. for k, v in variants.items():
  166. print(f"{k}: {v}")
  167. # 2. 加载向量化模型
  168. embedding_model = load_embedding_model()
  169. # 3. 原始查询的向量化 (用于最终排序)
  170. original_query_embedding = embedding_model.encode(user_query)
  171. # 4. 预加载所有向量库
  172. stores = {
  173. "level_1": load_vector_store("level_1"),
  174. # "level_2": load_vector_store("level_2"),
  175. # "level_3": load_vector_store("level_3")
  176. }
  177. # 5. 对每个变体执行分层检索
  178. all_variant_results = []
  179. for variant_key, variant_query in variants.items():
  180. variant_results = hierarchical_search_for_variant(
  181. variant_query,
  182. variant_key,
  183. embedding_model,
  184. stores,
  185. l1_results=10, # 每个变体在level_1中搜索的结果数
  186. l2_sub_searches=3, # 每个level_1结果在level_2中搜索的结果数
  187. final_results=5 # 每个变体返回的最终结果数
  188. )
  189. all_variant_results.extend(variant_results)
  190. print(f"所有变体检索结果总量: {len(all_variant_results)}条")
  191. # 6. 汇总所有变体的结果,并按与原始查询的相似度排序
  192. merged_results = []
  193. seen_texts = set()
  194. # 去重
  195. for item in all_variant_results:
  196. if item["text"] not in seen_texts:
  197. seen_texts.add(item["text"])
  198. merged_results.append(item)
  199. # 按原始查询相似度重新排序
  200. for item in merged_results:
  201. text_embedding = embedding_model.encode(item["text"])
  202. similarity = np.dot(text_embedding, original_query_embedding) / (
  203. np.linalg.norm(text_embedding) * np.linalg.norm(original_query_embedding)
  204. )
  205. item["original_score"] = float(similarity)
  206. # 按原始查询相似度排序,选出top结果
  207. final_results_list = sorted(merged_results, key=lambda x: x["original_score"], reverse=True)[:final_results]
  208. # 获取前5条结果,但按照层级优先级排序
  209. level_priority = {"level_1": 1, "level_2": 2, "level_3": 3}
  210. top_display_results = sorted(final_results_list[:top_display], key=lambda x: level_priority.get(x["level"], 999))
  211. return {
  212. "original_query": user_query,
  213. "variants": variants,
  214. "all_results": final_results_list,
  215. "top_results": top_display_results
  216. }
  217. def main():
  218. while True:
  219. user_query = input("\n请输入您的法律问题 (输入'q'退出): ")
  220. if user_query.lower() == 'q':
  221. break
  222. start_time = time.time()
  223. results = multi_variant_search(user_query)
  224. print("\n检索结果:")
  225. print(f"原始问题: {results['original_query']}")
  226. print(f"问题变体:")
  227. for k, v in results['variants'].items():
  228. print(f" {k}: {v}")
  229. print(f"\n总共找到 {len(results['all_results'])} 条相关法律条文")
  230. print(f"耗时: {time.time() - start_time:.2f}秒")
  231. # 先展示所有15条结果
  232. print("\n所有检索结果:")
  233. for i, item in enumerate(results['all_results']):
  234. print(f"\n[{i+1}] 相似度: {item['original_score']:.4f} - 层级: {item['level']} - 变体: {item['variant']}")
  235. print(f"法律ID: {item['metadata'].get('law_id', '未知')}")
  236. print(f"内容: {item['text']}")
  237. # 再展示按层级排序的前5条
  238. print("\n\n按层级优先级排序的前5条结果:")
  239. for i, item in enumerate(results['top_results']):
  240. print(f"\n[{i+1}] 相似度: {item['original_score']:.4f} - 层级: {item['level']} - 变体: {item['variant']}")
  241. print(f"法律ID: {item['metadata'].get('law_id', '未知')}")
  242. print(f"内容: {item['text']}")
  243. if __name__ == "__main__":
  244. main()