add model config profiles and cost tracking

- models.json: named profiles with base url/key/model and per-million token pricing
- support --config <Name> to load a profile, replacing --model/--base-url/--api-key
- accumulate input/output token usage across all API calls in the client
- compute time + token cost per case, written to {out}_cost.json behind opt-in --cost flag

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
main
fengruixiang 2026-05-28 12:00:48 +08:00
parent b0e5fd7fd5
commit ae1005aab7
3 changed files with 198 additions and 10 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@
*.html *.html
result.yaml result.yaml
__pycache__/

View File

@ -67,7 +67,8 @@ import argparse
import json import json
import re import re
import sys import sys
from dataclasses import dataclass import time
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -85,6 +86,45 @@ DEFAULT_API_KEY = "ollama" # Ollama 不需要真實 key但 API 需要提供
DEFAULT_TIMEOUT = 600 # 增加到 10 分鐘,適應遠程服務器 DEFAULT_TIMEOUT = 600 # 增加到 10 分鐘,適應遠程服務器
MAX_RETRIES = 2 MAX_RETRIES = 2
DEFAULT_MODELS_FILE = "models.json"
# =============================================================================
# 模型配置models.json按配置名加載 base_url / api_key / model / 計費價格
# =============================================================================
def load_model_profile(name: str, models_file: str = DEFAULT_MODELS_FILE) -> dict:
"""從 models.json 按 Name 加載一個模型配置。
配置格式數組
[{"Name":..., "source":..., "BaseApiUrl":..., "ApiKey":...,
"model":..., "input_price":..., "output_price":..., "price_unit":...}]
找不到文件或配置名時直接退出並提示可用配置
"""
path = Path(models_file)
if not path.exists():
print(f"❌ 找不到模型配置文件:{models_file}", file=sys.stderr)
sys.exit(1)
try:
profiles = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as e:
print(f"❌ 模型配置文件不是有效的 JSON{e}", file=sys.stderr)
sys.exit(1)
if not isinstance(profiles, list):
print("❌ 模型配置文件應為配置對象組成的數組", file=sys.stderr)
sys.exit(1)
for p in profiles:
if p.get("Name") == name:
return p
available = [p.get("Name") for p in profiles]
print(f"❌ 配置 '{name}' 不存在。可用配置:{available}", file=sys.stderr)
sys.exit(1)
# ============================================================================= # =============================================================================
# 0. 語言檢測 # 0. 語言檢測
@ -450,6 +490,15 @@ class OpenAICompatibleClient:
api_key: str = DEFAULT_API_KEY api_key: str = DEFAULT_API_KEY
timeout: int = DEFAULT_TIMEOUT timeout: int = DEFAULT_TIMEOUT
# token 用量累計(跨所有調用,含重試)
total_input_tokens: int = field(default=0, init=False)
total_output_tokens: int = field(default=0, init=False)
num_calls: int = field(default=0, init=False)
@property
def total_tokens(self) -> int:
return self.total_input_tokens + self.total_output_tokens
def chat_json(self, system: str, user: str, schema: dict, def chat_json(self, system: str, user: str, schema: dict,
temperature: float = 0.0, temperature: float = 0.0,
max_tokens: int = 4096) -> dict: max_tokens: int = 4096) -> dict:
@ -494,6 +543,12 @@ class OpenAICompatibleClient:
print(f" 響應內容: {r.text[:500]}", file=sys.stderr) print(f" 響應內容: {r.text[:500]}", file=sys.stderr)
raise raise
# 累計 token 用量OpenAI 兼容端點通常在 usage 字段返回)
usage = response_data.get("usage") or {}
self.total_input_tokens += int(usage.get("prompt_tokens", 0) or 0)
self.total_output_tokens += int(usage.get("completion_tokens", 0) or 0)
self.num_calls += 1
# 提取內容 # 提取內容
if "choices" not in response_data or not response_data["choices"]: if "choices" not in response_data or not response_data["choices"]:
print(f"❌ API 響應缺少 choices 字段", file=sys.stderr) print(f"❌ API 響應缺少 choices 字段", file=sys.stderr)
@ -1213,7 +1268,7 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
entities_window: int = 400, entities_window: int = 400,
entities_max: int = 6500, entities_max: int = 6500,
analysis_window: int = 500, analysis_window: int = 500,
analysis_max: int = 6500) -> dict: analysis_max: int = 6500) -> tuple[dict, OpenAICompatibleClient]:
log = lambda m: print(m, file=sys.stderr) log = lambda m: print(m, file=sys.stderr)
log("[0/7] 檢測語言...") log("[0/7] 檢測語言...")
@ -1277,7 +1332,55 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
for w in warnings: for w in warnings:
log(f" ⚠️ {w}") log(f" ⚠️ {w}")
return final return final, client
# =============================================================================
# 5.5 成本統計:根據 models.json 計費價格計算本次抽取消耗
# =============================================================================
def compute_cost(client: OpenAICompatibleClient,
elapsed_seconds: float,
profile: dict | None) -> dict:
"""根據 token 用量、耗時和模型計費價格計算本次抽取成本。
價格單位為每百萬 token 價格
input_cost = input_tokens / 1_000_000 * input_price
output_cost = output_tokens / 1_000_000 * output_price
profile None未使用 --config價格相關字段為 null
"""
input_tokens = client.total_input_tokens
output_tokens = client.total_output_tokens
input_price = output_price = None
price_unit = None
if profile is not None:
input_price = float(profile.get("input_price") or 0)
output_price = float(profile.get("output_price") or 0)
price_unit = profile.get("price_unit")
input_cost = output_cost = total_cost = None
if input_price is not None and output_price is not None:
input_cost = round(input_tokens / 1_000_000 * input_price, 6)
output_cost = round(output_tokens / 1_000_000 * output_price, 6)
total_cost = round(input_cost + output_cost, 6)
return {
"config_name": profile.get("Name") if profile else None,
"source": profile.get("source") if profile else None,
"model": client.model,
"elapsed_seconds": round(elapsed_seconds, 3),
"num_api_calls": client.num_calls,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": client.total_tokens,
"input_price_per_million": input_price,
"output_price_per_million": output_price,
"price_unit": price_unit,
"input_cost": input_cost,
"output_cost": output_cost,
"total_cost": total_cost,
}
# ============================================================================= # =============================================================================
@ -1322,6 +1425,10 @@ def main() -> None:
description="香港判決書結構化抽取OpenAI 兼容 API", description="香港判決書結構化抽取OpenAI 兼容 API",
epilog=""" epilog="""
示例用法 示例用法
# 使用 models.json 中的配置名稱(推薦,省去多個參數)
python hk_case_extractor.py case.txt --config openrouter-claude-sonnet --out result.yaml
# 成本統計會寫入 result_cost.json
# 使用本地 Ollama # 使用本地 Ollama
python hk_case_extractor.py case.txt --model qwen2.5:7b-instruct python hk_case_extractor.py case.txt --model qwen2.5:7b-instruct
@ -1347,6 +1454,11 @@ def main() -> None:
formatter_class=argparse.RawDescriptionHelpFormatter formatter_class=argparse.RawDescriptionHelpFormatter
) )
ap.add_argument("input", help="判決書文本路徑(.txt 或 .json") ap.add_argument("input", help="判決書文本路徑(.txt 或 .json")
ap.add_argument("--config", default=None,
help="models.json 中的配置名稱Name"
"使用後可省略 --model/--base-url/--api-key")
ap.add_argument("--models-file", default=DEFAULT_MODELS_FILE,
help=f"模型配置文件路徑(默認:{DEFAULT_MODELS_FILE}")
ap.add_argument("--model", default=DEFAULT_MODEL, ap.add_argument("--model", default=DEFAULT_MODEL,
help=f"模型名稱(默認:{DEFAULT_MODEL}") help=f"模型名稱(默認:{DEFAULT_MODEL}")
ap.add_argument("--base-url", default=DEFAULT_BASE_URL, ap.add_argument("--base-url", default=DEFAULT_BASE_URL,
@ -1354,6 +1466,8 @@ def main() -> None:
ap.add_argument("--api-key", default=DEFAULT_API_KEY, ap.add_argument("--api-key", default=DEFAULT_API_KEY,
help="API keyOllama 可忽略)") help="API keyOllama 可忽略)")
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout") ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout")
ap.add_argument("--cost", action="store_true",
help="輸出成本統計到 {輸出文件名}_cost.json默認不輸出")
ap.add_argument("--debug-dump", default=None, ap.add_argument("--debug-dump", default=None,
help="額外輸出原始 JSON 結果到該路徑(便於 diff") help="額外輸出原始 JSON 結果到該路徑(便於 diff")
@ -1373,6 +1487,20 @@ def main() -> None:
args = ap.parse_args() args = ap.parse_args()
# 解析模型配置:--config 優先,未命中的字段回退到命令行/默認值
profile: dict | None = None
if args.config:
profile = load_model_profile(args.config, args.models_file)
model = profile.get("model") or args.model
base_url = profile.get("BaseApiUrl") or args.base_url
api_key = profile.get("ApiKey") or args.api_key
print(f"使用配置 '{args.config}'model={model}, base_url={base_url}",
file=sys.stderr)
else:
model = args.model
base_url = args.base_url
api_key = args.api_key
# 支持从 .json 文件的 content 字段读取 # 支持从 .json 文件的 content 字段读取
input_path = Path(args.input) input_path = Path(args.input)
if input_path.suffix.lower() == '.json': if input_path.suffix.lower() == '.json':
@ -1385,10 +1513,27 @@ def main() -> None:
else: else:
text = input_path.read_text(encoding="utf-8") text = input_path.read_text(encoding="utf-8")
result = run_pipeline(text, args.model, args.base_url, args.api_key, start = time.perf_counter()
args.head_length, args.tail_length, result, client = run_pipeline(text, model, base_url, api_key,
args.entities_window, args.entities_max, args.head_length, args.tail_length,
args.analysis_window, args.analysis_max) args.entities_window, args.entities_max,
args.analysis_window, args.analysis_max)
elapsed = time.perf_counter() - start
# 成本統計:僅在 --cost 時輸出到 {summary_file_name}_cost.json
if args.cost:
cost = compute_cost(client, elapsed, profile)
if args.out:
cost_path = Path(args.out).with_name(Path(args.out).stem + "_cost.json")
else:
cost_path = input_path.with_name(input_path.stem + "_cost.json")
cost_path.parent.mkdir(parents=True, exist_ok=True)
cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2),
encoding="utf-8")
print(f"💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s"
f"input={cost['input_tokens']} output={cost['output_tokens']} "
f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}",
file=sys.stderr)
if args.debug_dump: if args.debug_dump:
debug_path = Path(args.debug_dump) debug_path = Path(args.debug_dump)

42
models.json 100644
View File

@ -0,0 +1,42 @@
[
{
"Name": "ollama-qwen",
"source": "ollama",
"BaseApiUrl": "http://localhost:11434/v1",
"ApiKey": "ollama",
"model": "qwen2.5:7b-instruct",
"input_price": 0,
"output_price": 0,
"price_unit": "USD"
},
{
"Name": "openrouter-claude-sonnet",
"source": "openrouter",
"BaseApiUrl": "https://openrouter.ai/api/v1",
"ApiKey": "",
"model": "anthropic/claude-3.5-sonnet",
"input_price": 3.0,
"output_price": 15.0,
"price_unit": "USD"
},
{
"Name": "openai-gpt-4o",
"source": "openai",
"BaseApiUrl": "https://api.openai.com/v1",
"ApiKey": "",
"model": "gpt-4o",
"input_price": 2.5,
"output_price": 10.0,
"price_unit": "USD"
},
{
"Name": "deepseek-chat",
"source": "deepseek",
"BaseApiUrl": "https://api.deepseek.com/v1",
"ApiKey": "",
"model": "deepseek-chat",
"input_price": 2.0,
"output_price": 8.0,
"price_unit": "CNY"
}
]