用Python视角拆解Google AMIE首次真实世界临床验证(下)
附录一、完整 Python 示例代码
下面这版代码适合做:
- 多个方法 / 模型的批量评估
- 生成逐样本结果表
- 汇总平均指标、置信区间
- 做配对显著性检验
- 导出适合 notebook 和论文表格使用的数据
默认按分类/打分/二元判断类任务来写;如果你原本是生成式任务,也可以把metric_fn换成 ROUGE、BLEU、EM、F1、人工评分等。
from__future__importannotationsimportjsonimportmathfromdataclassesimportdataclass,asdictfrompathlibimportPathfromtypingimportCallable,Dict,List,Any,Optional,Tupleimportnumpyasnpimportpandasaspdfromscipyimportstats# =========================# 1. 数据结构# =========================@dataclassclassExample:sample_id:strinput_text:strgold:Any meta:Optional[Dict[str,Any]]=None@dataclassclassPrediction:sample_id:strmodel_name:strpred:Any score:Optional[float]=Nonelatency_ms:Optional[float]=Nonetokens_in:Optional[int]=Nonetokens_out:Optional[int]=Noneraw_output:Optional[str]=None# =========================# 2. 指标函数# =========================defbinary_accuracy(gold:Any,pred:Any)->float:returnfloat(gold==pred)defexact_match(gold:str,pred:str)->float:ifgoldisNoneorpredisNone:return0.0returnfloat(str(gold).strip()==str(pred).strip())defmacro_f1_from_frame(df:pd.DataFrame,gold_col:str="gold",pred_col:str="pred")->float:""" 基于整表计算 macro-F1。 适合多分类整体汇总,不适合逐样本。 """labels=sorted(set(df[gold_col].dropna().unique())|set(df[pred_col].dropna().unique()))f1s=[]forlabelinlabels:tp=((df[gold_col]==label)&(df[pred_col]==label)).sum()fp=((df[gold_col]!=label)&(df[pred_col]==label)).sum()fn=((df[gold_col]==label)&(df[pred_col]!=label)).sum()precision=tp/(tp+fp)if(tp+fp)>0else0.0recall=tp/(tp+fn)if(tp+fn)>0else0.0f1=2*precision*recall/(precision+recall)if(precision+recall)>0else0.0f1s.append(f1)returnfloat(np.mean(f1s))iff1selse0.0# =========================# 3. 评估主流程# =========================defevaluate_predictions(examples:List[Example],predictions:List[Prediction],metric_fn:Callable[[Any,Any],float],)->pd.DataFrame:""" 返回逐样本评估表: [sample_id, model_name, gold, pred, metric, latency_ms, ...] """ex_map={ex.sample_id:exforexinexamples}rows=[]forpinpredictions:ifp.sample_idnotinex_map:continueex=ex_map[p.sample_id]metric_value=metric_fn(ex.gold,p.pred)row={"sample_id":p.sample_id,"model_name":p.model_name,"input_text":ex.input_text,"gold":ex.gold,"pred":p.pred,"metric":metric_value,"latency_ms":p.latency_ms,"tokens_in":p.tokens_in,"tokens_out":p.tokens_out,"score":p.score,}ifex.meta:fork,vinex.meta.items():row[f"meta.{k}"]=v rows.append(row)returnpd.DataFrame(rows)# =========================# 4. Bootstrap 置信区间# =========================defbootstrap_ci(values:np.ndarray,n_bootstrap:int=5000,ci:float=