add model config profiles and cost tracking
- models.json: named profiles with base url/key/model and per-million token pricing
- support --config <Name> to load a profile, replacing --model/--base-url/--api-key
- accumulate input/output token usage across all API calls in the client
- compute time + token cost per case, written to {out}_cost.json behind opt-in --cost flag
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
main
parent
b0e5fd7fd5
commit
ae1005aab7
|
|
@ -2,4 +2,5 @@
|
||||||
*.docx
|
*.docx
|
||||||
*.html
|
*.html
|
||||||
|
|
||||||
result.yaml
|
result.yaml
|
||||||
|
__pycache__/
|
||||||
|
|
@ -67,7 +67,8 @@ import argparse
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -85,6 +86,45 @@ DEFAULT_API_KEY = "ollama" # Ollama 不需要真實 key,但 API 需要提供
|
||||||
DEFAULT_TIMEOUT = 600 # 增加到 10 分鐘,適應遠程服務器
|
DEFAULT_TIMEOUT = 600 # 增加到 10 分鐘,適應遠程服務器
|
||||||
MAX_RETRIES = 2
|
MAX_RETRIES = 2
|
||||||
|
|
||||||
|
DEFAULT_MODELS_FILE = "models.json"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 模型配置(models.json):按配置名加載 base_url / api_key / model / 計費價格
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def load_model_profile(name: str, models_file: str = DEFAULT_MODELS_FILE) -> dict:
|
||||||
|
"""從 models.json 按 Name 加載一個模型配置。
|
||||||
|
|
||||||
|
配置格式(數組):
|
||||||
|
[{"Name":..., "source":..., "BaseApiUrl":..., "ApiKey":...,
|
||||||
|
"model":..., "input_price":..., "output_price":..., "price_unit":...}]
|
||||||
|
|
||||||
|
找不到文件或配置名時直接退出並提示可用配置。
|
||||||
|
"""
|
||||||
|
path = Path(models_file)
|
||||||
|
if not path.exists():
|
||||||
|
print(f"❌ 找不到模型配置文件:{models_file}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
profiles = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"❌ 模型配置文件不是有效的 JSON:{e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not isinstance(profiles, list):
|
||||||
|
print("❌ 模型配置文件應為配置對象組成的數組", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
for p in profiles:
|
||||||
|
if p.get("Name") == name:
|
||||||
|
return p
|
||||||
|
|
||||||
|
available = [p.get("Name") for p in profiles]
|
||||||
|
print(f"❌ 配置 '{name}' 不存在。可用配置:{available}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 0. 語言檢測
|
# 0. 語言檢測
|
||||||
|
|
@ -450,6 +490,15 @@ class OpenAICompatibleClient:
|
||||||
api_key: str = DEFAULT_API_KEY
|
api_key: str = DEFAULT_API_KEY
|
||||||
timeout: int = DEFAULT_TIMEOUT
|
timeout: int = DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
# token 用量累計(跨所有調用,含重試)
|
||||||
|
total_input_tokens: int = field(default=0, init=False)
|
||||||
|
total_output_tokens: int = field(default=0, init=False)
|
||||||
|
num_calls: int = field(default=0, init=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
return self.total_input_tokens + self.total_output_tokens
|
||||||
|
|
||||||
def chat_json(self, system: str, user: str, schema: dict,
|
def chat_json(self, system: str, user: str, schema: dict,
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
max_tokens: int = 4096) -> dict:
|
max_tokens: int = 4096) -> dict:
|
||||||
|
|
@ -494,6 +543,12 @@ class OpenAICompatibleClient:
|
||||||
print(f" 響應內容: {r.text[:500]}", file=sys.stderr)
|
print(f" 響應內容: {r.text[:500]}", file=sys.stderr)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# 累計 token 用量(OpenAI 兼容端點通常在 usage 字段返回)
|
||||||
|
usage = response_data.get("usage") or {}
|
||||||
|
self.total_input_tokens += int(usage.get("prompt_tokens", 0) or 0)
|
||||||
|
self.total_output_tokens += int(usage.get("completion_tokens", 0) or 0)
|
||||||
|
self.num_calls += 1
|
||||||
|
|
||||||
# 提取內容
|
# 提取內容
|
||||||
if "choices" not in response_data or not response_data["choices"]:
|
if "choices" not in response_data or not response_data["choices"]:
|
||||||
print(f"❌ API 響應缺少 choices 字段", file=sys.stderr)
|
print(f"❌ API 響應缺少 choices 字段", file=sys.stderr)
|
||||||
|
|
@ -1213,7 +1268,7 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
|
||||||
entities_window: int = 400,
|
entities_window: int = 400,
|
||||||
entities_max: int = 6500,
|
entities_max: int = 6500,
|
||||||
analysis_window: int = 500,
|
analysis_window: int = 500,
|
||||||
analysis_max: int = 6500) -> dict:
|
analysis_max: int = 6500) -> tuple[dict, OpenAICompatibleClient]:
|
||||||
log = lambda m: print(m, file=sys.stderr)
|
log = lambda m: print(m, file=sys.stderr)
|
||||||
|
|
||||||
log("[0/7] 檢測語言...")
|
log("[0/7] 檢測語言...")
|
||||||
|
|
@ -1277,7 +1332,55 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
|
||||||
for w in warnings:
|
for w in warnings:
|
||||||
log(f" ⚠️ {w}")
|
log(f" ⚠️ {w}")
|
||||||
|
|
||||||
return final
|
return final, client
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 5.5 成本統計:根據 models.json 計費價格計算本次抽取消耗
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def compute_cost(client: OpenAICompatibleClient,
|
||||||
|
elapsed_seconds: float,
|
||||||
|
profile: dict | None) -> dict:
|
||||||
|
"""根據 token 用量、耗時和模型計費價格計算本次抽取成本。
|
||||||
|
|
||||||
|
價格單位為「每百萬 token 價格」:
|
||||||
|
input_cost = input_tokens / 1_000_000 * input_price
|
||||||
|
output_cost = output_tokens / 1_000_000 * output_price
|
||||||
|
profile 為 None(未使用 --config)時,價格相關字段為 null。
|
||||||
|
"""
|
||||||
|
input_tokens = client.total_input_tokens
|
||||||
|
output_tokens = client.total_output_tokens
|
||||||
|
|
||||||
|
input_price = output_price = None
|
||||||
|
price_unit = None
|
||||||
|
if profile is not None:
|
||||||
|
input_price = float(profile.get("input_price") or 0)
|
||||||
|
output_price = float(profile.get("output_price") or 0)
|
||||||
|
price_unit = profile.get("price_unit")
|
||||||
|
|
||||||
|
input_cost = output_cost = total_cost = None
|
||||||
|
if input_price is not None and output_price is not None:
|
||||||
|
input_cost = round(input_tokens / 1_000_000 * input_price, 6)
|
||||||
|
output_cost = round(output_tokens / 1_000_000 * output_price, 6)
|
||||||
|
total_cost = round(input_cost + output_cost, 6)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"config_name": profile.get("Name") if profile else None,
|
||||||
|
"source": profile.get("source") if profile else None,
|
||||||
|
"model": client.model,
|
||||||
|
"elapsed_seconds": round(elapsed_seconds, 3),
|
||||||
|
"num_api_calls": client.num_calls,
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": client.total_tokens,
|
||||||
|
"input_price_per_million": input_price,
|
||||||
|
"output_price_per_million": output_price,
|
||||||
|
"price_unit": price_unit,
|
||||||
|
"input_cost": input_cost,
|
||||||
|
"output_cost": output_cost,
|
||||||
|
"total_cost": total_cost,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -1322,9 +1425,13 @@ def main() -> None:
|
||||||
description="香港判決書結構化抽取(OpenAI 兼容 API)",
|
description="香港判決書結構化抽取(OpenAI 兼容 API)",
|
||||||
epilog="""
|
epilog="""
|
||||||
示例用法:
|
示例用法:
|
||||||
|
# 使用 models.json 中的配置名稱(推薦,省去多個參數)
|
||||||
|
python hk_case_extractor.py case.txt --config openrouter-claude-sonnet --out result.yaml
|
||||||
|
# 成本統計會寫入 result_cost.json
|
||||||
|
|
||||||
# 使用本地 Ollama
|
# 使用本地 Ollama
|
||||||
python hk_case_extractor.py case.txt --model qwen2.5:7b-instruct
|
python hk_case_extractor.py case.txt --model qwen2.5:7b-instruct
|
||||||
|
|
||||||
# 使用 OpenRouter
|
# 使用 OpenRouter
|
||||||
python hk_case_extractor.py case.txt \\
|
python hk_case_extractor.py case.txt \\
|
||||||
--base-url https://openrouter.ai/api/v1 \\
|
--base-url https://openrouter.ai/api/v1 \\
|
||||||
|
|
@ -1347,13 +1454,20 @@ def main() -> None:
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter
|
formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
)
|
)
|
||||||
ap.add_argument("input", help="判決書文本路徑(.txt 或 .json)")
|
ap.add_argument("input", help="判決書文本路徑(.txt 或 .json)")
|
||||||
ap.add_argument("--model", default=DEFAULT_MODEL,
|
ap.add_argument("--config", default=None,
|
||||||
|
help="models.json 中的配置名稱(Name),"
|
||||||
|
"使用後可省略 --model/--base-url/--api-key")
|
||||||
|
ap.add_argument("--models-file", default=DEFAULT_MODELS_FILE,
|
||||||
|
help=f"模型配置文件路徑(默認:{DEFAULT_MODELS_FILE})")
|
||||||
|
ap.add_argument("--model", default=DEFAULT_MODEL,
|
||||||
help=f"模型名稱(默認:{DEFAULT_MODEL})")
|
help=f"模型名稱(默認:{DEFAULT_MODEL})")
|
||||||
ap.add_argument("--base-url", default=DEFAULT_BASE_URL,
|
ap.add_argument("--base-url", default=DEFAULT_BASE_URL,
|
||||||
help=f"API base URL(默認:{DEFAULT_BASE_URL})")
|
help=f"API base URL(默認:{DEFAULT_BASE_URL})")
|
||||||
ap.add_argument("--api-key", default=DEFAULT_API_KEY,
|
ap.add_argument("--api-key", default=DEFAULT_API_KEY,
|
||||||
help="API key(Ollama 可忽略)")
|
help="API key(Ollama 可忽略)")
|
||||||
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout)")
|
ap.add_argument("--out", default=None, help="輸出 YAML 路徑(默認 stdout)")
|
||||||
|
ap.add_argument("--cost", action="store_true",
|
||||||
|
help="輸出成本統計到 {輸出文件名}_cost.json(默認不輸出)")
|
||||||
ap.add_argument("--debug-dump", default=None,
|
ap.add_argument("--debug-dump", default=None,
|
||||||
help="額外輸出原始 JSON 結果到該路徑(便於 diff)")
|
help="額外輸出原始 JSON 結果到該路徑(便於 diff)")
|
||||||
|
|
||||||
|
|
@ -1373,6 +1487,20 @@ def main() -> None:
|
||||||
|
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
# 解析模型配置:--config 優先,未命中的字段回退到命令行/默認值
|
||||||
|
profile: dict | None = None
|
||||||
|
if args.config:
|
||||||
|
profile = load_model_profile(args.config, args.models_file)
|
||||||
|
model = profile.get("model") or args.model
|
||||||
|
base_url = profile.get("BaseApiUrl") or args.base_url
|
||||||
|
api_key = profile.get("ApiKey") or args.api_key
|
||||||
|
print(f"使用配置 '{args.config}':model={model}, base_url={base_url}",
|
||||||
|
file=sys.stderr)
|
||||||
|
else:
|
||||||
|
model = args.model
|
||||||
|
base_url = args.base_url
|
||||||
|
api_key = args.api_key
|
||||||
|
|
||||||
# 支持从 .json 文件的 content 字段读取
|
# 支持从 .json 文件的 content 字段读取
|
||||||
input_path = Path(args.input)
|
input_path = Path(args.input)
|
||||||
if input_path.suffix.lower() == '.json':
|
if input_path.suffix.lower() == '.json':
|
||||||
|
|
@ -1385,10 +1513,27 @@ def main() -> None:
|
||||||
else:
|
else:
|
||||||
text = input_path.read_text(encoding="utf-8")
|
text = input_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
result = run_pipeline(text, args.model, args.base_url, args.api_key,
|
start = time.perf_counter()
|
||||||
args.head_length, args.tail_length,
|
result, client = run_pipeline(text, model, base_url, api_key,
|
||||||
args.entities_window, args.entities_max,
|
args.head_length, args.tail_length,
|
||||||
args.analysis_window, args.analysis_max)
|
args.entities_window, args.entities_max,
|
||||||
|
args.analysis_window, args.analysis_max)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
|
||||||
|
# 成本統計:僅在 --cost 時輸出到 {summary_file_name}_cost.json
|
||||||
|
if args.cost:
|
||||||
|
cost = compute_cost(client, elapsed, profile)
|
||||||
|
if args.out:
|
||||||
|
cost_path = Path(args.out).with_name(Path(args.out).stem + "_cost.json")
|
||||||
|
else:
|
||||||
|
cost_path = input_path.with_name(input_path.stem + "_cost.json")
|
||||||
|
cost_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
cost_path.write_text(json.dumps(cost, ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8")
|
||||||
|
print(f"💰 成本統計已寫入 {cost_path}:耗時 {cost['elapsed_seconds']}s,"
|
||||||
|
f"input={cost['input_tokens']} output={cost['output_tokens']} "
|
||||||
|
f"total_cost={cost['total_cost']} {cost['price_unit'] or ''}",
|
||||||
|
file=sys.stderr)
|
||||||
|
|
||||||
if args.debug_dump:
|
if args.debug_dump:
|
||||||
debug_path = Path(args.debug_dump)
|
debug_path = Path(args.debug_dump)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"Name": "ollama-qwen",
|
||||||
|
"source": "ollama",
|
||||||
|
"BaseApiUrl": "http://localhost:11434/v1",
|
||||||
|
"ApiKey": "ollama",
|
||||||
|
"model": "qwen2.5:7b-instruct",
|
||||||
|
"input_price": 0,
|
||||||
|
"output_price": 0,
|
||||||
|
"price_unit": "USD"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "openrouter-claude-sonnet",
|
||||||
|
"source": "openrouter",
|
||||||
|
"BaseApiUrl": "https://openrouter.ai/api/v1",
|
||||||
|
"ApiKey": "",
|
||||||
|
"model": "anthropic/claude-3.5-sonnet",
|
||||||
|
"input_price": 3.0,
|
||||||
|
"output_price": 15.0,
|
||||||
|
"price_unit": "USD"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "openai-gpt-4o",
|
||||||
|
"source": "openai",
|
||||||
|
"BaseApiUrl": "https://api.openai.com/v1",
|
||||||
|
"ApiKey": "",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"input_price": 2.5,
|
||||||
|
"output_price": 10.0,
|
||||||
|
"price_unit": "USD"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "deepseek-chat",
|
||||||
|
"source": "deepseek",
|
||||||
|
"BaseApiUrl": "https://api.deepseek.com/v1",
|
||||||
|
"ApiKey": "",
|
||||||
|
"model": "deepseek-chat",
|
||||||
|
"input_price": 2.0,
|
||||||
|
"output_price": 8.0,
|
||||||
|
"price_unit": "CNY"
|
||||||
|
}
|
||||||
|
]
|
||||||
Loading…
Reference in New Issue