update
parent
09f5be0180
commit
4a32806c13
|
|
@ -372,28 +372,43 @@ def gather_chunks(text: str,
|
|||
return "\n\n[…]\n\n".join(pieces), len(hits)
|
||||
|
||||
|
||||
def gather_all(text: str) -> dict[str, str]:
|
||||
def gather_all(text: str,
|
||||
head_length: int = 5000,
|
||||
tail_length: int = 5000,
|
||||
entities_window: int = 400,
|
||||
entities_max: int = 6500,
|
||||
analysis_window: int = 500,
|
||||
analysis_max: int = 6500) -> dict[str, str]:
|
||||
"""為每個 group 召回對應的上下文片段
|
||||
|
||||
優化策略:
|
||||
1. 基礎信息(當事人):直接取開頭5000字符,不使用關鍵詞召回
|
||||
2. 事由與標的(reason_object):直接取開頭5000字符,不使用關鍵詞召回
|
||||
3. 判決結果:取開頭5000字符 + 尾部5000字符
|
||||
1. 基礎信息(當事人):直接取開頭 head_length 字符,不使用關鍵詞召回
|
||||
2. 事由與標的(reason_object):直接取開頭 head_length 字符,不使用關鍵詞召回
|
||||
3. 判決結果:取開頭 head_length 字符 + 尾部 tail_length 字符
|
||||
4. 其他字段:保持關鍵詞召回策略
|
||||
|
||||
Args:
|
||||
text: 判決書全文
|
||||
head_length: 開頭截取長度(默認 5000)
|
||||
tail_length: 尾部截取長度(默認 5000)
|
||||
entities_window: 實體關鍵詞窗口半徑(默認 400)
|
||||
entities_max: 實體片段最大總長度(默認 6500)
|
||||
analysis_window: 分析關鍵詞窗口半徑(默認 500)
|
||||
analysis_max: 分析片段最大總長度(默認 6500)
|
||||
"""
|
||||
out: dict[str, str] = {}
|
||||
|
||||
# 1. 當事人信息:直接從開頭5000字符提取
|
||||
out["parties"] = text[:5000]
|
||||
# 1. 當事人信息:直接從開頭截取
|
||||
out["parties"] = text[:head_length]
|
||||
out["_parties_hits"] = "0" # 不使用關鍵詞,標記為0
|
||||
|
||||
# 2. 事由與標的:直接從開頭5000字符提取
|
||||
out["reason_object"] = text[:5000]
|
||||
# 2. 事由與標的:直接從開頭截取
|
||||
out["reason_object"] = text[:head_length]
|
||||
out["_reason_object_hits"] = "0" # 不使用關鍵詞,標記為0
|
||||
|
||||
# 3. 判決結果:取開頭5000字符 + 尾部5000字符
|
||||
head_text = text[:5000] if len(text) > 5000 else text
|
||||
tail_text = text[-5000:] if len(text) > 5000 else ""
|
||||
# 3. 判決結果:取開頭 + 尾部
|
||||
head_text = text[:head_length] if len(text) > head_length else text
|
||||
tail_text = text[-tail_length:] if len(text) > tail_length else ""
|
||||
# 如果文本足夠長,拼接頭尾;否則只用全文
|
||||
if tail_text and head_text != tail_text:
|
||||
out["judgment_result"] = head_text + "\n\n[…]\n\n" + tail_text
|
||||
|
|
@ -403,8 +418,8 @@ def gather_all(text: str) -> dict[str, str]:
|
|||
|
||||
# 4. 其他字段:使用關鍵詞召回
|
||||
params: dict[str, tuple[int, int]] = {
|
||||
"entities": (400, 6500),
|
||||
"analysis": (500, 6500),
|
||||
"entities": (entities_window, entities_max),
|
||||
"analysis": (analysis_window, analysis_max),
|
||||
}
|
||||
for group in ["entities", "analysis"]:
|
||||
kws = KEYWORD_GROUPS[group]
|
||||
|
|
@ -1108,7 +1123,13 @@ def validate_and_fix(result: dict, lang: str = 'zh') -> tuple[dict, list[str]]:
|
|||
# 5. 主管線
|
||||
# =============================================================================
|
||||
|
||||
def run_pipeline(text: str, model: str, base_url: str, api_key: str) -> dict:
|
||||
def run_pipeline(text: str, model: str, base_url: str, api_key: str,
|
||||
head_length: int = 5000,
|
||||
tail_length: int = 5000,
|
||||
entities_window: int = 400,
|
||||
entities_max: int = 6500,
|
||||
analysis_window: int = 500,
|
||||
analysis_max: int = 6500) -> dict:
|
||||
log = lambda m: print(m, file=sys.stderr)
|
||||
|
||||
log("[0/7] 檢測語言...")
|
||||
|
|
@ -1118,7 +1139,9 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str) -> dict:
|
|||
log("[1/7] 預處理 + 關鍵詞召回...")
|
||||
text = clean_text(text)
|
||||
meta = extract_metadata_by_rule(text, lang)
|
||||
ctx = gather_all(text)
|
||||
ctx = gather_all(text, head_length, tail_length,
|
||||
entities_window, entities_max,
|
||||
analysis_window, analysis_max)
|
||||
|
||||
log(f" 規則元數據:{meta}")
|
||||
log(f" 召回片段:")
|
||||
|
|
@ -1229,6 +1252,13 @@ def main() -> None:
|
|||
--base-url https://api.openai.com/v1 \\
|
||||
--model gpt-4 \\
|
||||
--api-key your-api-key
|
||||
|
||||
# 調整截取長度
|
||||
python hk_case_extractor.py case.txt \\
|
||||
--head-length 8000 \\
|
||||
--tail-length 8000 \\
|
||||
--entities-max 10000 \\
|
||||
--analysis-max 10000
|
||||
""",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
|
@ -1242,6 +1272,21 @@ def main() -> None:
|
|||
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout)")
|
||||
ap.add_argument("--debug-dump", default=None,
|
||||
help="額外輸出原始 JSON 結果到該路徑(便於 diff)")
|
||||
|
||||
# 截取長度控制參數
|
||||
ap.add_argument("--head-length", type=int, default=5000,
|
||||
help="開頭截取長度(默認:5000)")
|
||||
ap.add_argument("--tail-length", type=int, default=5000,
|
||||
help="尾部截取長度(默認:5000)")
|
||||
ap.add_argument("--entities-window", type=int, default=400,
|
||||
help="實體關鍵詞窗口半徑(默認:400)")
|
||||
ap.add_argument("--entities-max", type=int, default=6500,
|
||||
help="實體片段最大總長度(默認:6500)")
|
||||
ap.add_argument("--analysis-window", type=int, default=500,
|
||||
help="分析關鍵詞窗口半徑(默認:500)")
|
||||
ap.add_argument("--analysis-max", type=int, default=6500,
|
||||
help="分析片段最大總長度(默認:6500)")
|
||||
|
||||
args = ap.parse_args()
|
||||
|
||||
# 支持从 .json 文件的 content 字段读取
|
||||
|
|
@ -1256,7 +1301,10 @@ def main() -> None:
|
|||
else:
|
||||
text = input_path.read_text(encoding="utf-8")
|
||||
|
||||
result = run_pipeline(text, args.model, args.base_url, args.api_key)
|
||||
result = run_pipeline(text, args.model, args.base_url, args.api_key,
|
||||
args.head_length, args.tail_length,
|
||||
args.entities_window, args.entities_max,
|
||||
args.analysis_window, args.analysis_max)
|
||||
|
||||
if args.debug_dump:
|
||||
Path(args.debug_dump).write_text(
|
||||
|
|
|
|||
Loading…
Reference in New Issue