add multi-model concurrent extraction with per-model output files

--config and --model now accept comma-separated lists to run several
models on one document. Runs execute concurrently via a thread pool
(--max-workers controls parallelism; 1 forces sequential), with logs
prefixed by the config/model label.

Each model writes to its own YAML/cost/debug files by inserting a
sanitized label into the --out name (or the input name when --out is
omitted), and cost is computed per-model from its own profile.
Single-model behavior, including stdout output, is unchanged.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
main
fengruixiang 2026-05-28 15:16:04 +08:00
parent bc2b9826d5
commit abcb2103f1
1 changed files with 183 additions and 54 deletions

View File

@ -68,6 +68,7 @@ import json
import re import re
import sys import sys
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -1319,8 +1320,9 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
entities_window: int = 400, entities_window: int = 400,
entities_max: int = 6500, entities_max: int = 6500,
analysis_window: int = 500, analysis_window: int = 500,
analysis_max: int = 6500) -> tuple[dict, OpenAICompatibleClient]: analysis_max: int = 6500,
log = lambda m: print(m, file=sys.stderr) log_prefix: str = "") -> tuple[dict, OpenAICompatibleClient]:
log = lambda m: print(f"{log_prefix}{m}", file=sys.stderr)
log("[0/7] 檢測語言...") log("[0/7] 檢測語言...")
lang = detect_language(text) lang = detect_language(text)
@ -1467,6 +1469,93 @@ def to_yaml(result: dict) -> str:
default_flow_style=False, width=100) default_flow_style=False, width=100)
# =============================================================================
# 7. 多模型運行:解析運行規格 + 輸出路徑推導
# =============================================================================
@dataclass
class RunSpec:
"""單個模型的運行規格(標籤 + 連接參數 + 計費配置)"""
label: str # 用於日誌前綴與輸出文件名(配置名或模型名)
model: str
base_url: str
api_key: str
profile: dict | None = None # models.json 配置(用於成本計算),無則 None
def _safe_label(name: str) -> str:
"""把配置/模型名轉成可安全用於文件名的標籤(如 anthropic/claude → anthropic_claude"""
safe = re.sub(r"[^0-9A-Za-z._-]+", "_", name).strip("_")
return safe or "model"
def parse_run_specs(args) -> list[RunSpec]:
"""解析 --config / --model均支持逗號分隔為一組運行規格。
優先級
- --config逐個從 models.json 加載配置自帶 model/base_url/api_key/價格
- 否則 --model可逗號分隔多個共用 --base-url / --api-key
"""
specs: list[RunSpec] = []
if args.config:
names = [n.strip() for n in args.config.split(",") if n.strip()]
for name in names:
profile = load_model_profile(name, args.models_file)
specs.append(RunSpec(
label=name,
model=profile.get("model") or args.model,
base_url=profile.get("BaseApiUrl") or args.base_url,
api_key=profile.get("ApiKey") or args.api_key,
profile=profile,
))
else:
names = [n.strip() for n in args.model.split(",") if n.strip()] or [args.model]
for name in names:
specs.append(RunSpec(
label=name,
model=name,
base_url=args.base_url,
api_key=args.api_key,
profile=None,
))
return specs
def build_out_path(args, input_path: Path, label: str, multi: bool) -> Path | None:
"""推導某個模型的 YAML 輸出路徑。
- 單模型沿用原行為--out 指定則用之否則 None 表示輸出到 stdout
- 多模型在文件名中插入標籤未給 --out 時用輸入名_標籤.yaml
"""
if not multi:
return Path(args.out) if args.out else None
safe = _safe_label(label)
if args.out:
base = Path(args.out)
return base.with_name(f"{base.stem}_{safe}{base.suffix or '.yaml'}")
return input_path.with_name(f"{input_path.stem}_{safe}.yaml")
def build_cost_path(out_path: Path | None, input_path: Path,
label: str, multi: bool) -> Path:
"""成本文件路徑有輸出文件時用「輸出名_cost.json」否則回退到輸入名。"""
if out_path is not None:
return out_path.with_name(out_path.stem + "_cost.json")
if multi:
return input_path.with_name(f"{input_path.stem}_{_safe_label(label)}_cost.json")
return input_path.with_name(input_path.stem + "_cost.json")
def build_debug_path(args, label: str, multi: bool) -> Path | None:
"""debug-dump 路徑:多模型時在文件名中插入標籤。"""
if not args.debug_dump:
return None
base = Path(args.debug_dump)
if multi:
return base.with_name(f"{base.stem}_{_safe_label(label)}{base.suffix or '.json'}")
return base
# ============================================================================= # =============================================================================
# CLI # CLI
# ============================================================================= # =============================================================================
@ -1495,6 +1584,17 @@ def main() -> None:
--model gpt-4 \\ --model gpt-4 \\
--api-key your-api-key --api-key your-api-key
# 同時跑多個模型(逗號分隔),並發執行,分別輸出到不同文件
python hk_case_extractor.py case.txt \\
--config openrouter-claude-sonnet,openrouter-gpt4o,ollama-qwen \\
--out result.yaml --cost
# 生成 result_openrouter-claude-sonnet.yaml / result_openrouter-gpt4o.yaml / ...
# 及各自的 *_cost.json
# 多個本地 Ollama 模型共用同一端點
python hk_case_extractor.py case.txt \\
--model qwen2.5:7b-instruct,llama3.1:8b --out result.yaml
# 調整截取長度 # 調整截取長度
python hk_case_extractor.py case.txt \\ python hk_case_extractor.py case.txt \\
--head-length 8000 \\ --head-length 8000 \\
@ -1506,21 +1606,24 @@ def main() -> None:
) )
ap.add_argument("input", help="判決書文本路徑(.txt 或 .json") ap.add_argument("input", help="判決書文本路徑(.txt 或 .json")
ap.add_argument("--config", default=None, ap.add_argument("--config", default=None,
help="models.json 中的配置名稱Name" help="models.json 中的配置名稱Name可逗號分隔多個以同時運行多個模型,"
"使用後可省略 --model/--base-url/--api-key") "使用後可省略 --model/--base-url/--api-key")
ap.add_argument("--models-file", default=DEFAULT_MODELS_FILE, ap.add_argument("--models-file", default=DEFAULT_MODELS_FILE,
help=f"模型配置文件路徑(默認:{DEFAULT_MODELS_FILE}") help=f"模型配置文件路徑(默認:{DEFAULT_MODELS_FILE}")
ap.add_argument("--model", default=DEFAULT_MODEL, ap.add_argument("--model", default=DEFAULT_MODEL,
help=f"模型名稱(默認:{DEFAULT_MODEL}") help=f"模型名稱(默認:{DEFAULT_MODEL},可逗號分隔多個(共用 --base-url/--api-key")
ap.add_argument("--base-url", default=DEFAULT_BASE_URL, ap.add_argument("--base-url", default=DEFAULT_BASE_URL,
help=f"API base URL默認{DEFAULT_BASE_URL}") help=f"API base URL默認{DEFAULT_BASE_URL}")
ap.add_argument("--api-key", default=DEFAULT_API_KEY, ap.add_argument("--api-key", default=DEFAULT_API_KEY,
help="API keyOllama 可忽略)") help="API keyOllama 可忽略)")
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout") ap.add_argument("--out", default=None,
help="輸出 YAML 路徑(單模型默認 stdout多模型時在文件名中插入標籤")
ap.add_argument("--max-workers", type=int, default=0,
help="多模型時的並發數(默認 0 = 模型數量;設為 1 則順序執行)")
ap.add_argument("--cost", action="store_true", ap.add_argument("--cost", action="store_true",
help="輸出成本統計到 {輸出文件名}_cost.json默認不輸出") help="輸出成本統計到 {輸出文件名}_cost.json默認不輸出")
ap.add_argument("--debug-dump", default=None, ap.add_argument("--debug-dump", default=None,
help="額外輸出原始 JSON 結果到該路徑(便於 diff") help="額外輸出原始 JSON 結果到該路徑(多模型時在文件名中插入標籤")
# 截取長度控制參數 # 截取長度控制參數
ap.add_argument("--head-length", type=int, default=5000, ap.add_argument("--head-length", type=int, default=5000,
@ -1538,24 +1641,17 @@ def main() -> None:
args = ap.parse_args() args = ap.parse_args()
# 解析模型配置:--config 優先,未命中的字段回退到命令行/默認值 # 解析運行規格(--config / --model 均支持逗號分隔多個)
profile: dict | None = None specs = parse_run_specs(args)
if args.config: multi = len(specs) > 1
profile = load_model_profile(args.config, args.models_file) for spec in specs:
model = profile.get("model") or args.model if spec.profile is not None:
base_url = profile.get("BaseApiUrl") or args.base_url print(f"使用配置 '{spec.label}'model={spec.model}, base_url={spec.base_url}",
api_key = profile.get("ApiKey") or args.api_key file=sys.stderr)
print(f"使用配置 '{args.config}'model={model}, base_url={base_url}",
file=sys.stderr)
else:
model = args.model
base_url = args.base_url
api_key = args.api_key
# 支持从 .json 文件的 content 字段读取 # 支持从 .json 文件的 content 字段读取
input_path = Path(args.input) input_path = Path(args.input)
if input_path.suffix.lower() == '.json': if input_path.suffix.lower() == '.json':
import json
data = json.loads(input_path.read_text(encoding="utf-8")) data = json.loads(input_path.read_text(encoding="utf-8"))
text = data.get("content", "") text = data.get("content", "")
if not text: if not text:
@ -1563,44 +1659,77 @@ def main() -> None:
sys.exit(1) sys.exit(1)
else: else:
text = input_path.read_text(encoding="utf-8") text = input_path.read_text(encoding="utf-8")
start = time.perf_counter()
result, client = run_pipeline(text, model, base_url, api_key,
args.head_length, args.tail_length,
args.entities_window, args.entities_max,
args.analysis_window, args.analysis_max)
elapsed = time.perf_counter() - start
# 成本統計:僅在 --cost 時輸出到 {summary_file_name}_cost.json def run_and_emit(spec: RunSpec) -> str:
if args.cost: """跑單個模型並寫出其 YAML / cost / debug 文件,返回輸出路徑描述。"""
cost = compute_cost(client, elapsed, profile) log_prefix = f"[{spec.label}] " if multi else ""
if args.out: start = time.perf_counter()
cost_path = Path(args.out).with_name(Path(args.out).stem + "_cost.json") result, client = run_pipeline(
else: text, spec.model, spec.base_url, spec.api_key,
cost_path = input_path.with_name(input_path.stem + "_cost.json") args.head_length, args.tail_length,
cost_path.parent.mkdir(parents=True, exist_ok=True) args.entities_window, args.entities_max,
cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2), args.analysis_window, args.analysis_max,
encoding="utf-8") log_prefix=log_prefix)
print(f"💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s" elapsed = time.perf_counter() - start
f"input={cost['input_tokens']} output={cost['output_tokens']} "
f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}",
file=sys.stderr)
if args.debug_dump: out_path = build_out_path(args, input_path, spec.label, multi)
debug_path = Path(args.debug_dump)
debug_path.parent.mkdir(parents=True, exist_ok=True)
debug_path.write_text(
json.dumps(result, ensure_ascii=False, indent=2),
encoding="utf-8")
yaml_str = to_yaml(result) # 成本統計:僅在 --cost 時輸出到 {輸出文件名}_cost.json
if args.out: if args.cost:
out_path = Path(args.out) cost = compute_cost(client, elapsed, spec.profile)
out_path.parent.mkdir(parents=True, exist_ok=True) cost_path = build_cost_path(out_path, input_path, spec.label, multi)
out_path.write_text(yaml_str, encoding="utf-8") cost_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n✅ 已寫入 {args.out}", file=sys.stderr) cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2),
else: encoding="utf-8")
print(f"{log_prefix}💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s"
f"input={cost['input_tokens']} output={cost['output_tokens']} "
f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}",
file=sys.stderr)
# debug dump原始 JSON
debug_path = build_debug_path(args, spec.label, multi)
if debug_path is not None:
debug_path.parent.mkdir(parents=True, exist_ok=True)
debug_path.write_text(json.dumps(result, ensure_ascii=False, indent=2),
encoding="utf-8")
# YAML 輸出
yaml_str = to_yaml(result)
if out_path is not None:
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(yaml_str, encoding="utf-8")
print(f"{log_prefix}✅ 已寫入 {out_path}", file=sys.stderr)
return str(out_path)
print(yaml_str) print(yaml_str)
return "(stdout)"
if multi and args.max_workers != 1:
# 並發執行:每個模型一個線程,日誌以 [標籤] 前綴區分
workers = args.max_workers if args.max_workers > 0 else len(specs)
print(f"⏳ 同時運行 {len(specs)} 個模型(並發 {workers}"
f"{[s.label for s in specs]}", file=sys.stderr)
summary: list[tuple[str, str]] = []
with ThreadPoolExecutor(max_workers=workers) as ex:
futures = {ex.submit(run_and_emit, spec): spec for spec in specs}
for fut in as_completed(futures):
spec = futures[fut]
try:
summary.append((spec.label, fut.result()))
except Exception as e:
print(f"[{spec.label}] ❌ 運行失敗:{e}", file=sys.stderr)
summary.append((spec.label, f"FAILED: {e}"))
print("\n=== 多模型運行結果 ===", file=sys.stderr)
for label, out in sorted(summary):
print(f" {label:30s}{out}", file=sys.stderr)
else:
# 單模型,或多模型但顯式 --max-workers 1 順序執行
for spec in specs:
try:
run_and_emit(spec)
except Exception as e:
if not multi:
raise
print(f"[{spec.label}] ❌ 運行失敗:{e}", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":