db.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import os
  2. import json
  3. from datetime import datetime
  4. from typing import List, Dict, Any
  5. try:
  6. import pymysql
  7. except Exception as exc:
  8. pymysql = None
  9. _MYSQL_IMPORT_ERROR = str(exc)
  10. def get_mysql_config() -> Dict[str, Any]:
  11. return {
  12. "host": os.getenv("MYSQL_HOST", "127.0.0.1"),
  13. "port": int(os.getenv("MYSQL_PORT", "3306")),
  14. "user": os.getenv("MYSQL_USER", "root"),
  15. "password": os.getenv("MYSQL_PASSWORD", "123456"),
  16. "database": os.getenv("MYSQL_DATABASE", "arbitration_system"),
  17. "charset": "utf8mb4"
  18. }
  19. def get_mysql_connection(use_database: bool = True):
  20. if pymysql is None:
  21. raise RuntimeError(f"MySQL驱动不可用: {_MYSQL_IMPORT_ERROR}")
  22. cfg = get_mysql_config()
  23. if not use_database:
  24. cfg = {k: v for k, v in cfg.items() if k != "database"}
  25. return pymysql.connect(**cfg)
  26. def index_exists(cursor, database: str, table: str, index_name: str) -> bool:
  27. cursor.execute(
  28. """
  29. SELECT 1
  30. FROM information_schema.statistics
  31. WHERE table_schema=%s AND table_name=%s AND index_name=%s
  32. LIMIT 1
  33. """,
  34. (database, table, index_name)
  35. )
  36. return cursor.fetchone() is not None
  37. def init_case_db() -> None:
  38. cfg = get_mysql_config()
  39. conn = get_mysql_connection(use_database=False)
  40. try:
  41. cursor = conn.cursor()
  42. cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{cfg['database']}` CHARACTER SET utf8mb4")
  43. conn.commit()
  44. finally:
  45. conn.close()
  46. conn = get_mysql_connection(use_database=True)
  47. try:
  48. cursor = conn.cursor()
  49. cursor.execute(
  50. """
  51. CREATE TABLE IF NOT EXISTS cases (
  52. id BIGINT PRIMARY KEY AUTO_INCREMENT,
  53. case_id VARCHAR(255),
  54. summary_text TEXT,
  55. case_profile_json LONGTEXT,
  56. dispute_points_json LONGTEXT,
  57. law_results_json LONGTEXT,
  58. evidence_results_json LONGTEXT,
  59. final_judgement_json LONGTEXT,
  60. embedding_json LONGTEXT,
  61. created_at DATETIME
  62. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  63. """
  64. )
  65. if not index_exists(cursor, cfg["database"], "cases", "idx_cases_case_id"):
  66. cursor.execute("CREATE INDEX idx_cases_case_id ON cases(case_id)")
  67. cursor.execute(
  68. """
  69. CREATE TABLE IF NOT EXISTS case_management (
  70. id BIGINT PRIMARY KEY AUTO_INCREMENT,
  71. case_id VARCHAR(255) UNIQUE,
  72. title VARCHAR(255),
  73. description TEXT,
  74. status VARCHAR(50),
  75. stage VARCHAR(50),
  76. updated_at DATETIME
  77. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
  78. """
  79. )
  80. if not index_exists(cursor, cfg["database"], "case_management", "idx_case_management_case_id"):
  81. cursor.execute("CREATE INDEX idx_case_management_case_id ON case_management(case_id)")
  82. conn.commit()
  83. finally:
  84. conn.close()
  85. def store_case_record(
  86. case_id: str,
  87. summary_text: str,
  88. case_profile: Dict[str, Any],
  89. dispute_points: List[str],
  90. law_results: Dict[str, Any],
  91. evidence_results: Dict[str, Any],
  92. final_judgement: Dict[str, Any],
  93. embedding: List[float]
  94. ) -> None:
  95. conn = get_mysql_connection(use_database=True)
  96. try:
  97. cursor = conn.cursor()
  98. cursor.execute(
  99. """
  100. INSERT INTO cases (
  101. case_id, summary_text, case_profile_json, dispute_points_json,
  102. law_results_json, evidence_results_json, final_judgement_json,
  103. embedding_json, created_at
  104. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
  105. """,
  106. (
  107. case_id,
  108. summary_text,
  109. json.dumps(case_profile, ensure_ascii=False),
  110. json.dumps(dispute_points, ensure_ascii=False),
  111. json.dumps(law_results, ensure_ascii=False),
  112. json.dumps(evidence_results, ensure_ascii=False),
  113. json.dumps(final_judgement, ensure_ascii=False),
  114. json.dumps(embedding, ensure_ascii=False),
  115. datetime.utcnow()
  116. )
  117. )
  118. conn.commit()
  119. finally:
  120. conn.close()
  121. def fetch_similar_cases(embedding: List[float], top_k: int = 3) -> List[Dict[str, Any]]:
  122. conn = get_mysql_connection(use_database=True)
  123. try:
  124. cursor = conn.cursor()
  125. cursor.execute(
  126. """
  127. SELECT case_id, summary_text, final_judgement_json, embedding_json
  128. FROM cases
  129. WHERE embedding_json IS NOT NULL AND embedding_json != ''
  130. """
  131. )
  132. rows = cursor.fetchall()
  133. finally:
  134. conn.close()
  135. scored = []
  136. for case_id, summary_text, final_judgement_json, embedding_json in rows:
  137. try:
  138. vec = json.loads(embedding_json)
  139. except Exception:
  140. vec = []
  141. scored.append(
  142. {
  143. "case_id": case_id,
  144. "summary_text": summary_text,
  145. "final_judgement_json": final_judgement_json,
  146. "embedding": vec
  147. }
  148. )
  149. return scored
  150. def parse_case_description(description: str) -> Dict[str, Any]:
  151. if not description or not isinstance(description, str):
  152. return {}
  153. try:
  154. data = json.loads(description)
  155. return data if isinstance(data, dict) else {}
  156. except Exception:
  157. return {}
  158. def normalize_case_description(case_id: str, description: str) -> str:
  159. if not case_id:
  160. return description
  161. existing = fetch_case_management(case_id)
  162. existing_desc = existing.get("description", "") if existing else ""
  163. existing_data = parse_case_description(existing_desc)
  164. new_data = parse_case_description(description)
  165. if existing_data.get("materials") and not new_data.get("materials"):
  166. desc_text = description or existing_data.get("description", "")
  167. merged = {"description": desc_text, "materials": existing_data.get("materials", {})}
  168. return json.dumps(merged, ensure_ascii=False)
  169. return description
  170. def upsert_case_management(case_id: str, title: str, description: str, status: str, stage: str) -> Dict[str, Any]:
  171. conn = get_mysql_connection(use_database=True)
  172. try:
  173. cursor = conn.cursor()
  174. description = normalize_case_description(case_id, description)
  175. cursor.execute(
  176. """
  177. INSERT INTO case_management (case_id, title, description, status, stage, updated_at)
  178. VALUES (%s, %s, %s, %s, %s, %s)
  179. ON DUPLICATE KEY UPDATE
  180. title=VALUES(title),
  181. description=VALUES(description),
  182. status=VALUES(status),
  183. stage=VALUES(stage),
  184. updated_at=VALUES(updated_at)
  185. """,
  186. (case_id, title, description, status, stage, datetime.utcnow())
  187. )
  188. conn.commit()
  189. finally:
  190. conn.close()
  191. return {"case_id": case_id, "title": title, "description": description, "status": status, "stage": stage}
  192. def update_case_materials(case_id: str, materials: Dict[str, Any]) -> None:
  193. if not case_id:
  194. return
  195. existing = fetch_case_management(case_id)
  196. existing_desc = existing.get("description", "") if existing else ""
  197. existing_data = parse_case_description(existing_desc)
  198. description_text = existing_desc if not existing_data else existing_data.get("description", "")
  199. payload = {
  200. "description": description_text,
  201. "materials": materials or {}
  202. }
  203. conn = get_mysql_connection(use_database=True)
  204. try:
  205. cursor = conn.cursor()
  206. cursor.execute(
  207. """
  208. UPDATE case_management
  209. SET description=%s, updated_at=%s
  210. WHERE case_id=%s
  211. """,
  212. (json.dumps(payload, ensure_ascii=False), datetime.utcnow(), case_id)
  213. )
  214. conn.commit()
  215. finally:
  216. conn.close()
  217. def list_case_management() -> List[Dict[str, Any]]:
  218. conn = get_mysql_connection(use_database=True)
  219. try:
  220. cursor = conn.cursor()
  221. cursor.execute(
  222. """
  223. SELECT case_id, title, description, status, stage, updated_at
  224. FROM case_management
  225. ORDER BY updated_at DESC
  226. """
  227. )
  228. rows = cursor.fetchall()
  229. finally:
  230. conn.close()
  231. results = []
  232. for case_id, title, description, status, stage, updated_at in rows:
  233. results.append(
  234. {
  235. "case_id": case_id,
  236. "title": title,
  237. "description": description,
  238. "status": status,
  239. "stage": stage,
  240. "updated_at": updated_at.isoformat() if updated_at else ""
  241. }
  242. )
  243. return results
  244. def fetch_case_management(case_id: str) -> Dict[str, Any]:
  245. if not case_id:
  246. return {}
  247. conn = get_mysql_connection(use_database=True)
  248. try:
  249. cursor = conn.cursor()
  250. cursor.execute(
  251. """
  252. SELECT case_id, title, description, status, stage, updated_at
  253. FROM case_management
  254. WHERE case_id=%s
  255. LIMIT 1
  256. """,
  257. (case_id,)
  258. )
  259. row = cursor.fetchone()
  260. finally:
  261. conn.close()
  262. if not row:
  263. return {}
  264. case_id, title, description, status, stage, updated_at = row
  265. return {
  266. "case_id": case_id,
  267. "title": title,
  268. "description": description or "",
  269. "status": status,
  270. "stage": stage,
  271. "updated_at": updated_at.isoformat() if updated_at else ""
  272. }
  273. def fetch_case_record(case_id: str) -> Dict[str, Any]:
  274. if not case_id:
  275. return {}
  276. conn = get_mysql_connection(use_database=True)
  277. try:
  278. cursor = conn.cursor()
  279. cursor.execute(
  280. """
  281. SELECT summary_text, case_profile_json, dispute_points_json, law_results_json,
  282. evidence_results_json, final_judgement_json, embedding_json, created_at
  283. FROM cases
  284. WHERE case_id=%s
  285. ORDER BY created_at DESC
  286. LIMIT 1
  287. """,
  288. (case_id,)
  289. )
  290. row = cursor.fetchone()
  291. finally:
  292. conn.close()
  293. if not row:
  294. return {}
  295. (
  296. summary_text,
  297. case_profile_json,
  298. dispute_points_json,
  299. law_results_json,
  300. evidence_results_json,
  301. final_judgement_json,
  302. embedding_json,
  303. created_at
  304. ) = row
  305. def parse_json(value: str, fallback):
  306. if not value:
  307. return fallback
  308. try:
  309. parsed = json.loads(value)
  310. return parsed if parsed is not None else fallback
  311. except Exception:
  312. return fallback
  313. return {
  314. "summary_text": summary_text or "",
  315. "case_profile": parse_json(case_profile_json, {}),
  316. "dispute_points": parse_json(dispute_points_json, []),
  317. "law_results": parse_json(law_results_json, {}),
  318. "evidence_results": parse_json(evidence_results_json, {}),
  319. "final_judgement": parse_json(final_judgement_json, {}),
  320. "embedding": parse_json(embedding_json, []),
  321. "created_at": created_at.isoformat() if created_at else ""
  322. }
  323. def delete_case_management(case_id: str) -> None:
  324. conn = get_mysql_connection(use_database=True)
  325. try:
  326. cursor = conn.cursor()
  327. cursor.execute("DELETE FROM case_management WHERE case_id=%s", (case_id,))
  328. conn.commit()
  329. finally:
  330. conn.close()