retrieval.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import os
  2. import time
  3. import json
  4. import numpy as np
  5. import faiss
  6. from text2vec import SentenceModel
  7. from openai import OpenAI
  8. from langchain_core.prompts import ChatPromptTemplate
  9. from langchain_core.output_parsers import StrOutputParser
  10. from langchain_core.runnables import RunnableLambda
  11. from langchain_openai import ChatOpenAI
  12. import config.config
  13. # 向量存储目录结构
  14. VECTOR_STORE_BASE = config.config.VECTOR_STORE_BASE
  15. MODEL_PATH = config.config.MODEL_PATH
  16. # DeepSeek API配置
  17. DEEPSEEK_API_KEY = config.config.DEEPSEEK_API # 请替换为您的实际API密钥
  18. DEEPSEEK_BASE_URL = "https://api.deepseek.com"
  19. # 加载文本向量化模型
  20. def load_embedding_model(device="cuda"):
  21. """加载text2vec-base-chinese模型"""
  22. print(f"正在加载向量化模型: text2vec-base-chinese")
  23. print(f"使用设备: {device}")
  24. return SentenceModel(MODEL_PATH, device=device)
  25. # 转译用户问题的链
  26. def build_question_translation_chain():
  27. """构建用于转译用户问题的LLM链"""
  28. llm = ChatOpenAI(
  29. api_key=DEEPSEEK_API_KEY,
  30. base_url=DEEPSEEK_BASE_URL,
  31. model="deepseek-chat",
  32. temperature=0.3,
  33. )
  34. system_template = """
  35. 你是一个法律问题转译专家。你的任务是将用户的自然语言法律问题转译为更加专业、简洁、直接的形式,以便于向量检索系统能够找到最相关的法律条文。
  36. 遵循以下原则:
  37. 1. 提取问题中的关键法律概念和术语
  38. 2. 使用法律专业术语重新表述问题
  39. 3. 去除无关信息,保留核心法律问题
  40. 4. 确保转译后的问题简洁明了,直指法律要点
  41. 请直接输出转译后的问题,不要包含额外解释。
  42. """
  43. prompt = ChatPromptTemplate(
  44. [
  45. ("system", system_template),
  46. ("human", "{question}"),
  47. ]
  48. )
  49. return prompt | llm | StrOutputParser()
  50. # 使用DeepSeek API直接调用
  51. def translate_question_with_deepseek(question):
  52. """使用DeepSeek API直接调用模型进行问题转译"""
  53. client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_BASE_URL)
  54. system_template = """
  55. 你是一个法律问题转译专家。你的任务是将用户的自然语言法律问题转译为更加专业、简洁、直接的形式,以便于向量检索系统能够找到最相关的法律条文。
  56. 遵循以下原则:
  57. 1. 提取问题中的关键法律概念和术语
  58. 2. 使用法律专业术语重新表述问题
  59. 3. 去除无关信息,保留核心法律问题
  60. 4. 确保转译后的问题简洁明了,直指法律要点
  61. 请直接输出转译后的问题,不要包含额外解释。
  62. """
  63. response = client.chat.completions.create(
  64. model="deepseek-chat",
  65. messages=[
  66. {"role": "system", "content": system_template},
  67. {"role": "user", "content": question},
  68. ],
  69. temperature=0.3,
  70. stream=False
  71. )
  72. return response.choices[0].message.content
  73. # 向量检索相关函数
  74. def load_vector_store(level_dir):
  75. """加载指定层级的向量存储"""
  76. store_dir = os.path.join(VECTOR_STORE_BASE, level_dir)
  77. index_path = os.path.join(store_dir, "index")
  78. if not os.path.exists(f"{index_path}_index"):
  79. raise FileNotFoundError(f"向量索引文件不存在: {index_path}_index")
  80. index = faiss.read_index(f"{index_path}_index")
  81. # 加载向量、文本和元数据
  82. vectors = np.load(f"{index_path}_vectors.npy")
  83. # texts 现在包含的是原始完整文本(包含"第xx条 "前缀)
  84. texts = np.load(f"{index_path}_texts.npy", allow_pickle=True)
  85. metadata = np.load(f"{index_path}_metadata.npy", allow_pickle=True)
  86. print(f"成功加载{level_dir}向量库: {len(vectors)}条记录")
  87. return {
  88. "index": index,
  89. "vectors": vectors,
  90. "texts": texts,
  91. "metadata": metadata
  92. }
  93. def search_in_level(query_embedding, level_dir, top_k=10, store=None):
  94. """在指定层级中搜索,返回top_k个结果"""
  95. if store is None:
  96. store = load_vector_store(level_dir)
  97. # 搜索
  98. distances, indices = store["index"].search(query_embedding.reshape(1, -1), top_k)
  99. results = []
  100. for i, idx in enumerate(indices[0]):
  101. if idx != -1: # Faiss返回-1表示找不到足够的结果
  102. results.append({
  103. "text": store["texts"][idx],
  104. "metadata": store["metadata"][idx],
  105. "score": float(1.0 - distances[0][i]), # 将距离转换为相似度分数
  106. "index": int(idx),
  107. "level": level_dir
  108. })
  109. return results
  110. def rank_final_results(results, query_embedding, top_k=5):
  111. """对所有结果重新排序,返回top_k个最相关的结果"""
  112. # 提取所有文本并计算与查询的相似度
  113. texts = [item["text"] for item in results]
  114. # 加载模型
  115. embedding_model = load_embedding_model()
  116. # 计算文本向量
  117. embeddings = embedding_model.encode(texts)
  118. # 计算相似度
  119. similarities = np.dot(embeddings, query_embedding.T) / (
  120. np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding)
  121. )
  122. # 更新相似度分数
  123. for i, item in enumerate(results):
  124. item["score"] = float(similarities[i])
  125. # 先按相似度排序,选出前20个结果
  126. sorted_by_similarity = sorted(results, key=lambda x: x["score"], reverse=True)[:20]
  127. # 去重 - 基于文本内容
  128. unique_results = []
  129. seen_texts = set()
  130. for item in sorted_by_similarity:
  131. if item["text"] not in seen_texts:
  132. seen_texts.add(item["text"])
  133. unique_results.append(item)
  134. if len(unique_results) >= top_k:
  135. break
  136. # 对筛选后的top_k个不重复结果按层级优先级排序
  137. level_priority = {"level_1": 1, "level_2": 2, "level_3": 3}
  138. final_results = sorted(unique_results, key=lambda x: level_priority.get(x["level"], 999))
  139. return final_results
  140. def hierarchical_search(user_query, l1_results=10, l2_sub_searches=3, l3_sub_searches=1, final_results=5):
  141. """执行三层次的向量检索"""
  142. print(f"开始处理用户查询: {user_query}")
  143. # 1. 转译用户查询 - 使用DeepSeek API
  144. translated_query = translate_question_with_deepseek(user_query)
  145. print(f"转译后的查询: {translated_query}")
  146. # 2. 加载向量化模型
  147. embedding_model = load_embedding_model()
  148. # 3. 对转译后的查询进行向量化
  149. query_embedding = embedding_model.encode(translated_query)
  150. # 4. 加载各层级向量库
  151. level_1_store = load_vector_store("level_1")
  152. level_2_store = load_vector_store("level_2")
  153. level_3_store = load_vector_store("level_3")
  154. # 5. 在level_1中搜索前l1_results个结果
  155. level_1_results = search_in_level(query_embedding, "level_1", l1_results, level_1_store)
  156. print(f"一级检索结果: {len(level_1_results)}条")
  157. # 6. 对每个level_1结果在level_2中搜索
  158. all_results = []
  159. all_results.extend(level_1_results)
  160. level_2_results = []
  161. for l1_item in level_1_results:
  162. # 用level_1中的文本创建查询向量
  163. l1_text_embedding = embedding_model.encode(l1_item["text"])
  164. # 在level_2中查找相关内容
  165. l2_items = search_in_level(l1_text_embedding, "level_2", l2_sub_searches, level_2_store)
  166. level_2_results.extend(l2_items)
  167. print(f"二级检索结果: {len(level_2_results)}条")
  168. all_results.extend(level_2_results)
  169. # 7. 对每个level_2结果在level_3中搜索
  170. # level_3_results = []
  171. # for l2_item in level_2_results:
  172. # # 用level_2中的文本创建查询向量
  173. # l2_text_embedding = embedding_model.encode(l2_item["text"])
  174. # # 在level_3中查找相关内容
  175. # l3_items = search_in_level(l2_text_embedding, "level_3", l3_sub_searches, level_3_store)
  176. # level_3_results.extend(l3_items)
  177. # print(f"二级检索结果: {len(level_3_results)}条")
  178. # all_results.extend(level_3_results)
  179. print(f"所有层级检索结果总量: {len(all_results)}条")
  180. # 8. 对所有结果重新排序,选取最相关的前final_results个
  181. final_results = rank_final_results(all_results, query_embedding, final_results)
  182. print(f"最终筛选结果: {len(final_results)}条")
  183. return {
  184. "original_query": user_query,
  185. "translated_query": translated_query,
  186. "results": final_results
  187. }
  188. def main():
  189. while True:
  190. user_query = input("\n请输入您的法律问题 (输入'q'退出): ")
  191. if user_query.lower() == 'q':
  192. break
  193. start_time = time.time()
  194. results = hierarchical_search(user_query)
  195. print("\n检索结果:")
  196. print(f"原始问题: {results['original_query']}")
  197. print(f"转译问题: {results['translated_query']}")
  198. print(f"耗时: {time.time() - start_time:.2f}秒")
  199. for i, item in enumerate(results['results']):
  200. print(f"\n[{i+1}] 相似度: {item['score']:.4f} - 层级: {item['level']}")
  201. print(f"法律ID: {item['metadata'].get('law_id', '未知')}")
  202. print(f"内容: {item['text']}") # 这里输出的是原始完整文本
  203. if __name__ == "__main__":
  204. main()