main
fengruixiang 2026-05-15 12:52:42 +08:00
parent 09f5be0180
commit 4a32806c13
1 changed files with 64 additions and 16 deletions

View File

@ -372,28 +372,43 @@ def gather_chunks(text: str,
return "\n\n[…]\n\n".join(pieces), len(hits)
def gather_all(text: str) -> dict[str, str]:
def gather_all(text: str,
head_length: int = 5000,
tail_length: int = 5000,
entities_window: int = 400,
entities_max: int = 6500,
analysis_window: int = 500,
analysis_max: int = 6500) -> dict[str, str]:
"""為每個 group 召回對應的上下文片段
優化策略
1. 基礎信息當事人直接取開頭5000字符不使用關鍵詞召回
2. 事由與標的reason_object直接取開頭5000字符不使用關鍵詞召回
3. 判決結果取開頭5000字符 + 尾部5000字符
1. 基礎信息當事人直接取開頭 head_length 字符不使用關鍵詞召回
2. 事由與標的reason_object直接取開頭 head_length 字符不使用關鍵詞召回
3. 判決結果取開頭 head_length 字符 + 尾部 tail_length 字符
4. 其他字段保持關鍵詞召回策略
Args:
text: 判決書全文
head_length: 開頭截取長度默認 5000
tail_length: 尾部截取長度默認 5000
entities_window: 實體關鍵詞窗口半徑默認 400
entities_max: 實體片段最大總長度默認 6500
analysis_window: 分析關鍵詞窗口半徑默認 500
analysis_max: 分析片段最大總長度默認 6500
"""
out: dict[str, str] = {}
# 1. 當事人信息直接從開頭5000字符提取
out["parties"] = text[:5000]
# 1. 當事人信息:直接從開頭
out["parties"] = text[:head_length]
out["_parties_hits"] = "0" # 不使用關鍵詞標記為0
# 2. 事由與標的直接從開頭5000字符提取
out["reason_object"] = text[:5000]
# 2. 事由與標的:直接從開頭
out["reason_object"] = text[:head_length]
out["_reason_object_hits"] = "0" # 不使用關鍵詞標記為0
# 3. 判決結果:取開頭5000字符 + 尾部5000字符
head_text = text[:5000] if len(text) > 5000 else text
tail_text = text[-5000:] if len(text) > 5000 else ""
# 3. 判決結果:取開頭 + 尾部
head_text = text[:head_length] if len(text) > head_length else text
tail_text = text[-tail_length:] if len(text) > tail_length else ""
# 如果文本足夠長,拼接頭尾;否則只用全文
if tail_text and head_text != tail_text:
out["judgment_result"] = head_text + "\n\n[…]\n\n" + tail_text
@ -403,8 +418,8 @@ def gather_all(text: str) -> dict[str, str]:
# 4. 其他字段:使用關鍵詞召回
params: dict[str, tuple[int, int]] = {
"entities": (400, 6500),
"analysis": (500, 6500),
"entities": (entities_window, entities_max),
"analysis": (analysis_window, analysis_max),
}
for group in ["entities", "analysis"]:
kws = KEYWORD_GROUPS[group]
@ -1108,7 +1123,13 @@ def validate_and_fix(result: dict, lang: str = 'zh') -> tuple[dict, list[str]]:
# 5. 主管線
# =============================================================================
def run_pipeline(text: str, model: str, base_url: str, api_key: str) -> dict:
def run_pipeline(text: str, model: str, base_url: str, api_key: str,
head_length: int = 5000,
tail_length: int = 5000,
entities_window: int = 400,
entities_max: int = 6500,
analysis_window: int = 500,
analysis_max: int = 6500) -> dict:
log = lambda m: print(m, file=sys.stderr)
log("[0/7] 檢測語言...")
@ -1118,7 +1139,9 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str) -> dict:
log("[1/7] 預處理 + 關鍵詞召回...")
text = clean_text(text)
meta = extract_metadata_by_rule(text, lang)
ctx = gather_all(text)
ctx = gather_all(text, head_length, tail_length,
entities_window, entities_max,
analysis_window, analysis_max)
log(f" 規則元數據:{meta}")
log(f" 召回片段:")
@ -1229,6 +1252,13 @@ def main() -> None:
--base-url https://api.openai.com/v1 \\
--model gpt-4 \\
--api-key your-api-key
# 調整截取長度
python hk_case_extractor.py case.txt \\
--head-length 8000 \\
--tail-length 8000 \\
--entities-max 10000 \\
--analysis-max 10000
""",
formatter_class=argparse.RawDescriptionHelpFormatter
)
@ -1242,6 +1272,21 @@ def main() -> None:
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout")
ap.add_argument("--debug-dump", default=None,
help="額外輸出原始 JSON 結果到該路徑(便於 diff")
# 截取長度控制參數
ap.add_argument("--head-length", type=int, default=5000,
help="開頭截取長度默認5000")
ap.add_argument("--tail-length", type=int, default=5000,
help="尾部截取長度默認5000")
ap.add_argument("--entities-window", type=int, default=400,
help="實體關鍵詞窗口半徑默認400")
ap.add_argument("--entities-max", type=int, default=6500,
help="實體片段最大總長度默認6500")
ap.add_argument("--analysis-window", type=int, default=500,
help="分析關鍵詞窗口半徑默認500")
ap.add_argument("--analysis-max", type=int, default=6500,
help="分析片段最大總長度默認6500")
args = ap.parse_args()
# 支持从 .json 文件的 content 字段读取
@ -1256,7 +1301,10 @@ def main() -> None:
else:
text = input_path.read_text(encoding="utf-8")
result = run_pipeline(text, args.model, args.base_url, args.api_key)
result = run_pipeline(text, args.model, args.base_url, args.api_key,
args.head_length, args.tail_length,
args.entities_window, args.entities_max,
args.analysis_window, args.analysis_max)
if args.debug_dump:
Path(args.debug_dump).write_text(