构建鲁棒性AI医疗模型:从青光眼筛查竞赛到工程实践
1. 项目概述:从一场竞赛看AI医疗的硬核落地
最近几年,医疗AI领域的竞赛层出不穷,但真正能推动技术走向临床、解决实际痛点的却不多。我关注到一项名为“AIROGS”的国际挑战赛,全称是“AI for Robust Glaucoma Screening”,直译过来就是“用于鲁棒性青光眼筛查的人工智能”。这个标题本身就很有意思,它没有停留在“准确率”这个单一维度,而是把“鲁棒性”放在了核心位置。这意味着什么?意味着它关注的不再是实验室里干净数据上的漂亮分数,而是AI模型在真实、复杂、充满不确定性的临床环境中,能否依然稳定、可靠地工作。
青光眼被称为“视力的小偷”,是全球不可逆性致盲的首要原因。关键在于早期筛查,而眼底图像检查是筛查的核心手段。但现实是,全球范围内,特别是基层和欠发达地区,具备读片能力的眼科医生严重短缺。AI辅助筛查被寄予厚望,但过去很多模型在遇到图像质量不佳(如对焦模糊、曝光不均)、设备型号差异、或存在其他眼底病变干扰时,性能会急剧下降,这就是“鲁棒性”不足。AIROGS挑战赛正是为了攻克这一难题而设,它提供了一个大规模、高质量且特意包含了各种真实世界挑战的眼底图像数据集,要求参赛者构建的模型不仅能判断“有无青光眼”,更要能“在各种不利条件下稳定地判断”。
这不仅仅是技术爱好者的游戏,它直接指向了AI医疗产品能否真正部署到社区医院、体检中心甚至移动筛查车上的关键。接下来,我将结合对这类竞赛的深度参与和行业观察,拆解构建一个鲁棒性青光眼筛查AI的核心技术栈、实操难点以及那些在论文里不会写的“坑”。
2. 核心需求与挑战拆解:为什么“鲁棒性”比“准确率”更难?
在动手构建模型之前,我们必须彻底理解AIROGS这类任务提出的深层挑战。如果只追求在精选数据集上的高准确率,那是一个相对单纯的技术问题;但追求鲁棒性,则是一个系统工程。
2.1 真实世界数据的“不完美性”
竞赛提供的训练数据,会刻意模拟真实临床采集环境中的多种干扰因素。这是我们构建鲁棒性模型的出发点。
图像质量变异:这是最大的挑战来源。包括:
- 对焦模糊:设备操作不当或患者配合度差导致。
- 曝光问题:过曝会使视盘等关键结构细节丢失,欠曝则噪声明显。
- 低对比度:图像整体灰蒙蒙,组织结构边界不清。
- 伪影:睫毛、灰尘、镜头光晕遮挡部分视野。
- 不均匀照明:图像周边暗,中心亮,或反之。
设备与采集协议差异:不同医院可能使用不同品牌(如Zeiss, Canon, Topcon)的眼底相机,其成像传感器、色彩空间、分辨率、视场角(如45°、50°)均不同。一个在Canon设备上训练完美的模型,在Topcon图像上可能直接“失灵”。
病理共存与干扰:青光眼患者常伴有其他眼底疾病,如糖尿病视网膜病变(DR)的出血、渗出,或年龄相关性黄斑变性(AMD)的病灶。这些病变可能改变视盘或视网膜的结构外观,对模型造成混淆。
标签噪声与不确定性:即使由专家标注,青光眼的诊断本身也存在灰色地带。早期青光眼、疑似青光眼与正常眼之间的界限有时并不绝对清晰。数据集中可能包含这种带有不确定性的标签,模型需要学会处理这种模糊性,而不是强行拟合。
注意:很多团队初赛成绩很好,但在面对独立外部验证集(来自完全不同机构、设备的数据)时性能骤降,根本原因就是模型过拟合了训练集的“特定分布”,而缺乏对上述不完美因素的泛化能力。
2.2 模型鲁棒性的多维定义
在AIROGS的语境下,鲁棒性至少体现在三个层面:
- 对输入扰动的稳定性:面对质量不佳的图像,模型给出的预测概率不应发生剧烈波动。例如,同一只眼睛的清晰图和轻微模糊图,模型判断为“青光眼”的概率应该接近,而不是从0.9骤降到0.3。
- 跨设备、跨中心的泛化能力:在未见过的设备类型或医疗机构采集的数据上,模型性能(如AUC, Sensitivity, Specificity)的衰减应控制在可接受范围内。这是部署的前提。
- 对混淆因素的抗干扰能力:当图像中存在其他病变时,模型应主要依据青光眼相关特征(如视盘杯盘比、盘沿宽度、神经纤维层缺损)做出判断,而不是被出血、渗出等无关特征带偏。
理解这些挑战后,我们的技术方案设计就有了明确的靶心:不是一味堆叠更深的网络,而是构建一个能“抗干扰”、“善适应”的AI系统。
3. 技术方案设计与核心组件选型
基于上述挑战,一个鲁棒性青光眼筛查流水线通常包含以下几个核心模块,其设计思路处处体现着对“鲁棒性”的考量。
3.1 数据预处理与质量增强模块
这是提升鲁棒性的第一道,也是至关重要的一道防线。目标是将多样化的输入图像,归一化到一个相对标准、干净的特征空间。
- 自适应图像质量评估与筛选:并非所有低质量图像都该被丢弃。我们可以训练一个轻量级的分类器(如MobileNet)或使用无参考图像质量评估(NR-IQA)算法(如BRISQUE, NIQE)对输入图像打分。对于质量极差、关键解剖结构(视盘)完全不可见的图像,应触发“质量不合格,建议重新拍摄”的提示,而不是强行分析。对于质量尚可但存在瑕疵的图像,则进入增强流程。
- 鲁棒的颜色归一化:不同设备的色彩差异巨大。简单的直方图均衡化可能破坏病理信息。更优的做法是采用基于深度学习的方法,如CycleGAN,学习从多个源设备域到一个标准颜色域的映射。或者使用更经典的Macenko方法进行染色归一化(虽然源自组织病理学,但其思想可借鉴),它能在保留生物学结构信息的前提下校正颜色。
- 针对性的图像增强:
- 去模糊:对于运动模糊,可使用盲去卷积或基于深度学习的方法(如DeblurGAN-v2)进行轻量级修复。但需谨慎,过度去模糊可能引入伪影。
- 对比度增强:采用自适应直方图均衡化(CLAHE),它能限制局部对比度过度增强,避免噪声放大,特别适用于改善视盘周边区域的可见度。
- 阴影校正:利用形态学操作(如顶帽变换)或拟合背景光照场来消除不均匀照明。
实操心得:这个模块不宜过于复杂耗时,需在效果和推理速度间权衡。在竞赛中,我们通常会在训练前对全体训练集做一次离线的、统一的预处理。但在部署时,这个模块必须是在线、实时的。一个技巧是,可以准备多个预处理后的图像版本(如原图、增强图),在后续特征提取阶段进行融合,让模型自己学习该关注哪个版本的信息。
3.2 核心深度学习模型架构选型
主干网络的选择是基础。当前的主流选择依然是卷积神经网络(CNN)及其变体。
- Backbone选择:EfficientNet、ConvNeXt、Swin Transformer是常见的选择。EfficientNet在精度和效率上平衡得很好;ConvNeXt融合了CNN和Transformer的设计思想,性能强劲;Swin Transformer因其窗口自注意力机制,在捕捉长距离依赖(如视盘与黄斑的关系)上可能有优势。关键不在于追逐最新最热的模型,而在于结合数据特性。眼底图像是高分辨率、结构化的,需要模型能精细捕捉视盘、血管的局部纹理,同时理解全局结构关系。
- 输入策略:是输入整张眼底图,还是先定位裁剪出视盘区域(ROI)再输入?
- 整图输入:简单直接,模型能自主关注所有区域。但计算量大,且可能让模型分心于无关背景。
- ROI输入:需要先运行一个视盘检测模型(如YOLO、U-Net),精准裁剪出视盘及周边区域。这大幅减少了输入信息量,迫使模型聚焦于青光眼最相关的区域,理论上有利于鲁棒性,因为屏蔽了大部分外部干扰。这是目前主流且被证明有效的策略。AIROGS数据集中通常已提供或可通过公开模型获得视盘中心坐标。
- 多尺度与注意力机制:青光眼特征既有局部(盘沿的微小切迹),也有全局(杯盘比的整体形态)。在模型中集成特征金字塔网络(FPN)或多尺度特征融合模块,能让模型同时利用低层细节和高层语义信息。注意力机制(如SE Block, CBAM)可以让模型动态地“聚焦”于图像中与青光眼判别最相关的区域,例如自动忽略睫毛伪影,增强对盘沿区域的关注。
3.3 提升鲁棒性的关键训练策略
模型架构是骨架,训练策略才是赋予其鲁棒性的灵魂。
- 数据增强的“艺术”:这是成本最低、效果最显著的鲁棒性提升手段。不能只使用简单的旋转、翻转。必须模拟真实世界的干扰:
- 模拟模糊:应用高斯模糊、运动模糊核。
- 模拟噪声:添加高斯噪声、泊松噪声。
- 模拟颜色偏移:在HSV或LAB空间随机扰动色调、饱和度和亮度。
- 模拟伪影:随机添加模拟睫毛、灰尘的黑色块状掩膜。
- 模拟设备差异:使用色彩抖动(Color Jitter)并加大强度,或应用随机风格迁移。 核心思想是:让模型在训练阶段就见遍“世间丑恶”,它才能在测试时处变不惊。
- 领域泛化与域自适应:如果数据集中明确包含了不同设备(域)的数据,我们可以采用领域泛化技术。例如:
- 域混合:在批次(Batch)内混合来自不同设备的图像,强制模型学习域不变的特征。
- 对抗性训练:引入一个域分类器,与主特征提取器进行对抗训练。特征提取器的目标是生成让域分类器无法区分设备来源的特征,从而剥离掉设备特异性信息,保留疾病本质特征。
- 标签平滑与不确定性建模:针对标签噪声,使用标签平滑技术,将硬标签(如0或1)转换为软标签(如0.1或0.9),防止模型对可能存在错误的标签过度自信。更高级的做法是引入证据深度学习或直接建模预测的不确定性,让模型在遇到模糊病例时,能输出一个较高的不确定性分数,供医生参考。
- 损失函数设计:除了标准的交叉熵损失,可以引入:
- 焦点损失:自动降低易分类样本(如非常正常的眼底)的权重,让模型更专注于难例(如早期青光眼、图像质量差的病例)。
- 一致性损失:对同一图像的不同增强版本,要求模型输出相似的预测,这能显著提升模型对输入微小扰动的稳定性。
4. 实操流程与核心环节实现
假设我们已获得AIROGS格式的数据集(包含图像、青光眼标签、视盘坐标、可能的质量标签),下面是一个可复现的实操流程。
4.1 环境准备与数据探查
# 创建环境 conda create -n airogs python=3.8 conda activate airogs pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # 根据CUDA版本调整 pip install opencv-python pillow scikit-learn pandas matplotlib albumentations timm首先,对数据进行彻底分析。使用Pandas加载标注文件,统计各类别数量、设备分布、图像质量分布。用OpenCV随机抽样查看不同质量、不同设备的图像,直观感受挑战所在。这一步至关重要,它决定了后续增强策略的强度和方法选择。
4.2 构建鲁棒的数据预处理流水线
我们使用albumentations库,它比torchvision的transforms更强大,支持更复杂的空间和像素级变换。
import albumentations as A from albumentations.pytorch import ToTensorV2 def get_strong_train_transform(img_size=512): return A.Compose([ # 基础空间变换 A.RandomResizedCrop(height=img_size, width=img_size, scale=(0.8, 1.0)), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1), # 眼底图上下翻转虽不常见,但可增强 A.RandomRotate90(p=0.5), # 模拟质量变异 - 像素级变换 A.OneOf([ A.MotionBlur(blur_limit=5, p=0.5), A.GaussianBlur(blur_limit=(3, 5), p=0.3), A.MedianBlur(blur_limit=3, p=0.2), ], p=0.5), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.75), A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5), A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3), p=0.3), # 模拟伪影 A.CoarseDropout(max_holes=2, max_height=30, max_width=30, fill_value=0, p=0.1), # 颜色归一化 A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ]) def get_val_transform(img_size=512): # 验证/测试时,仅进行中心裁剪、归一化和Tensor转换 return A.Compose([ A.CenterCrop(height=img_size, width=img_size), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ])关键点:OneOf操作让每次增强只随机选择一种模糊方式,更符合实际情况。CoarseDropout模拟随机遮挡。色彩扰动的强度(hue_shift_limit等)需要根据数据探查结果谨慎设置,避免偏离真实色彩太远。
4.3 实现基于ROI的模型训练
我们采用先检测视盘,再裁剪ROI的策略。假设已有视盘中心坐标(cx, cy)。
import torch import torch.nn as nn import timm from torch.utils.data import Dataset, DataLoader import cv2 import numpy as np class GlaucomaROIDataset(Dataset): def __init__(self, df, transform=None, roi_size=512): self.df = df self.transform = transform self.roi_size = roi_size def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] img_path = row['image_path'] label = row['glaucoma_label'] cx, cy = row['disk_center_x'], row['disk_center_y'] # 读取图像 image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) h, w = image.shape[:2] # 以视盘为中心裁剪ROI,处理边界情况 x1 = max(0, int(cx - self.roi_size/2)) y1 = max(0, int(cy - self.roi_size/2)) x2 = min(w, int(cx + self.roi_size/2)) y2 = min(h, int(cy + self.roi_size/2)) roi = image[y1:y2, x1:x2] # 如果裁剪区域小于目标尺寸,进行填充 if roi.shape[0] < self.roi_size or roi.shape[1] < self.roi_size: new_roi = np.zeros((self.roi_size, self.roi_size, 3), dtype=roi.dtype) y_offset = (self.roi_size - roi.shape[0]) // 2 x_offset = (self.roi_size - roi.shape[1]) // 2 new_roi[y_offset:y_offset+roi.shape[0], x_offset:x_offset+roi.shape[1]] = roi roi = new_roi # 应用增强 if self.transform: augmented = self.transform(image=roi) roi = augmented['image'] return roi, torch.tensor(label, dtype=torch.float32) # 定义模型 class GlaucomaClassifier(nn.Module): def __init__(self, model_name='convnext_tiny', pretrained=True, num_classes=1): super().__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='') # 获取backbone特征维度 feature_dim = self.backbone.num_features # 添加自定义头部 self.global_pool = nn.AdaptiveAvgPool2d(1) self.head = nn.Sequential( nn.Linear(feature_dim, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256, num_classes) ) # 添加注意力模块示例(CBAM) self.cbam = CBAM(gate_channels=feature_dim) def forward(self, x): features = self.backbone(x) # 假设backbone输出是 [B, C, H, W] # 应用注意力 features = self.cbam(features) pooled = self.global_pool(features).flatten(1) out = self.head(pooled) return out # 简化的CBAM模块 class CBAM(nn.Module): def __init__(self, gate_channels, reduction_ratio=16): super().__init__() # 通道注意力 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.mlp = nn.Sequential( nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU(), nn.Linear(gate_channels // reduction_ratio, gate_channels) ) # 空间注意力 self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) def forward(self, x): # 通道注意力 avg_out = self.mlp(self.avg_pool(x).squeeze(-1).squeeze(-1)) max_out = self.mlp(self.max_pool(x).squeeze(-1).squeeze(-1)) channel_att = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1) x = x * channel_att # 空间注意力 avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) spatial_att = torch.cat([avg_out, max_out], dim=1) spatial_att = torch.sigmoid(self.conv(spatial_att)) return x * spatial_att4.4 训练循环与损失函数配置
def train_epoch(model, dataloader, optimizer, criterion, device): model.train() running_loss = 0.0 for batch_idx, (images, labels) in enumerate(dataloader): images, labels = images.to(device), labels.to(device).unsqueeze(1) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) # 可选:添加一致性损失(假设我们有强增强和弱增强两个视图) # weak_aug_images = ... 弱增强版本 # strong_aug_images = ... 强增强版本 # outputs_weak = model(weak_aug_images) # outputs_strong = model(strong_aug_images) # consistency_loss = F.mse_loss(outputs_weak.sigmoid(), outputs_strong.sigmoid()) # loss = loss + 0.1 * consistency_loss loss.backward() optimizer.step() running_loss += loss.item() return running_loss / len(dataloader) # 损失函数:带标签平滑的二元交叉熵 + 焦点损失 class FocalLossWithSmoothing(nn.Module): def __init__(self, alpha=0.25, gamma=2.0, smoothing=0.1): super().__init__() self.alpha = alpha self.gamma = gamma self.smoothing = smoothing self.bce = nn.BCEWithLogitsLoss(reduction='none') def forward(self, inputs, targets): # 标签平滑 targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing bce_loss = self.bce(inputs, targets) pt = torch.exp(-bce_loss) # 模型预测对应真实标签的概率 focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss return focal_loss.mean() # 主训练流程 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = GlaucomaClassifier().to(device) criterion = FocalLossWithSmoothing(smoothing=0.05) # 轻微标签平滑 optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) for epoch in range(num_epochs): train_loss = train_epoch(model, train_loader, optimizer, criterion, device) val_metrics = evaluate(model, val_loader, device) # 评估函数需自己实现,计算AUC等 scheduler.step() print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val AUC: {val_metrics["auc"]:.4f}')关键点:这里集成了多个提升鲁棒性的技巧:ROI裁剪聚焦关键区域、强数据增强、标签平滑的焦点损失、以及可选的模型注意力机制。优化器使用AdamW并配合余弦退火调度,有助于模型收敛到更平坦的极小值,这通常与更好的泛化能力相关。
5. 模型评估、集成与部署考量
训练完成后,评估不能只看整体准确率或AUC。
5.1 分层评估与鲁棒性测试
我们需要在多个子集上评估模型,以诊断其薄弱环节:
- 按图像质量分层:将测试集按质量评分分为“优”、“中”、“差”三组,分别计算每组的表现。鲁棒的模型在三组上的性能衰减应很小。
- 按设备来源分层:如果数据有设备信息,按设备分层评估。目标是模型在不同设备上的表现差异不大。
- 对抗性测试:主动对测试图像施加模拟的退化(如高斯模糊、加噪声),观察模型预测概率的变化幅度。一个鲁棒的模型,其预测输出应对此类扰动不敏感。
5.2 模型集成策略
单一模型可能在某些边缘case上失效。集成学习是提升最终系统鲁棒性的有效手段。
- 异构模型集成:使用不同架构的模型(如ConvNeXt, Swin Transformer, EfficientNet)在相同数据上训练,进行预测概率平均。它们可能捕捉到互补的特征。
- 多尺度/多裁剪集成:对同一张测试图像,除了中心裁剪,还可以在视盘坐标周围进行多次随机轻微偏移裁剪,或输入不同尺度的ROI,将所有预测结果平均。这模拟了医生从略微不同角度观察视盘的行为,能提升空间稳定性。
- 测试时增强:在推理时,对输入图像进行多种增强(如水平翻转、小角度旋转),将增强后图像的预测结果进行平均。这是一种廉价的集成方式,能有效平滑输出。
def tta_predict(model, image, transforms_list, device): """测试时增强预测""" model.eval() all_preds = [] with torch.no_grad(): for transform in transforms_list: augmented_img = transform(image=image)['image'].unsqueeze(0).to(device) pred = torch.sigmoid(model(augmented_img)) all_preds.append(pred.cpu()) # 对翻转的预测进行反翻转平均(可选) # 最终预测 final_pred = torch.mean(torch.stack(all_preds), dim=0) return final_pred.item()5.3 部署时的关键考量
将研究模型转化为临床可用的工具,还有最后几公里要走。
- 推理速度优化:使用ONNX或TensorRT对模型进行转换和量化(如FP16/INT8),在不显著损失精度的情况下大幅提升推理速度,满足实时筛查需求。
- 不确定性量化与拒绝机制:模型应能输出其预测的置信度或不确定性。对于置信度低于阈值的病例(通常是质量极差或病理表现不典型的边缘案例),系统应给出“无法判断,建议专科就诊”的提示,而不是强行给出一个可能错误的二分类结果。这是AI系统负责任的表现。
- 人机交互设计:输出不应只是一个“是/否”标签。可提供热力图(Grad-CAM, Score-CAM),可视化模型关注区域,帮助医生理解AI的判断依据,建立信任。例如,将热力图叠加在视盘区域,显示模型认为哪些区域的杯盘比异常或盘沿缺损。
- 持续学习与监控:模型部署后,应在严格隐私保护前提下,收集新的、来自目标部署环境的困难样本(模型判断错误或低置信度的样本),定期进行模型迭代更新,使其能适应数据分布的缓慢漂移。
6. 常见问题、避坑指南与心得
在实际操作中,会遇到许多在理论论文中不会提及的细节问题。
6.1 数据与预处理相关
- 问题:视盘检测不准导致ROI裁剪错误。
- 排查:可视化一批裁剪后的ROI,检查视盘是否在图像中心。如果偏移严重,问题出在检测模型或标注坐标上。
- 解决:使用更鲁棒的视盘检测模型(如基于U-Net的语义分割模型,而非简单的中心点检测)。或者在训练数据中,加入一定比例的、以标注点为中心进行随机小范围偏移的裁剪,让模型对轻微的位置偏差不敏感。
- 问题:数据增强后,模型在干净测试集上性能反而下降。
- 排查:增强强度过大,严重扭曲了眼底图像固有的解剖结构,导致模型学习到的是虚假特征。
- 解决:增强强度需要“微调”。从弱增强开始,逐步加强,并在一个固定的验证集上监控性能。找到那个让模型在干净数据和扰动数据上表现都最好的“平衡点”。
- 问题:类别不平衡。青光眼阳性样本远少于阴性样本。
- 解决:除了使用焦点损失(Focal Loss),更根本的方法是采用分层采样,确保每个训练批次(Batch)中正负样本比例均衡。也可以在数据增强时,对阳性样本施加更多样化的增强。
6.2 模型训练相关
- 问题:模型训练很快过拟合,验证集AUC停滞不前。
- 排查:检查Dropout比率、权重衰减(Weight Decay)是否足够。可能是模型容量过大或数据增强不足。
- 解决:增加Dropout比率(如0.5),增大权重衰减(如1e-3),使用更强的数据增强和标签平滑。考虑使用更小的预训练模型,或提前停止(Early Stopping)。
- 问题:不同随机种子下,模型性能波动很大。
- 排查:这通常说明任务难度大,或模型/训练过程不稳定。
- 解决:使用多个随机种子训练多个模型,最终进行集成。这本身就是一种鲁棒性策略。确保数据加载的Shuffle是充分的,并固定所有可能的随机种子(Python, NumPy, PyTorch)以确保可复现性。
- 问题:注意力热力图显示模型关注无关区域。
- 排查:模型可能学到了数据中的虚假相关性(例如,某个设备拍摄的阴性图像恰好都有某种反光伪影)。
- 解决:在数据增强中模拟这种伪影,并同时出现在正负样本中,打破这种虚假关联。使用领域对抗训练,剥离设备特异性特征。
6.3 评估与部署相关
- 问题:在内部测试集上AUC很高(>0.99),但在外部公开数据集上骤降。
- 原因:这是典型的“实验室效应”,模型过拟合了内部数据集的特定分布。
- 解决:在训练初期就引入外部数据(或留出一部分内部数据作为“伪外部”测试集)进行监控。采用更激进的领域泛化技术。必须认识到,在单一来源数据上刷到再高的分数,临床价值也有限。
- 问题:模型对“疑似青光眼”病例的判断非常不稳定。
- 认知:这未必是模型缺陷。这些病例本身就是临床诊断的难点,专家间也存在分歧。模型输出一个中等概率并伴随高不确定性分数,可能是更合理的。
- 行动:不要追求在这些病例上达到100%的专家一致性。转而评估模型能否将这些病例与明确的正常和异常病例区分开,并为其提供可靠的不确定性估计。
个人心得:参与AIROGS这类竞赛,最大的收获不是名次,而是被迫以“临床可用”的标准来审视自己的技术方案。它让我从一味追求SOTA(最先进)的思维,转向追求STAR(稳定、可转移、可问责、鲁棒)的思维。一个在10个数据源上平均AUC为0.92的模型,远比在1个数据源上AUC为0.98的模型更有价值。最终,医疗AI的价值不在于替代医生,而在于成为医生手中一个在复杂环境下依然可靠的“智能手电筒”,照亮那些原本可能被忽略的角落。在技术实现上,永远要在“模型复杂度”、“数据多样性”、“推理效率”和“结果可解释性”之间寻找最佳平衡点,而这个平衡点,必须通过在实际数据上的反复实验和严谨的分层评估来确定。
