diff --git a/hk_case_extractor.py b/hk_case_extractor.py index 3e1a6d7..13e6372 100644 --- a/hk_case_extractor.py +++ b/hk_case_extractor.py @@ -372,28 +372,43 @@ def gather_chunks(text: str, return "\n\n[…]\n\n".join(pieces), len(hits) -def gather_all(text: str) -> dict[str, str]: +def gather_all(text: str, + head_length: int = 5000, + tail_length: int = 5000, + entities_window: int = 400, + entities_max: int = 6500, + analysis_window: int = 500, + analysis_max: int = 6500) -> dict[str, str]: """為每個 group 召回對應的上下文片段 優化策略: - 1. 基礎信息(當事人):直接取開頭5000字符,不使用關鍵詞召回 - 2. 事由與標的(reason_object):直接取開頭5000字符,不使用關鍵詞召回 - 3. 判決結果:取開頭5000字符 + 尾部5000字符 + 1. 基礎信息(當事人):直接取開頭 head_length 字符,不使用關鍵詞召回 + 2. 事由與標的(reason_object):直接取開頭 head_length 字符,不使用關鍵詞召回 + 3. 判決結果:取開頭 head_length 字符 + 尾部 tail_length 字符 4. 其他字段:保持關鍵詞召回策略 + + Args: + text: 判決書全文 + head_length: 開頭截取長度(默認 5000) + tail_length: 尾部截取長度(默認 5000) + entities_window: 實體關鍵詞窗口半徑(默認 400) + entities_max: 實體片段最大總長度(默認 6500) + analysis_window: 分析關鍵詞窗口半徑(默認 500) + analysis_max: 分析片段最大總長度(默認 6500) """ out: dict[str, str] = {} - # 1. 當事人信息:直接從開頭5000字符提取 - out["parties"] = text[:5000] + # 1. 當事人信息:直接從開頭截取 + out["parties"] = text[:head_length] out["_parties_hits"] = "0" # 不使用關鍵詞,標記為0 - # 2. 事由與標的:直接從開頭5000字符提取 - out["reason_object"] = text[:5000] + # 2. 事由與標的:直接從開頭截取 + out["reason_object"] = text[:head_length] out["_reason_object_hits"] = "0" # 不使用關鍵詞,標記為0 - # 3. 判決結果:取開頭5000字符 + 尾部5000字符 - head_text = text[:5000] if len(text) > 5000 else text - tail_text = text[-5000:] if len(text) > 5000 else "" + # 3. 判決結果:取開頭 + 尾部 + head_text = text[:head_length] if len(text) > head_length else text + tail_text = text[-tail_length:] if len(text) > tail_length else "" # 如果文本足夠長,拼接頭尾;否則只用全文 if tail_text and head_text != tail_text: out["judgment_result"] = head_text + "\n\n[…]\n\n" + tail_text @@ -403,8 +418,8 @@ def gather_all(text: str) -> dict[str, str]: # 4. 其他字段:使用關鍵詞召回 params: dict[str, tuple[int, int]] = { - "entities": (400, 6500), - "analysis": (500, 6500), + "entities": (entities_window, entities_max), + "analysis": (analysis_window, analysis_max), } for group in ["entities", "analysis"]: kws = KEYWORD_GROUPS[group] @@ -1108,7 +1123,13 @@ def validate_and_fix(result: dict, lang: str = 'zh') -> tuple[dict, list[str]]: # 5. 主管線 # ============================================================================= -def run_pipeline(text: str, model: str, base_url: str, api_key: str) -> dict: +def run_pipeline(text: str, model: str, base_url: str, api_key: str, + head_length: int = 5000, + tail_length: int = 5000, + entities_window: int = 400, + entities_max: int = 6500, + analysis_window: int = 500, + analysis_max: int = 6500) -> dict: log = lambda m: print(m, file=sys.stderr) log("[0/7] 檢測語言...") @@ -1118,7 +1139,9 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str) -> dict: log("[1/7] 預處理 + 關鍵詞召回...") text = clean_text(text) meta = extract_metadata_by_rule(text, lang) - ctx = gather_all(text) + ctx = gather_all(text, head_length, tail_length, + entities_window, entities_max, + analysis_window, analysis_max) log(f" 規則元數據:{meta}") log(f" 召回片段:") @@ -1229,6 +1252,13 @@ def main() -> None: --base-url https://api.openai.com/v1 \\ --model gpt-4 \\ --api-key your-api-key + + # 調整截取長度 + python hk_case_extractor.py case.txt \\ + --head-length 8000 \\ + --tail-length 8000 \\ + --entities-max 10000 \\ + --analysis-max 10000 """, formatter_class=argparse.RawDescriptionHelpFormatter ) @@ -1242,6 +1272,21 @@ def main() -> None: ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout)") ap.add_argument("--debug-dump", default=None, help="額外輸出原始 JSON 結果到該路徑(便於 diff)") + + # 截取長度控制參數 + ap.add_argument("--head-length", type=int, default=5000, + help="開頭截取長度(默認:5000)") + ap.add_argument("--tail-length", type=int, default=5000, + help="尾部截取長度(默認:5000)") + ap.add_argument("--entities-window", type=int, default=400, + help="實體關鍵詞窗口半徑(默認:400)") + ap.add_argument("--entities-max", type=int, default=6500, + help="實體片段最大總長度(默認:6500)") + ap.add_argument("--analysis-window", type=int, default=500, + help="分析關鍵詞窗口半徑(默認:500)") + ap.add_argument("--analysis-max", type=int, default=6500, + help="分析片段最大總長度(默認:6500)") + args = ap.parse_args() # 支持从 .json 文件的 content 字段读取 @@ -1256,7 +1301,10 @@ def main() -> None: else: text = input_path.read_text(encoding="utf-8") - result = run_pipeline(text, args.model, args.base_url, args.api_key) + result = run_pipeline(text, args.model, args.base_url, args.api_key, + args.head_length, args.tail_length, + args.entities_window, args.entities_max, + args.analysis_window, args.analysis_max) if args.debug_dump: Path(args.debug_dump).write_text(