diff --git a/.gitignore b/.gitignore index 11df575..0646eac 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ *.docx *.html -result.yaml \ No newline at end of file +result.yaml +__pycache__/ \ No newline at end of file diff --git a/hk_case_extractor.py b/hk_case_extractor.py index 6dd883f..ca5e71e 100644 --- a/hk_case_extractor.py +++ b/hk_case_extractor.py @@ -67,7 +67,8 @@ import argparse import json import re import sys -from dataclasses import dataclass +import time +from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -85,6 +86,45 @@ DEFAULT_API_KEY = "ollama" # Ollama 不需要真實 key,但 API 需要提供 DEFAULT_TIMEOUT = 600 # 增加到 10 分鐘,適應遠程服務器 MAX_RETRIES = 2 +DEFAULT_MODELS_FILE = "models.json" + + +# ============================================================================= +# 模型配置(models.json):按配置名加載 base_url / api_key / model / 計費價格 +# ============================================================================= + +def load_model_profile(name: str, models_file: str = DEFAULT_MODELS_FILE) -> dict: + """從 models.json 按 Name 加載一個模型配置。 + + 配置格式(數組): + [{"Name":..., "source":..., "BaseApiUrl":..., "ApiKey":..., + "model":..., "input_price":..., "output_price":..., "price_unit":...}] + + 找不到文件或配置名時直接退出並提示可用配置。 + """ + path = Path(models_file) + if not path.exists(): + print(f"❌ 找不到模型配置文件:{models_file}", file=sys.stderr) + sys.exit(1) + + try: + profiles = json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as e: + print(f"❌ 模型配置文件不是有效的 JSON:{e}", file=sys.stderr) + sys.exit(1) + + if not isinstance(profiles, list): + print("❌ 模型配置文件應為配置對象組成的數組", file=sys.stderr) + sys.exit(1) + + for p in profiles: + if p.get("Name") == name: + return p + + available = [p.get("Name") for p in profiles] + print(f"❌ 配置 '{name}' 不存在。可用配置:{available}", file=sys.stderr) + sys.exit(1) + # ============================================================================= # 0. 語言檢測 @@ -450,6 +490,15 @@ class OpenAICompatibleClient: api_key: str = DEFAULT_API_KEY timeout: int = DEFAULT_TIMEOUT + # token 用量累計(跨所有調用,含重試) + total_input_tokens: int = field(default=0, init=False) + total_output_tokens: int = field(default=0, init=False) + num_calls: int = field(default=0, init=False) + + @property + def total_tokens(self) -> int: + return self.total_input_tokens + self.total_output_tokens + def chat_json(self, system: str, user: str, schema: dict, temperature: float = 0.0, max_tokens: int = 4096) -> dict: @@ -494,6 +543,12 @@ class OpenAICompatibleClient: print(f" 響應內容: {r.text[:500]}", file=sys.stderr) raise + # 累計 token 用量(OpenAI 兼容端點通常在 usage 字段返回) + usage = response_data.get("usage") or {} + self.total_input_tokens += int(usage.get("prompt_tokens", 0) or 0) + self.total_output_tokens += int(usage.get("completion_tokens", 0) or 0) + self.num_calls += 1 + # 提取內容 if "choices" not in response_data or not response_data["choices"]: print(f"❌ API 響應缺少 choices 字段", file=sys.stderr) @@ -1213,7 +1268,7 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str, entities_window: int = 400, entities_max: int = 6500, analysis_window: int = 500, - analysis_max: int = 6500) -> dict: + analysis_max: int = 6500) -> tuple[dict, OpenAICompatibleClient]: log = lambda m: print(m, file=sys.stderr) log("[0/7] 檢測語言...") @@ -1277,7 +1332,55 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str, for w in warnings: log(f" ⚠️ {w}") - return final + return final, client + + +# ============================================================================= +# 5.5 成本統計:根據 models.json 計費價格計算本次抽取消耗 +# ============================================================================= + +def compute_cost(client: OpenAICompatibleClient, + elapsed_seconds: float, + profile: dict | None) -> dict: + """根據 token 用量、耗時和模型計費價格計算本次抽取成本。 + + 價格單位為「每百萬 token 價格」: + input_cost = input_tokens / 1_000_000 * input_price + output_cost = output_tokens / 1_000_000 * output_price + profile 為 None(未使用 --config)時,價格相關字段為 null。 + """ + input_tokens = client.total_input_tokens + output_tokens = client.total_output_tokens + + input_price = output_price = None + price_unit = None + if profile is not None: + input_price = float(profile.get("input_price") or 0) + output_price = float(profile.get("output_price") or 0) + price_unit = profile.get("price_unit") + + input_cost = output_cost = total_cost = None + if input_price is not None and output_price is not None: + input_cost = round(input_tokens / 1_000_000 * input_price, 6) + output_cost = round(output_tokens / 1_000_000 * output_price, 6) + total_cost = round(input_cost + output_cost, 6) + + return { + "config_name": profile.get("Name") if profile else None, + "source": profile.get("source") if profile else None, + "model": client.model, + "elapsed_seconds": round(elapsed_seconds, 3), + "num_api_calls": client.num_calls, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": client.total_tokens, + "input_price_per_million": input_price, + "output_price_per_million": output_price, + "price_unit": price_unit, + "input_cost": input_cost, + "output_cost": output_cost, + "total_cost": total_cost, + } # ============================================================================= @@ -1322,9 +1425,13 @@ def main() -> None: description="香港判決書結構化抽取(OpenAI 兼容 API)", epilog=""" 示例用法: + # 使用 models.json 中的配置名稱(推薦,省去多個參數) + python hk_case_extractor.py case.txt --config openrouter-claude-sonnet --out result.yaml + # 成本統計會寫入 result_cost.json + # 使用本地 Ollama python hk_case_extractor.py case.txt --model qwen2.5:7b-instruct - + # 使用 OpenRouter python hk_case_extractor.py case.txt \\ --base-url https://openrouter.ai/api/v1 \\ @@ -1347,13 +1454,20 @@ def main() -> None: formatter_class=argparse.RawDescriptionHelpFormatter ) ap.add_argument("input", help="判決書文本路徑(.txt 或 .json)") - ap.add_argument("--model", default=DEFAULT_MODEL, + ap.add_argument("--config", default=None, + help="models.json 中的配置名稱(Name)," + "使用後可省略 --model/--base-url/--api-key") + ap.add_argument("--models-file", default=DEFAULT_MODELS_FILE, + help=f"模型配置文件路徑(默認:{DEFAULT_MODELS_FILE})") + ap.add_argument("--model", default=DEFAULT_MODEL, help=f"模型名稱(默認:{DEFAULT_MODEL})") ap.add_argument("--base-url", default=DEFAULT_BASE_URL, help=f"API base URL(默認:{DEFAULT_BASE_URL})") ap.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key(Ollama 可忽略)") ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout)") + ap.add_argument("--cost", action="store_true", + help="輸出成本統計到 {輸出文件名}_cost.json(默認不輸出)") ap.add_argument("--debug-dump", default=None, help="額外輸出原始 JSON 結果到該路徑(便於 diff)") @@ -1373,6 +1487,20 @@ def main() -> None: args = ap.parse_args() + # 解析模型配置:--config 優先,未命中的字段回退到命令行/默認值 + profile: dict | None = None + if args.config: + profile = load_model_profile(args.config, args.models_file) + model = profile.get("model") or args.model + base_url = profile.get("BaseApiUrl") or args.base_url + api_key = profile.get("ApiKey") or args.api_key + print(f"使用配置 '{args.config}':model={model}, base_url={base_url}", + file=sys.stderr) + else: + model = args.model + base_url = args.base_url + api_key = args.api_key + # 支持从 .json 文件的 content 字段读取 input_path = Path(args.input) if input_path.suffix.lower() == '.json': @@ -1385,10 +1513,27 @@ def main() -> None: else: text = input_path.read_text(encoding="utf-8") - result = run_pipeline(text, args.model, args.base_url, args.api_key, - args.head_length, args.tail_length, - args.entities_window, args.entities_max, - args.analysis_window, args.analysis_max) + start = time.perf_counter() + result, client = run_pipeline(text, model, base_url, api_key, + args.head_length, args.tail_length, + args.entities_window, args.entities_max, + args.analysis_window, args.analysis_max) + elapsed = time.perf_counter() - start + + # 成本統計:僅在 --cost 時輸出到 {summary_file_name}_cost.json + if args.cost: + cost = compute_cost(client, elapsed, profile) + if args.out: + cost_path = Path(args.out).with_name(Path(args.out).stem + "_cost.json") + else: + cost_path = input_path.with_name(input_path.stem + "_cost.json") + cost_path.parent.mkdir(parents=True, exist_ok=True) + cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2), + encoding="utf-8") + print(f"💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s," + f"input={cost['input_tokens']} output={cost['output_tokens']} " + f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}", + file=sys.stderr) if args.debug_dump: debug_path = Path(args.debug_dump) diff --git a/models.json b/models.json new file mode 100644 index 0000000..c7f7fef --- /dev/null +++ b/models.json @@ -0,0 +1,42 @@ +[ + { + "Name": "ollama-qwen", + "source": "ollama", + "BaseApiUrl": "http://localhost:11434/v1", + "ApiKey": "ollama", + "model": "qwen2.5:7b-instruct", + "input_price": 0, + "output_price": 0, + "price_unit": "USD" + }, + { + "Name": "openrouter-claude-sonnet", + "source": "openrouter", + "BaseApiUrl": "https://openrouter.ai/api/v1", + "ApiKey": "", + "model": "anthropic/claude-3.5-sonnet", + "input_price": 3.0, + "output_price": 15.0, + "price_unit": "USD" + }, + { + "Name": "openai-gpt-4o", + "source": "openai", + "BaseApiUrl": "https://api.openai.com/v1", + "ApiKey": "", + "model": "gpt-4o", + "input_price": 2.5, + "output_price": 10.0, + "price_unit": "USD" + }, + { + "Name": "deepseek-chat", + "source": "deepseek", + "BaseApiUrl": "https://api.deepseek.com/v1", + "ApiKey": "", + "model": "deepseek-chat", + "input_price": 2.0, + "output_price": 8.0, + "price_unit": "CNY" + } +]