@ -67,7 +67,8 @@ import argparse
import json
import re
import sys
from dataclasses import dataclass
import time
from dataclasses import dataclass , field
from pathlib import Path
from typing import Any
@ -85,6 +86,45 @@ DEFAULT_API_KEY = "ollama" # Ollama 不需要真實 key, 但 API 需要提供
DEFAULT_TIMEOUT = 600 # 增加到 10 分鐘,適應遠程服務器
MAX_RETRIES = 2
DEFAULT_MODELS_FILE = " models.json "
# =============================================================================
# 模型配置( models.json) : 按配置名加載 base_url / api_key / model / 計費價格
# =============================================================================
def load_model_profile ( name : str , models_file : str = DEFAULT_MODELS_FILE ) - > dict :
""" 從 models.json 按 Name 加載一個模型配置。
配置格式 ( 數組 ) :
[ { " Name " : . . . , " source " : . . . , " BaseApiUrl " : . . . , " ApiKey " : . . . ,
" model " : . . . , " input_price " : . . . , " output_price " : . . . , " price_unit " : . . . } ]
找不到文件或配置名時直接退出並提示可用配置 。
"""
path = Path ( models_file )
if not path . exists ( ) :
print ( f " ❌ 找不到模型配置文件: { models_file } " , file = sys . stderr )
sys . exit ( 1 )
try :
profiles = json . loads ( path . read_text ( encoding = " utf-8 " ) )
except json . JSONDecodeError as e :
print ( f " ❌ 模型配置文件不是有效的 JSON: { e } " , file = sys . stderr )
sys . exit ( 1 )
if not isinstance ( profiles , list ) :
print ( " ❌ 模型配置文件應為配置對象組成的數組 " , file = sys . stderr )
sys . exit ( 1 )
for p in profiles :
if p . get ( " Name " ) == name :
return p
available = [ p . get ( " Name " ) for p in profiles ]
print ( f " ❌ 配置 ' { name } ' 不存在。可用配置: { available } " , file = sys . stderr )
sys . exit ( 1 )
# =============================================================================
# 0. 語言檢測
@ -450,6 +490,15 @@ class OpenAICompatibleClient:
api_key : str = DEFAULT_API_KEY
timeout : int = DEFAULT_TIMEOUT
# token 用量累計(跨所有調用,含重試)
total_input_tokens : int = field ( default = 0 , init = False )
total_output_tokens : int = field ( default = 0 , init = False )
num_calls : int = field ( default = 0 , init = False )
@property
def total_tokens ( self ) - > int :
return self . total_input_tokens + self . total_output_tokens
def chat_json ( self , system : str , user : str , schema : dict ,
temperature : float = 0.0 ,
max_tokens : int = 4096 ) - > dict :
@ -494,6 +543,12 @@ class OpenAICompatibleClient:
print ( f " 響應內容: { r . text [ : 500 ] } " , file = sys . stderr )
raise
# 累計 token 用量( OpenAI 兼容端點通常在 usage 字段返回)
usage = response_data . get ( " usage " ) or { }
self . total_input_tokens + = int ( usage . get ( " prompt_tokens " , 0 ) or 0 )
self . total_output_tokens + = int ( usage . get ( " completion_tokens " , 0 ) or 0 )
self . num_calls + = 1
# 提取內容
if " choices " not in response_data or not response_data [ " choices " ] :
print ( f " ❌ API 響應缺少 choices 字段 " , file = sys . stderr )
@ -1213,7 +1268,7 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
entities_window : int = 400 ,
entities_max : int = 6500 ,
analysis_window : int = 500 ,
analysis_max : int = 6500 ) - > dict:
analysis_max : int = 6500 ) - > tuple[ dict, OpenAICompatibleClient ] :
log = lambda m : print ( m , file = sys . stderr )
log ( " [0/7] 檢測語言... " )
@ -1277,7 +1332,55 @@ def run_pipeline(text: str, model: str, base_url: str, api_key: str,
for w in warnings :
log ( f " ⚠️ { w } " )
return final
return final , client
# =============================================================================
# 5.5 成本統計:根據 models.json 計費價格計算本次抽取消耗
# =============================================================================
def compute_cost ( client : OpenAICompatibleClient ,
elapsed_seconds : float ,
profile : dict | None ) - > dict :
""" 根據 token 用量、耗時和模型計費價格計算本次抽取成本。
價格單位為 「 每百萬 token 價格 」 :
input_cost = input_tokens / 1_000_000 * input_price
output_cost = output_tokens / 1_000_000 * output_price
profile 為 None ( 未使用 - - config ) 時 , 價格相關字段為 null 。
"""
input_tokens = client . total_input_tokens
output_tokens = client . total_output_tokens
input_price = output_price = None
price_unit = None
if profile is not None :
input_price = float ( profile . get ( " input_price " ) or 0 )
output_price = float ( profile . get ( " output_price " ) or 0 )
price_unit = profile . get ( " price_unit " )
input_cost = output_cost = total_cost = None
if input_price is not None and output_price is not None :
input_cost = round ( input_tokens / 1_000_000 * input_price , 6 )
output_cost = round ( output_tokens / 1_000_000 * output_price , 6 )
total_cost = round ( input_cost + output_cost , 6 )
return {
" config_name " : profile . get ( " Name " ) if profile else None ,
" source " : profile . get ( " source " ) if profile else None ,
" model " : client . model ,
" elapsed_seconds " : round ( elapsed_seconds , 3 ) ,
" num_api_calls " : client . num_calls ,
" input_tokens " : input_tokens ,
" output_tokens " : output_tokens ,
" total_tokens " : client . total_tokens ,
" input_price_per_million " : input_price ,
" output_price_per_million " : output_price ,
" price_unit " : price_unit ,
" input_cost " : input_cost ,
" output_cost " : output_cost ,
" total_cost " : total_cost ,
}
# =============================================================================
@ -1322,9 +1425,13 @@ def main() -> None:
description = " 香港判決書結構化抽取( OpenAI 兼容 API) " ,
epilog = """
示例用法 :
# 使用 models.json 中的配置名稱(推薦,省去多個參數)
python hk_case_extractor . py case . txt - - config openrouter - claude - sonnet - - out result . yaml
# 成本統計會寫入 result_cost.json
# 使用本地 Ollama
python hk_case_extractor . py case . txt - - model qwen2 .5 : 7 b - instruct
# 使用 OpenRouter
python hk_case_extractor . py case . txt \\
- - base - url https : / / openrouter . ai / api / v1 \\
@ -1347,13 +1454,20 @@ def main() -> None:
formatter_class = argparse . RawDescriptionHelpFormatter
)
ap . add_argument ( " input " , help = " 判決書文本路徑(.txt 或 .json) " )
ap . add_argument ( " --model " , default = DEFAULT_MODEL ,
ap . add_argument ( " --config " , default = None ,
help = " models.json 中的配置名稱( Name) , "
" 使用後可省略 --model/--base-url/--api-key " )
ap . add_argument ( " --models-file " , default = DEFAULT_MODELS_FILE ,
help = f " 模型配置文件路徑(默認: { DEFAULT_MODELS_FILE } ) " )
ap . add_argument ( " --model " , default = DEFAULT_MODEL ,
help = f " 模型名稱(默認: { DEFAULT_MODEL } ) " )
ap . add_argument ( " --base-url " , default = DEFAULT_BASE_URL ,
help = f " API base URL( 默認: { DEFAULT_BASE_URL } ) " )
ap . add_argument ( " --api-key " , default = DEFAULT_API_KEY ,
help = " API key( Ollama 可忽略) " )
ap . add_argument ( " --out " , default = None , help = " 輸出 YAML 路徑(默認 stdout) " )
ap . add_argument ( " --cost " , action = " store_true " ,
help = " 輸出成本統計到 { 輸出文件名}_cost.json( 默認不輸出) " )
ap . add_argument ( " --debug-dump " , default = None ,
help = " 額外輸出原始 JSON 結果到該路徑(便於 diff) " )
@ -1373,6 +1487,20 @@ def main() -> None:
args = ap . parse_args ( )
# 解析模型配置:--config 優先,未命中的字段回退到命令行/默認值
profile : dict | None = None
if args . config :
profile = load_model_profile ( args . config , args . models_file )
model = profile . get ( " model " ) or args . model
base_url = profile . get ( " BaseApiUrl " ) or args . base_url
api_key = profile . get ( " ApiKey " ) or args . api_key
print ( f " 使用配置 ' { args . config } ' : model={ model } , base_url= { base_url } " ,
file = sys . stderr )
else :
model = args . model
base_url = args . base_url
api_key = args . api_key
# 支持从 .json 文件的 content 字段读取
input_path = Path ( args . input )
if input_path . suffix . lower ( ) == ' .json ' :
@ -1385,10 +1513,27 @@ def main() -> None:
else :
text = input_path . read_text ( encoding = " utf-8 " )
result = run_pipeline ( text , args . model , args . base_url , args . api_key ,
args . head_length , args . tail_length ,
args . entities_window , args . entities_max ,
args . analysis_window , args . analysis_max )
start = time . perf_counter ( )
result , client = run_pipeline ( text , model , base_url , api_key ,
args . head_length , args . tail_length ,
args . entities_window , args . entities_max ,
args . analysis_window , args . analysis_max )
elapsed = time . perf_counter ( ) - start
# 成本統計:僅在 --cost 時輸出到 {summary_file_name}_cost.json
if args . cost :
cost = compute_cost ( client , elapsed , profile )
if args . out :
cost_path = Path ( args . out ) . with_name ( Path ( args . out ) . stem + " _cost.json " )
else :
cost_path = input_path . with_name ( input_path . stem + " _cost.json " )
cost_path . parent . mkdir ( parents = True , exist_ok = True )
cost_path . write_text ( json . dumps ( cost , ensure_ascii = False , indent = 2 ) ,
encoding = " utf-8 " )
print ( f " 💰 成本統計已寫入 { cost_path } :耗時 { cost [ ' elapsed_seconds ' ] } s, "
f " input= { cost [ ' input_tokens ' ] } output= { cost [ ' output_tokens ' ] } "
f " total_cost= { cost [ ' total_cost ' ] } { cost [ ' price_unit ' ] or ' ' } " ,
file = sys . stderr )
if args . debug_dump :
debug_path = Path ( args . debug_dump )