high_level_retrieval.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. import os
  2. import time
  3. import json
  4. import sys
  5. import numpy as np
  6. from openai import OpenAI
  7. import re
  8. import config.config
  9. # 直接导入模块,而不是通过包导入
  10. current_dir = os.path.dirname(os.path.abspath(__file__))
  11. sys.path.append(current_dir)
  12. import three_variant_retrieval_deepseek_api
  13. multi_variant_search = three_variant_retrieval_deepseek_api.multi_variant_search
  14. class HighLevelRetriever:
  15. def __init__(self):
  16. """初始化高层次检索器"""
  17. self.client = OpenAI(api_key=config.config.DEEPSEEK_API, base_url="https://api.deepseek.com")
  18. print("高层次检索器初始化完成")
  19. def decompose_query(self, user_query):
  20. """将复杂查询分解为多个原子查询"""
  21. system_template = """
  22. 你是一个法律问题分析专家。现在需要你将一个复杂的法律问题分解成三个简单的原子问题。
  23. 请遵循以下原则:
  24. 1. 识别复杂问题中的多个法律概念或问题点
  25. 2. 将每个概念或问题点转化为一个独立的、简洁的原子问题
  26. 3. 确保原子问题涵盖原始复杂问题的所有关键方面
  27. 4. 每个原子问题应该是明确的、可搜索的
  28. 请输出JSON格式的结果:
  29. {
  30. "atomic_queries": [
  31. {
  32. "query": "原子问题1",
  33. "aspect": "这个问题关注的法律方面"
  34. },
  35. {
  36. "query": "原子问题2",
  37. "aspect": "这个问题关注的法律方面"
  38. },
  39. {
  40. "query": "原子问题3",
  41. "aspect": "这个问题关注的法律方面"
  42. }
  43. ]
  44. }
  45. 只输出JSON格式的结果,不要包含任何额外的解释或文本。
  46. """
  47. try:
  48. print("正在分解查询...")
  49. response = self.client.chat.completions.create(
  50. model="deepseek-chat",
  51. messages=[
  52. {"role": "system", "content": system_template},
  53. {"role": "user", "content": user_query},
  54. ],
  55. stream=False,
  56. timeout=60 # 增加超时时间
  57. )
  58. atomic_queries_text = response.choices[0].message.content
  59. print(f"API返回的分解结果: {atomic_queries_text[:200]}...")
  60. # 尝试提取JSON部分
  61. json_match = re.search(r'```(?:json)?\s*(.*?)\s*```', atomic_queries_text, re.DOTALL)
  62. if json_match:
  63. atomic_queries_text = json_match.group(1)
  64. # 尝试加载JSON
  65. try:
  66. atomic_queries = json.loads(atomic_queries_text)
  67. if "atomic_queries" not in atomic_queries:
  68. raise ValueError("返回的JSON缺少'atomic_queries'字段")
  69. except:
  70. # 尝试进行第二次解析,寻找可能的JSON结构
  71. json_pattern = r'({[\s\S]*})'
  72. match = re.search(json_pattern, atomic_queries_text)
  73. if match:
  74. try:
  75. atomic_queries = json.loads(match.group(1))
  76. except:
  77. raise ValueError("无法解析有效的JSON结构")
  78. else:
  79. raise ValueError("无法找到JSON结构")
  80. # 确保至少返回了一个查询
  81. if not atomic_queries.get("atomic_queries") or len(atomic_queries["atomic_queries"]) == 0:
  82. raise ValueError("没有生成有效的原子查询")
  83. return atomic_queries
  84. except Exception as e:
  85. print(f"查询分解失败: {e}")
  86. print("使用备选分解方法...")
  87. # 备选分解方法:简单地将原始查询拆分为几个关键词查询
  88. words = user_query.split()
  89. # 如果查询很短,直接使用原始查询
  90. if len(words) <= 5:
  91. return {
  92. "atomic_queries": [
  93. {"query": user_query, "aspect": "主要问题"}
  94. ]
  95. }
  96. # 否则,尝试提取几个子查询
  97. return {
  98. "atomic_queries": [
  99. {"query": user_query, "aspect": "完整问题"},
  100. {"query": " ".join(words[:len(words)//2]), "aspect": "问题前半部分"},
  101. {"query": " ".join(words[len(words)//2:]), "aspect": "问题后半部分"}
  102. ]
  103. }
  104. def validate_results(self, user_query, all_results):
  105. """验证检索结果是否足以回答用户查询"""
  106. if not all_results:
  107. print("没有找到任何法律条文,验证失败")
  108. return False, {"is_sufficient": False, "missing_aspects": ["未找到相关法律条文"], "explanation": "未找到相关法律条文"}
  109. # 提取所有检索到的法律条文文本
  110. result_texts = [item["text"] for item in all_results]
  111. combined_results = "\n\n".join(result_texts[:15]) # 限制文本长度,只取前15条
  112. system_template = """
  113. 你是一个法律验证专家。你的任务是分析检索到的法律条文是否足以回答用户的原始问题。
  114. 请评估以下几点:
  115. 1. 检索到的法律条文是否涵盖了用户问题的所有方面
  116. 2. 是否存在问题中提及但法律条文中未涉及的关键概念
  117. 3. 是否需要额外的法律信息来完整回答问题
  118. 请输出JSON格式的结果:
  119. {
  120. "is_sufficient": true/false,
  121. "missing_aspects": ["缺失方面1", "缺失方面2", ...],
  122. "explanation": "简要解释为什么结果不足以回答问题或为什么足够"
  123. }
  124. 如果条文确实完全足够回答问题,请将"is_sufficient"设置为true,否则设为false。
  125. 只输出JSON格式的结果,不要包含任何额外的解释或文本。
  126. """
  127. try:
  128. print("开始验证检索结果...")
  129. response = self.client.chat.completions.create(
  130. model="deepseek-chat",
  131. messages=[
  132. {"role": "system", "content": system_template},
  133. {"role": "user", "content": f"用户问题:\n{user_query}\n\n检索到的法律条文:\n{combined_results}"},
  134. ],
  135. stream=False,
  136. timeout=60 # 增加超时时间
  137. )
  138. validation_text = response.choices[0].message.content
  139. print(f"API返回的验证结果: {validation_text[:200]}...")
  140. # 尝试提取JSON部分
  141. json_match = re.search(r'```(?:json)?\s*(.*?)\s*```', validation_text, re.DOTALL)
  142. if json_match:
  143. validation_text = json_match.group(1)
  144. # 尝试解析JSON
  145. try:
  146. validation = json.loads(validation_text)
  147. except:
  148. # 尝试进行第二次解析,寻找可能的JSON结构
  149. json_pattern = r'({[\s\S]*})'
  150. match = re.search(json_pattern, validation_text)
  151. if match:
  152. try:
  153. validation = json.loads(match.group(1))
  154. except:
  155. print("无法解析有效的验证结果JSON")
  156. # 如果解析失败,基于结果数量和文本进行简单判断
  157. if len(all_results) >= 15:
  158. return True, {"is_sufficient": True, "missing_aspects": [], "explanation": "已检索到足够的法律条文"}
  159. elif "足够" in validation_text or "充分" in validation_text:
  160. return True, {"is_sufficient": True, "missing_aspects": [], "explanation": "API返回值判断为足够"}
  161. else:
  162. return False, {"is_sufficient": False, "missing_aspects": ["API返回解析失败"], "explanation": "无法解析API返回值"}
  163. else:
  164. # 如果找不到JSON结构,基于结果数量进行判断
  165. if len(all_results) >= 15:
  166. return True, {"is_sufficient": True, "missing_aspects": [], "explanation": "已检索到足够的法律条文"}
  167. else:
  168. return False, {"is_sufficient": False, "missing_aspects": ["API返回解析失败"], "explanation": "无法解析API返回值"}
  169. # 确保验证结果包含所需字段
  170. if "is_sufficient" not in validation:
  171. print("验证结果缺少'is_sufficient'字段,检查其他信息...")
  172. # 如果关键字段缺失,使用启发式方法判断
  173. if len(all_results) >= 15:
  174. return True, {"is_sufficient": True, "missing_aspects": [], "explanation": "已检索到大量法律条文"}
  175. elif validation.get("missing_aspects") and len(validation.get("missing_aspects", [])) > 0:
  176. return False, {"is_sufficient": False, "missing_aspects": validation.get("missing_aspects", ["未指明的缺失"]), "explanation": validation.get("explanation", "需要额外信息")}
  177. else:
  178. return False, {"is_sufficient": False, "missing_aspects": ["验证结果不完整"], "explanation": "需要额外信息"}
  179. return validation.get("is_sufficient", False), validation
  180. except Exception as e:
  181. print(f"验证过程出错: {e}")
  182. # 出错时的回退策略:根据结果数量进行简单判断
  183. if len(all_results) >= 20:
  184. return True, {"is_sufficient": True, "missing_aspects": [], "explanation": "已检索到大量法律条文,假定足够"}
  185. elif len(all_results) >= 10:
  186. return False, {"is_sufficient": False, "missing_aspects": ["可能需要更多条文"], "explanation": "验证失败,需要更多信息"}
  187. else:
  188. return False, {"is_sufficient": False, "missing_aspects": ["条文数量不足"], "explanation": "检索到的条文数量较少,需要补充"}
  189. def generate_supplementary_queries(self, user_query, all_results, validation_result):
  190. """根据验证结果生成补充查询"""
  191. # 提取所有检索到的法律条文文本
  192. result_texts = [item["text"] for item in all_results]
  193. combined_results = "\n\n".join(result_texts[:8]) # 限制文本长度
  194. missing_aspects = validation_result.get("missing_aspects", [])
  195. if not missing_aspects:
  196. missing_aspects = ["未明确指出的缺失信息"]
  197. missing_aspects_text = "\n".join([f"- {aspect}" for aspect in missing_aspects])
  198. system_template = """
  199. 你是一个法律问题补充专家。基于用户的原始问题和已检索到的法律条文,你需要生成补充查询来获取缺失的信息。
  200. 请考虑以下因素:
  201. 1. 已检索到的法律条文中缺少哪些关键信息
  202. 2. 哪些额外的法律概念需要被检索
  203. 3. 如何构建精确的补充查询以获取这些缺失信息
  204. 请输出JSON格式的结果:
  205. {
  206. "supplementary_queries": [
  207. {
  208. "query": "补充查询1",
  209. "purpose": "这个查询的目的是什么"
  210. },
  211. {
  212. "query": "补充查询2",
  213. "purpose": "这个查询的目的是什么"
  214. },
  215. ...
  216. ]
  217. }
  218. 请确保生成2-3个具体的补充查询,以获取缺失的法律信息。
  219. 只输出JSON格式的结果,不要包含任何额外的解释或文本。
  220. """
  221. try:
  222. print("生成补充查询...")
  223. response = self.client.chat.completions.create(
  224. model="deepseek-chat",
  225. messages=[
  226. {"role": "system", "content": system_template},
  227. {"role": "user", "content": f"用户问题:\n{user_query}\n\n已检索到的法律条文:\n{combined_results}\n\n缺失的方面:\n{missing_aspects_text}"},
  228. ],
  229. stream=False,
  230. timeout=60 # 增加超时时间
  231. )
  232. supplementary_text = response.choices[0].message.content
  233. print(f"API返回的补充查询结果: {supplementary_text[:200]}...")
  234. # 尝试提取JSON部分
  235. json_match = re.search(r'```(?:json)?\s*(.*?)\s*```', supplementary_text, re.DOTALL)
  236. if json_match:
  237. supplementary_text = json_match.group(1)
  238. # 尝试解析JSON
  239. try:
  240. supplementary = json.loads(supplementary_text)
  241. supplementary_queries = supplementary.get("supplementary_queries", [])
  242. # 确保至少有一个补充查询
  243. if not supplementary_queries:
  244. raise ValueError("没有生成补充查询")
  245. return supplementary_queries
  246. except:
  247. # 尝试进行第二次解析,寻找可能的JSON结构
  248. json_pattern = r'({[\s\S]*})'
  249. match = re.search(json_pattern, supplementary_text)
  250. if match:
  251. try:
  252. supplementary = json.loads(match.group(1))
  253. supplementary_queries = supplementary.get("supplementary_queries", [])
  254. if supplementary_queries:
  255. return supplementary_queries
  256. except:
  257. print("无法解析有效的补充查询JSON")
  258. # 如果仍然无法解析,使用备选方法
  259. raise ValueError("无法解析补充查询JSON")
  260. except Exception as e:
  261. print(f"生成补充查询失败: {e}")
  262. # 根据缺失方面生成默认补充查询
  263. default_queries = []
  264. # 为每个缺失方面生成查询
  265. for aspect in missing_aspects[:2]: # 最多使用两个缺失方面
  266. default_queries.append({
  267. "query": f"关于{aspect}的法律规定",
  268. "purpose": f"查找关于{aspect}的缺失信息"
  269. })
  270. # 如果没有缺失方面或生成查询,添加通用查询
  271. if not default_queries:
  272. words = user_query.split()
  273. if len(words) >= 4:
  274. # 使用问题的另一部分
  275. half = len(words) // 2
  276. default_queries.append({
  277. "query": " ".join(words[:half]) + "法律规定",
  278. "purpose": "查找与问题前半部分相关的法律条文"
  279. })
  280. default_queries.append({
  281. "query": " ".join(words[half:]) + "法律依据",
  282. "purpose": "查找与问题后半部分相关的法律条文"
  283. })
  284. else:
  285. # 简单问题,添加通用补充查询
  286. default_queries.append({
  287. "query": f"{user_query}的法律依据",
  288. "purpose": "查找相关法律依据"
  289. })
  290. default_queries.append({
  291. "query": f"{user_query}的相关规定",
  292. "purpose": "查找更多相关法律规定"
  293. })
  294. return default_queries
  295. def search(self, user_query, max_iterations=3): # 恢复为3次迭代
  296. """执行高层次搜索"""
  297. print(f"开始处理复杂查询: {user_query}")
  298. start_time = time.time()
  299. # 第一步:分解查询
  300. decomposed = self.decompose_query(user_query)
  301. atomic_queries = decomposed.get("atomic_queries", [])
  302. print(f"将复杂查询分解为{len(atomic_queries)}个原子查询:")
  303. for i, query in enumerate(atomic_queries):
  304. print(f" [{i+1}] {query['query']} (关注点: {query['aspect']})")
  305. # 存储所有检索结果
  306. all_results = []
  307. # 记录搜索过程信息
  308. search_logs = {
  309. "initial_decomposition": atomic_queries,
  310. "iterations": []
  311. }
  312. # 对每个原子查询执行低层次检索
  313. for atomic_query in atomic_queries:
  314. try:
  315. print(f"\n执行原子查询: {atomic_query['query']}")
  316. results = multi_variant_search(atomic_query['query'])
  317. all_results.extend(results['all_results'])
  318. except Exception as e:
  319. print(f"执行原子查询失败: {e}")
  320. continue
  321. # 去重
  322. unique_results = []
  323. seen_texts = set()
  324. for item in all_results:
  325. if item["text"] not in seen_texts:
  326. seen_texts.add(item["text"])
  327. item["source"] = "初始查询"
  328. unique_results.append(item)
  329. # 如果没有找到任何结果,直接返回
  330. if not unique_results:
  331. return {
  332. "original_query": user_query,
  333. "atomic_queries": atomic_queries,
  334. "results": [],
  335. "total_results": 0,
  336. "time_taken": time.time() - start_time,
  337. "search_logs": search_logs
  338. }
  339. all_results = unique_results
  340. # 迭代补充查询过程
  341. iteration = 1
  342. while iteration <= max_iterations and len(all_results) < 50: # 限制结果总数
  343. print(f"\n开始第{iteration}轮验证和补充...")
  344. iteration_log = {
  345. "iteration": iteration,
  346. "results_before": len(all_results),
  347. "validation": {},
  348. "supplementary_queries": []
  349. }
  350. # 验证当前结果是否足够
  351. try:
  352. is_sufficient, validation_result = self.validate_results(user_query, all_results)
  353. iteration_log["validation"] = validation_result
  354. if is_sufficient:
  355. print("验证通过,检索到的法律条文足以回答用户查询")
  356. search_logs["iterations"].append(iteration_log)
  357. break
  358. explanation = validation_result.get("explanation", "未提供原因")
  359. print(f"验证未通过: {explanation}")
  360. missing_aspects = validation_result.get("missing_aspects", [])
  361. if missing_aspects:
  362. print(f"缺失方面: {', '.join(missing_aspects)}")
  363. # 生成补充查询
  364. supplementary_queries = self.generate_supplementary_queries(user_query, all_results, validation_result)
  365. iteration_log["supplementary_queries"] = supplementary_queries
  366. # 如果没有生成补充查询,退出循环
  367. if not supplementary_queries:
  368. print("未生成补充查询,结束迭代")
  369. search_logs["iterations"].append(iteration_log)
  370. break
  371. print(f"生成{len(supplementary_queries)}个补充查询:")
  372. for i, query in enumerate(supplementary_queries):
  373. print(f" [{i+1}] {query['query']} (目的: {query['purpose']})")
  374. # 执行补充查询
  375. supplementary_results = []
  376. for supp_query in supplementary_queries:
  377. try:
  378. print(f"\n执行补充查询: {supp_query['query']}")
  379. results = multi_variant_search(supp_query['query'])
  380. for item in results['all_results']:
  381. item["source"] = f"补充查询 (轮次 {iteration})"
  382. item["purpose"] = supp_query['purpose']
  383. supplementary_results.extend(results['all_results'])
  384. except Exception as e:
  385. print(f"执行补充查询失败: {e}")
  386. continue
  387. # 合并结果并去重
  388. new_items_count = 0
  389. for item in supplementary_results:
  390. if item["text"] not in seen_texts:
  391. seen_texts.add(item["text"])
  392. all_results.append(item)
  393. new_items_count += 1
  394. iteration_log["new_items_added"] = new_items_count
  395. iteration_log["results_after"] = len(all_results)
  396. search_logs["iterations"].append(iteration_log)
  397. print(f"第{iteration}轮补充查询添加了{new_items_count}条新的法律条文")
  398. # 如果没有添加新条文,增加一次额外的补充查询机会
  399. if new_items_count == 0 and iteration < max_iterations:
  400. print("未找到新的法律条文,尝试更广泛的查询...")
  401. broad_query = f"{user_query}相关法律"
  402. try:
  403. print(f"\n执行广泛查询: {broad_query}")
  404. results = multi_variant_search(broad_query)
  405. for item in results['all_results']:
  406. item["source"] = f"广泛查询 (轮次 {iteration})"
  407. if item["text"] not in seen_texts:
  408. seen_texts.add(item["text"])
  409. all_results.append(item)
  410. new_items_count += 1
  411. print(f"广泛查询添加了{new_items_count}条新的法律条文")
  412. except Exception as e:
  413. print(f"执行广泛查询失败: {e}")
  414. iteration += 1
  415. except Exception as e:
  416. print(f"第{iteration}轮验证和补充失败: {e}")
  417. search_logs["error"] = str(e)
  418. break
  419. # 按相关性排序
  420. final_results = sorted(all_results, key=lambda x: x.get("original_score", 0), reverse=True)
  421. end_time = time.time()
  422. return {
  423. "original_query": user_query,
  424. "atomic_queries": atomic_queries,
  425. "results": final_results,
  426. "total_results": len(final_results),
  427. "time_taken": end_time - start_time,
  428. "search_logs": search_logs
  429. }
  430. def main():
  431. try:
  432. retriever = HighLevelRetriever()
  433. while True:
  434. try:
  435. user_query = input("\n请输入您的复杂法律问题 (输入'q'退出): ")
  436. if user_query.lower() == 'q':
  437. break
  438. results = retriever.search(user_query)
  439. print(f"\n总共找到 {results['total_results']} 条相关法律条文")
  440. print(f"耗时: {results['time_taken']:.2f}秒")
  441. # 输出搜索日志
  442. search_logs = results.get("search_logs", {})
  443. iterations = search_logs.get("iterations", [])
  444. if iterations:
  445. print("\n搜索过程详情:")
  446. print(f"初始分解: {len(search_logs.get('initial_decomposition', []))}个原子查询")
  447. for iteration in iterations:
  448. print(f"第{iteration.get('iteration')}轮: 验证{'通过' if iteration.get('validation', {}).get('is_sufficient', False) else '未通过'}")
  449. if 'new_items_added' in iteration:
  450. print(f" 添加了{iteration.get('new_items_added')}条新法律条文")
  451. if results['total_results'] == 0:
  452. print("未找到相关法律条文。请尝试使用其他关键词或更详细的问题描述。")
  453. continue
  454. # 输出所有检索到的法律条文
  455. print("\n检索到的法律条文:")
  456. for i, item in enumerate(results["results"]):
  457. print(f"\n[{i+1}] 相似度: {item.get('original_score', 0):.4f}")
  458. print(f"来源: {item.get('source', '未知')}")
  459. if "purpose" in item:
  460. print(f"目的: {item['purpose']}")
  461. print(f"法律ID: {item['metadata'].get('law_id', '未知')}")
  462. print(f"内容: {item['text']}")
  463. # 每10条暂停显示
  464. if (i + 1) % 10 == 0 and i + 1 < len(results["results"]):
  465. try:
  466. continue_viewing = input("\n继续查看更多条文? (y/n): ")
  467. if continue_viewing.lower() != 'y':
  468. break
  469. except KeyboardInterrupt:
  470. print("\n显示中断")
  471. break
  472. except KeyboardInterrupt:
  473. print("\n操作已中断。您可以继续输入新的查询,或输入'q'退出。")
  474. continue
  475. except Exception as e:
  476. print(f"\n处理查询时出错: {e}")
  477. continue
  478. except Exception as e:
  479. print(f"程序出错: {e}")
  480. if __name__ == "__main__":
  481. main()