vectrization_text2vec.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import os
  2. import time
  3. import json
  4. import shutil
  5. import re
  6. from text2vec import SentenceModel
  7. from langchain_core.documents import Document
  8. import numpy as np
  9. import torch
  10. import faiss
  11. import config.config
  12. # 向量存储目录结构
  13. VECTOR_STORE_BASE = config.config.VECTOR_STORE_BASE
  14. FILE_STORAGE_BASE = config.config.FILE_STORAGE_BASE
  15. MODEL_PATH = config.config.MODEL_PATH
  16. # 层级目录列表
  17. LEVEL_DIRS = ["level_1", "level_2", "level_3"]
  18. def extract_content_from_law(text):
  19. """
  20. 从法律条文中提取具体内容
  21. 例如:"第三十二条 劳动者拒绝用人单位管理人员违章指挥、强令冒险作业的,不视为违反劳动合同。"
  22. 将提取为:"劳动者拒绝用人单位管理人员违章指挥、强令冒险作业的,不视为违反劳动合同。"
  23. """
  24. # 使用正则表达式匹配"第xx条 "模式,其中xx可以是各种中文数字
  25. pattern = r'^第[一二三四五六七八九十百千万零]+条\s+'
  26. # 替换匹配的内容为空字符串
  27. content = re.sub(pattern, '', text)
  28. return content
  29. def load_json_documents(json_file_path: str):
  30. """加载并处理JSON文件,将其转换为Document对象列表"""
  31. print(f"正在加载JSON文件: {json_file_path}")
  32. try:
  33. with open(json_file_path, 'r', encoding='utf-8') as f:
  34. data = json.load(f)
  35. documents = []
  36. for item in data:
  37. # JSON结构: {"law_id": "xxx", "text": "xxx"}
  38. text = item.get("text", "")
  39. # 提取具体内容,去掉"第xx条 "前缀
  40. content = extract_content_from_law(text)
  41. # 确保所有元数据都被保存,包括原始text和law_id
  42. metadata = {k: v for k, v in item.items()}
  43. metadata["source"] = json_file_path # 添加源文件信息
  44. metadata["original_text"] = text # 保存原始文本
  45. doc = Document(
  46. page_content=content, # 只用提取后的具体内容进行向量化
  47. metadata=metadata # 保存完整的元数据,包括完整的text和law_id
  48. )
  49. documents.append(doc)
  50. print(f"成功从JSON文件加载 {len(documents)} 条记录")
  51. return documents
  52. except Exception as e:
  53. print(f"加载JSON文件失败: {str(e)}")
  54. return []
  55. def load_py_documents(py_file_path: str):
  56. """
  57. 加载并处理 law_list.py 文件,将其中的 docs 列表转换为 Document 对象列表。
  58. 假设文件内容格式为:docs = ["条文1", "条文2", ...]
  59. """
  60. print(f"正在加载Python文件: {py_file_path}")
  61. try:
  62. # 读取文件内容并执行,提取 docs 变量
  63. with open(py_file_path, 'r', encoding='utf-8') as f:
  64. file_content = f.read()
  65. # 创建一个安全的命名空间,通过 exec 提取 docs
  66. namespace = {}
  67. exec(file_content, namespace)
  68. docs = namespace.get('docs', [])
  69. if not isinstance(docs, list):
  70. print(f"文件中未找到列表变量 'docs' 或类型错误")
  71. return []
  72. documents = []
  73. base_name = os.path.splitext(os.path.basename(py_file_path))[0] # 文件名(不含扩展名)
  74. for idx, text in enumerate(docs):
  75. if not isinstance(text, str):
  76. continue
  77. # 生成 law_id,例如 law_list_0
  78. law_id = f"{base_name}第{idx+1}条"
  79. # 提取具体内容,去掉"第xx条 "前缀
  80. content = extract_content_from_law(text)
  81. metadata = {
  82. "law_id": law_id,
  83. "source": py_file_path,
  84. "original_text": text # 保存原始文本
  85. }
  86. doc = Document(
  87. page_content=content,
  88. metadata=metadata
  89. )
  90. documents.append(doc)
  91. print(f"成功从Python文件加载 {len(documents)} 条记录")
  92. return documents
  93. except Exception as e:
  94. print(f"加载Python文件失败: {str(e)}")
  95. return []
  96. def clear_directory(directory):
  97. """清空目录中的所有文件,但保留目录结构"""
  98. if os.path.exists(directory):
  99. print(f"清空目录: {directory}")
  100. for filename in os.listdir(directory):
  101. file_path = os.path.join(directory, filename)
  102. try:
  103. if os.path.isfile(file_path):
  104. os.unlink(file_path)
  105. print(f" 已删除文件: {filename}")
  106. elif os.path.isdir(file_path):
  107. shutil.rmtree(file_path)
  108. print(f" 已删除子目录: {filename}")
  109. except Exception as e:
  110. print(f" 删除 {file_path} 失败: {e}")
  111. else:
  112. os.makedirs(directory, exist_ok=True)
  113. print(f"创建目录: {directory}")
  114. def create_vector_store(documents, save_dir):
  115. """创建向量数据库并保存到指定目录"""
  116. # 清空并重新创建目录
  117. clear_directory(save_dir)
  118. save_file = os.path.join(save_dir, "index")
  119. print(f"开始创建向量库: {save_file}")
  120. print(f"使用模型: text2vec-base-chinese")
  121. print(f"文档数量: {len(documents)}")
  122. # 检查GPU是否可用
  123. device = "cuda" if torch.cuda.is_available() else "cpu"
  124. print(f"使用设备: {device}")
  125. try:
  126. start_time = time.time()
  127. # 加载模型
  128. model = SentenceModel(MODEL_PATH, device=device)
  129. # 提取文本内容 - 仅使用page_content(处理后的具体内容)进行向量化
  130. texts = [doc.page_content for doc in documents]
  131. # 保存完整元数据,以便检索时能够还原所有字段
  132. metadata = [doc.metadata for doc in documents]
  133. np.save(f"{save_file}_metadata.npy", metadata)
  134. # 同时保存原始的完整文本,用于最终展示
  135. original_texts = [doc.metadata.get("original_text", doc.page_content) for doc in documents]
  136. np.save(f"{save_file}_texts.npy", original_texts)
  137. # 生成向量
  138. embeddings = model.encode(texts)
  139. # 保存向量
  140. np.save(f"{save_file}_vectors.npy", embeddings)
  141. # 创建Faiss索引
  142. dimension = embeddings.shape[1]
  143. index = faiss.IndexFlatL2(dimension)
  144. index.add(embeddings)
  145. # 保存Faiss索引
  146. faiss.write_index(index, f"{save_file}_index")
  147. print(f"\n向量化完成!耗时 {time.time() - start_time:.2f} 秒")
  148. print(f"向量维度:{dimension}")
  149. print(f"向量数量:{len(embeddings)}")
  150. print(f"数据保存位置:{save_file}")
  151. return True
  152. except Exception as e:
  153. print(f"向量化失败:{str(e)}")
  154. import traceback
  155. traceback.print_exc()
  156. return False
  157. def process_level_dir(level_dir):
  158. """处理指定层级目录中的JSON或Python文件"""
  159. source_dir = os.path.join(FILE_STORAGE_BASE, level_dir)
  160. target_dir = os.path.join(VECTOR_STORE_BASE, level_dir)
  161. print("=" * 50)
  162. print(f"开始处理 {level_dir} 目录下的文件")
  163. print(f"源目录: {source_dir}")
  164. print(f"目标目录: {target_dir}")
  165. # 确保源目录存在
  166. if not os.path.exists(source_dir):
  167. print(f"源目录 {source_dir} 不存在,跳过处理")
  168. return False
  169. all_documents = []
  170. # 1. 优先处理JSON文件(原有逻辑)
  171. json_files = [f for f in os.listdir(source_dir) if f.lower().endswith('.json')]
  172. for json_file in json_files:
  173. json_path = os.path.join(source_dir, json_file)
  174. docs = load_json_documents(json_path)
  175. all_documents.extend(docs)
  176. # 2. 如果没有JSON文件,尝试处理Python文件(如 law_list.py)
  177. if not json_files:
  178. py_files = [f for f in os.listdir(source_dir) if f.lower().endswith('.py')]
  179. for py_file in py_files:
  180. py_path = os.path.join(source_dir, py_file)
  181. docs = load_py_documents(py_path)
  182. all_documents.extend(docs)
  183. if not all_documents:
  184. print(f"目录 {source_dir} 中未找到有效文档,跳过处理")
  185. return False
  186. # 创建向量存储
  187. success = create_vector_store(all_documents, target_dir)
  188. print(f"{level_dir} 处理完成!共处理 {len(all_documents)} 条记录")
  189. print("=" * 50)
  190. return success
  191. def main():
  192. print("=" * 50)
  193. print("开始分层向量化处理...")
  194. print(f"文件存储基础目录: {FILE_STORAGE_BASE}")
  195. print(f"向量存储基础目录: {VECTOR_STORE_BASE}")
  196. print("=" * 50)
  197. # 确保向量存储基础目录存在
  198. os.makedirs(VECTOR_STORE_BASE, exist_ok=True)
  199. success_count = 0
  200. # 依次处理每个层级目录
  201. for level_dir in LEVEL_DIRS:
  202. if process_level_dir(level_dir):
  203. success_count += 1
  204. print("=" * 50)
  205. if success_count > 0:
  206. print(f"成功处理了 {success_count}/{len(LEVEL_DIRS)} 个层级目录")
  207. else:
  208. print("所有层级目录处理失败")
  209. print("=" * 50)
  210. if __name__ == "__main__":
  211. main()