add multi-model concurrent extraction with per-model output files
--config and --model now accept comma-separated lists to run several models on one document. Runs execute concurrently via a thread pool (--max-workers controls parallelism; 1 forces sequential), with logs prefixed by the config/model label. Each model writes to its own YAML/cost/debug files by inserting a sanitized label into the --out name (or the input name when --out is omitted), and cost is computed per-model from its own profile. Single-model behavior, including stdout output, is unchanged. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>main
parent
bc2b9826d5
commit
abcb2103f1
|
|
@ -68,6 +68,7 @@ import json
|
|||
import re
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -1319,8 +1320,9 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
|
|||
entities_window: int = 400,
|
||||
entities_max: int = 6500,
|
||||
analysis_window: int = 500,
|
||||
analysis_max: int = 6500) -> tuple[dict, OpenAICompatibleClient]:
|
||||
log = lambda m: print(m, file=sys.stderr)
|
||||
analysis_max: int = 6500,
|
||||
log_prefix: str = "") -> tuple[dict, OpenAICompatibleClient]:
|
||||
log = lambda m: print(f"{log_prefix}{m}", file=sys.stderr)
|
||||
|
||||
log("[0/7] 檢測語言...")
|
||||
lang = detect_language(text)
|
||||
|
|
@ -1467,6 +1469,93 @@ def to_yaml(result: dict) -> str:
|
|||
default_flow_style=False, width=100)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. 多模型運行:解析運行規格 + 輸出路徑推導
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class RunSpec:
|
||||
"""單個模型的運行規格(標籤 + 連接參數 + 計費配置)"""
|
||||
label: str # 用於日誌前綴與輸出文件名(配置名或模型名)
|
||||
model: str
|
||||
base_url: str
|
||||
api_key: str
|
||||
profile: dict | None = None # models.json 配置(用於成本計算),無則 None
|
||||
|
||||
|
||||
def _safe_label(name: str) -> str:
|
||||
"""把配置/模型名轉成可安全用於文件名的標籤(如 anthropic/claude → anthropic_claude)"""
|
||||
safe = re.sub(r"[^0-9A-Za-z._-]+", "_", name).strip("_")
|
||||
return safe or "model"
|
||||
|
||||
|
||||
def parse_run_specs(args) -> list[RunSpec]:
|
||||
"""解析 --config / --model(均支持逗號分隔)為一組運行規格。
|
||||
|
||||
優先級:
|
||||
- 有 --config:逐個從 models.json 加載配置(自帶 model/base_url/api_key/價格)
|
||||
- 否則:用 --model(可逗號分隔多個),共用 --base-url / --api-key
|
||||
"""
|
||||
specs: list[RunSpec] = []
|
||||
if args.config:
|
||||
names = [n.strip() for n in args.config.split(",") if n.strip()]
|
||||
for name in names:
|
||||
profile = load_model_profile(name, args.models_file)
|
||||
specs.append(RunSpec(
|
||||
label=name,
|
||||
model=profile.get("model") or args.model,
|
||||
base_url=profile.get("BaseApiUrl") or args.base_url,
|
||||
api_key=profile.get("ApiKey") or args.api_key,
|
||||
profile=profile,
|
||||
))
|
||||
else:
|
||||
names = [n.strip() for n in args.model.split(",") if n.strip()] or [args.model]
|
||||
for name in names:
|
||||
specs.append(RunSpec(
|
||||
label=name,
|
||||
model=name,
|
||||
base_url=args.base_url,
|
||||
api_key=args.api_key,
|
||||
profile=None,
|
||||
))
|
||||
return specs
|
||||
|
||||
|
||||
def build_out_path(args, input_path: Path, label: str, multi: bool) -> Path | None:
|
||||
"""推導某個模型的 YAML 輸出路徑。
|
||||
|
||||
- 單模型:沿用原行為(--out 指定則用之,否則 None 表示輸出到 stdout)
|
||||
- 多模型:在文件名中插入標籤;未給 --out 時用「輸入名_標籤.yaml」
|
||||
"""
|
||||
if not multi:
|
||||
return Path(args.out) if args.out else None
|
||||
safe = _safe_label(label)
|
||||
if args.out:
|
||||
base = Path(args.out)
|
||||
return base.with_name(f"{base.stem}_{safe}{base.suffix or '.yaml'}")
|
||||
return input_path.with_name(f"{input_path.stem}_{safe}.yaml")
|
||||
|
||||
|
||||
def build_cost_path(out_path: Path | None, input_path: Path,
|
||||
label: str, multi: bool) -> Path:
|
||||
"""成本文件路徑:有輸出文件時用「輸出名_cost.json」,否則回退到輸入名。"""
|
||||
if out_path is not None:
|
||||
return out_path.with_name(out_path.stem + "_cost.json")
|
||||
if multi:
|
||||
return input_path.with_name(f"{input_path.stem}_{_safe_label(label)}_cost.json")
|
||||
return input_path.with_name(input_path.stem + "_cost.json")
|
||||
|
||||
|
||||
def build_debug_path(args, label: str, multi: bool) -> Path | None:
|
||||
"""debug-dump 路徑:多模型時在文件名中插入標籤。"""
|
||||
if not args.debug_dump:
|
||||
return None
|
||||
base = Path(args.debug_dump)
|
||||
if multi:
|
||||
return base.with_name(f"{base.stem}_{_safe_label(label)}{base.suffix or '.json'}")
|
||||
return base
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI
|
||||
# =============================================================================
|
||||
|
|
@ -1495,6 +1584,17 @@ def main() -> None:
|
|||
--model gpt-4 \\
|
||||
--api-key your-api-key
|
||||
|
||||
# 同時跑多個模型(逗號分隔),並發執行,分別輸出到不同文件
|
||||
python hk_case_extractor.py case.txt \\
|
||||
--config openrouter-claude-sonnet,openrouter-gpt4o,ollama-qwen \\
|
||||
--out result.yaml --cost
|
||||
# 生成 result_openrouter-claude-sonnet.yaml / result_openrouter-gpt4o.yaml / ...
|
||||
# 及各自的 *_cost.json
|
||||
|
||||
# 多個本地 Ollama 模型共用同一端點
|
||||
python hk_case_extractor.py case.txt \\
|
||||
--model qwen2.5:7b-instruct,llama3.1:8b --out result.yaml
|
||||
|
||||
# 調整截取長度
|
||||
python hk_case_extractor.py case.txt \\
|
||||
--head-length 8000 \\
|
||||
|
|
@ -1506,21 +1606,24 @@ def main() -> None:
|
|||
)
|
||||
ap.add_argument("input", help="判決書文本路徑(.txt 或 .json)")
|
||||
ap.add_argument("--config", default=None,
|
||||
help="models.json 中的配置名稱(Name),"
|
||||
help="models.json 中的配置名稱(Name),可逗號分隔多個以同時運行多個模型,"
|
||||
"使用後可省略 --model/--base-url/--api-key")
|
||||
ap.add_argument("--models-file", default=DEFAULT_MODELS_FILE,
|
||||
help=f"模型配置文件路徑(默認:{DEFAULT_MODELS_FILE})")
|
||||
ap.add_argument("--model", default=DEFAULT_MODEL,
|
||||
help=f"模型名稱(默認:{DEFAULT_MODEL})")
|
||||
help=f"模型名稱(默認:{DEFAULT_MODEL}),可逗號分隔多個(共用 --base-url/--api-key)")
|
||||
ap.add_argument("--base-url", default=DEFAULT_BASE_URL,
|
||||
help=f"API base URL(默認:{DEFAULT_BASE_URL})")
|
||||
ap.add_argument("--api-key", default=DEFAULT_API_KEY,
|
||||
help="API key(Ollama 可忽略)")
|
||||
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout)")
|
||||
ap.add_argument("--out", default=None,
|
||||
help="輸出 YAML 路徑(單模型默認 stdout);多模型時在文件名中插入標籤")
|
||||
ap.add_argument("--max-workers", type=int, default=0,
|
||||
help="多模型時的並發數(默認 0 = 模型數量;設為 1 則順序執行)")
|
||||
ap.add_argument("--cost", action="store_true",
|
||||
help="輸出成本統計到 {輸出文件名}_cost.json(默認不輸出)")
|
||||
ap.add_argument("--debug-dump", default=None,
|
||||
help="額外輸出原始 JSON 結果到該路徑(便於 diff)")
|
||||
help="額外輸出原始 JSON 結果到該路徑(多模型時在文件名中插入標籤)")
|
||||
|
||||
# 截取長度控制參數
|
||||
ap.add_argument("--head-length", type=int, default=5000,
|
||||
|
|
@ -1538,24 +1641,17 @@ def main() -> None:
|
|||
|
||||
args = ap.parse_args()
|
||||
|
||||
# 解析模型配置:--config 優先,未命中的字段回退到命令行/默認值
|
||||
profile: dict | None = None
|
||||
if args.config:
|
||||
profile = load_model_profile(args.config, args.models_file)
|
||||
model = profile.get("model") or args.model
|
||||
base_url = profile.get("BaseApiUrl") or args.base_url
|
||||
api_key = profile.get("ApiKey") or args.api_key
|
||||
print(f"使用配置 '{args.config}':model={model}, base_url={base_url}",
|
||||
file=sys.stderr)
|
||||
else:
|
||||
model = args.model
|
||||
base_url = args.base_url
|
||||
api_key = args.api_key
|
||||
# 解析運行規格(--config / --model 均支持逗號分隔多個)
|
||||
specs = parse_run_specs(args)
|
||||
multi = len(specs) > 1
|
||||
for spec in specs:
|
||||
if spec.profile is not None:
|
||||
print(f"使用配置 '{spec.label}':model={spec.model}, base_url={spec.base_url}",
|
||||
file=sys.stderr)
|
||||
|
||||
# 支持从 .json 文件的 content 字段读取
|
||||
input_path = Path(args.input)
|
||||
if input_path.suffix.lower() == '.json':
|
||||
import json
|
||||
data = json.loads(input_path.read_text(encoding="utf-8"))
|
||||
text = data.get("content", "")
|
||||
if not text:
|
||||
|
|
@ -1563,44 +1659,77 @@ def main() -> None:
|
|||
sys.exit(1)
|
||||
else:
|
||||
text = input_path.read_text(encoding="utf-8")
|
||||
|
||||
start = time.perf_counter()
|
||||
result, client = run_pipeline(text, model, base_url, api_key,
|
||||
args.head_length, args.tail_length,
|
||||
args.entities_window, args.entities_max,
|
||||
args.analysis_window, args.analysis_max)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
# 成本統計:僅在 --cost 時輸出到 {summary_file_name}_cost.json
|
||||
if args.cost:
|
||||
cost = compute_cost(client, elapsed, profile)
|
||||
if args.out:
|
||||
cost_path = Path(args.out).with_name(Path(args.out).stem + "_cost.json")
|
||||
else:
|
||||
cost_path = input_path.with_name(input_path.stem + "_cost.json")
|
||||
cost_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8")
|
||||
print(f"💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s,"
|
||||
f"input={cost['input_tokens']} output={cost['output_tokens']} "
|
||||
f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}",
|
||||
file=sys.stderr)
|
||||
def run_and_emit(spec: RunSpec) -> str:
|
||||
"""跑單個模型並寫出其 YAML / cost / debug 文件,返回輸出路徑描述。"""
|
||||
log_prefix = f"[{spec.label}] " if multi else ""
|
||||
start = time.perf_counter()
|
||||
result, client = run_pipeline(
|
||||
text, spec.model, spec.base_url, spec.api_key,
|
||||
args.head_length, args.tail_length,
|
||||
args.entities_window, args.entities_max,
|
||||
args.analysis_window, args.analysis_max,
|
||||
log_prefix=log_prefix)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
if args.debug_dump:
|
||||
debug_path = Path(args.debug_dump)
|
||||
debug_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
debug_path.write_text(
|
||||
json.dumps(result, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8")
|
||||
out_path = build_out_path(args, input_path, spec.label, multi)
|
||||
|
||||
yaml_str = to_yaml(result)
|
||||
if args.out:
|
||||
out_path = Path(args.out)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(yaml_str, encoding="utf-8")
|
||||
print(f"\n✅ 已寫入 {args.out}", file=sys.stderr)
|
||||
else:
|
||||
# 成本統計:僅在 --cost 時輸出到 {輸出文件名}_cost.json
|
||||
if args.cost:
|
||||
cost = compute_cost(client, elapsed, spec.profile)
|
||||
cost_path = build_cost_path(out_path, input_path, spec.label, multi)
|
||||
cost_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8")
|
||||
print(f"{log_prefix}💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s,"
|
||||
f"input={cost['input_tokens']} output={cost['output_tokens']} "
|
||||
f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}",
|
||||
file=sys.stderr)
|
||||
|
||||
# debug dump(原始 JSON)
|
||||
debug_path = build_debug_path(args, spec.label, multi)
|
||||
if debug_path is not None:
|
||||
debug_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
debug_path.write_text(json.dumps(result, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8")
|
||||
|
||||
# YAML 輸出
|
||||
yaml_str = to_yaml(result)
|
||||
if out_path is not None:
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(yaml_str, encoding="utf-8")
|
||||
print(f"{log_prefix}✅ 已寫入 {out_path}", file=sys.stderr)
|
||||
return str(out_path)
|
||||
print(yaml_str)
|
||||
return "(stdout)"
|
||||
|
||||
if multi and args.max_workers != 1:
|
||||
# 並發執行:每個模型一個線程,日誌以 [標籤] 前綴區分
|
||||
workers = args.max_workers if args.max_workers > 0 else len(specs)
|
||||
print(f"⏳ 同時運行 {len(specs)} 個模型(並發 {workers}):"
|
||||
f"{[s.label for s in specs]}", file=sys.stderr)
|
||||
summary: list[tuple[str, str]] = []
|
||||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
futures = {ex.submit(run_and_emit, spec): spec for spec in specs}
|
||||
for fut in as_completed(futures):
|
||||
spec = futures[fut]
|
||||
try:
|
||||
summary.append((spec.label, fut.result()))
|
||||
except Exception as e:
|
||||
print(f"[{spec.label}] ❌ 運行失敗:{e}", file=sys.stderr)
|
||||
summary.append((spec.label, f"FAILED: {e}"))
|
||||
print("\n=== 多模型運行結果 ===", file=sys.stderr)
|
||||
for label, out in sorted(summary):
|
||||
print(f" {label:30s} → {out}", file=sys.stderr)
|
||||
else:
|
||||
# 單模型,或多模型但顯式 --max-workers 1 順序執行
|
||||
for spec in specs:
|
||||
try:
|
||||
run_and_emit(spec)
|
||||
except Exception as e:
|
||||
if not multi:
|
||||
raise
|
||||
print(f"[{spec.label}] ❌ 運行失敗:{e}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue