当前位置: 首页 > news >正文

Kaggle竞赛实战:用TensorFlow2搞定Severstal钢板缺陷检测(附完整代码解析)

Kaggle竞赛实战:用TensorFlow2搞定Severstal钢板缺陷检测(附完整代码解析)

当工业质检遇上深度学习,传统人工检测的局限性被彻底打破。Severstal钢板缺陷检测竞赛正是这样一个典型场景——要求参赛者通过算法自动识别钢板表面的四类缺陷。这不仅考验模型对微小缺陷的敏感度,更需要处理工业场景特有的数据分布问题。本文将带您从零构建ResUNet模型,并分享竞赛中那些教科书不会告诉你的实战技巧。

1. 竞赛背景与数据解码艺术

钢板表面缺陷检测是工业质检中的经典难题。Severstal提供的训练集包含12568张灰度图像,每张图像尺寸为1600×256像素。与常规分割任务不同,竞赛采用RLE(Run-Length Encoding)格式存储标注数据,这种压缩存储方式对数据处理提出了特殊要求。

1.1 RLE编码解析实战

RLE编码原理是将连续出现的像素值用"起始位置-长度"表示。例如编码串"29102 12 29346 24"表示:

  • 从第29102个像素开始连续12个像素为缺陷区域
  • 从第29346个像素开始连续24个像素为缺陷区域
def rle_to_mask(rle_string, height, width): if rle_string == -1: # 无缺陷情况 return np.zeros((height, width)) rle_numbers = list(map(int, rle_string.split())) pairs = np.array(rle_numbers).reshape(-1, 2) mask = np.zeros(height * width, dtype=np.uint8) for pos, length in pairs: mask[pos-1:pos-1+length] = 255 # 注意RLE计数从1开始 return mask.reshape(width, height).T # 转置恢复原始方向

关键细节:RLE编码按列优先(column-major)顺序存储,而OpenCV默认按行优先(row-major)读取,需要转置处理

1.2 数据分布与类别不平衡

四类缺陷的样本数量差异显著:

缺陷类别训练样本数占比
Class189742.7%
Class224711.8%
Class334216.3%
Class461229.2%

这种不平衡会导致模型偏向多数类。我们的解决方案是:

  • 对少数类样本进行过采样
  • 在损失函数中引入类别权重
  • 采用Focal Loss降低易分类样本的权重

2. 构建工业级数据管道

高效的数据管道能提升GPU利用率,避免训练过程出现瓶颈。我们设计的多线程DataGenerator支持实时数据增强。

2.1 自定义DataGenerator

class SteelDefectGenerator(tf.keras.utils.Sequence): def __init__(self, image_ids, masks, batch_size=16, img_size=(256,512), augment=True): self.image_ids = image_ids self.masks = masks self.batch_size = batch_size self.img_size = img_size self.augment = augment self.on_epoch_end() def __len__(self): return int(np.ceil(len(self.image_ids)/self.batch_size)) def __getitem__(self, idx): batch_ids = self.image_ids[idx*self.batch_size:(idx+1)*self.batch_size] X = np.zeros((len(batch_ids), *self.img_size, 1), dtype=np.float32) y = np.zeros((len(batch_ids), *self.img_size, 4), dtype=np.float32) for i, img_id in enumerate(batch_ids): img = cv2.imread(f"train_images/{img_id}", 0) img = cv2.resize(img, (self.img_size[1], self.img_size[0])) # 标准化 img = (img - img.mean()) / (img.std() + 1e-6) # 为每张图像生成4个通道的mask for class_id in range(4): rle = self.masks.get(f"{img_id}_{class_id+1}") mask = rle_to_mask(rle, 256, 1600) if rle else np.zeros((256,1600)) mask = cv2.resize(mask, (self.img_size[1], self.img_size[0])) y[i,...,class_id] = (mask > 127).astype(np.float32) # 数据增强 if self.augment and np.random.rand() > 0.5: img, y[i] = self.random_augment(img, y[i]) X[i,...,0] = img return X, y def random_augment(self, img, mask): # 水平翻转 if np.random.rand() > 0.5: img = cv2.flip(img, 1) mask = np.stack([cv2.flip(mask[...,i],1) for i in range(4)], axis=-1) # 随机亮度调整 img = np.clip(img * (0.8 + 0.4*np.random.rand()), -1, 1) return img, mask

2.2 数据加载优化技巧

  • 预取(prefetch):在GPU处理当前批次时,CPU准备下一批数据
  • 并行化(map):利用多核CPU并行执行图像解码
  • 缓存(cache):将预处理后的数据缓存到内存或本地存储
def create_dataset(image_ids, masks, batch_size=16): def generator(): for img_id in image_ids: img = cv2.imread(f"train_images/{img_id}", 0) img = cv2.resize(img, (512,256)) img = (img - img.mean()) / (img.std() + 1e-6) masks = [] for class_id in range(4): rle = masks.get(f"{img_id}_{class_id+1}") mask = rle_to_mask(rle, 256,1600) if rle else np.zeros((256,1600)) mask = cv2.resize(mask, (512,256)) masks.append((mask > 127).astype(np.float32)) yield img[...,np.newaxis], np.stack(masks, axis=-1) return tf.data.Dataset.from_generator( generator, output_types=(tf.float32, tf.float32), output_shapes=((256,512,1), (256,512,4)) ).batch(batch_size).prefetch(tf.data.AUTOTUNE)

3. ResUNet模型架构解析

结合ResNet的残差连接与UNet的跳跃连接,我们设计了一个适合小样本学习的轻量级架构。

3.1 核心构建块实现

def conv_block(x, filters, kernel_size=3, use_bn=True): x = tf.keras.layers.Conv2D(filters, kernel_size, padding='same', use_bias=False)(x) if use_bn: x = tf.keras.layers.BatchNormalization()(x) x = tf.keras.layers.ReLU()(x) return x def residual_block(x, filters): shortcut = x x = conv_block(x, filters) x = conv_block(x, filters, use_bn=False) return tf.keras.layers.Add()([x, shortcut])

3.2 完整模型架构

def build_resunet(input_shape=(256,512,1)): inputs = tf.keras.Input(shape=input_shape) # 编码器 x = conv_block(inputs, 32) skip1 = x x = tf.keras.layers.MaxPool2D()(x) # 128x256 x = residual_block(x, 64) skip2 = x x = tf.keras.layers.MaxPool2D()(x) # 64x128 x = residual_block(x, 128) skip3 = x x = tf.keras.layers.MaxPool2D()(x) # 32x64 # 桥接层 x = residual_block(x, 256) # 解码器 x = tf.keras.layers.UpSampling2D()(x) # 64x128 x = tf.keras.layers.Concatenate()([x, skip3]) x = residual_block(x, 128) x = tf.keras.layers.UpSampling2D()(x) # 128x256 x = tf.keras.layers.Concatenate()([x, skip2]) x = residual_block(x, 64) x = tf.keras.layers.UpSampling2D()(x) # 256x512 x = tf.keras.layers.Concatenate()([x, skip1]) x = residual_block(x, 32) # 输出层 outputs = tf.keras.layers.Conv2D(4, 1, activation='sigmoid')(x) return tf.keras.Model(inputs, outputs)

模型参数量约1.2M,在RTX 3090上单个epoch训练时间约3分钟

4. 竞赛专用损失函数与调参策略

工业质检场景需要特殊设计的损失函数,我们对比了多种方案的优劣。

4.1 复合损失函数实现

def dice_coef(y_true, y_pred, smooth=1e-6): y_true_f = tf.keras.layers.Flatten()(y_true) y_pred_f = tf.keras.layers.Flatten()(y_pred) intersection = tf.reduce_sum(y_true_f * y_pred_f) return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth) def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=0.75): y_true_pos = tf.keras.layers.Flatten()(y_true) y_pred_pos = tf.keras.layers.Flatten()(y_pred) true_pos = tf.reduce_sum(y_true_pos * y_pred_pos) false_neg = tf.reduce_sum(y_true_pos * (1-y_pred_pos)) false_pos = tf.reduce_sum((1-y_true_pos)*y_pred_pos) tversky = (true_pos + 1e-6)/(true_pos + alpha*false_neg + beta*false_pos + 1e-6) return tf.pow(1-tversky, gamma) def mixed_loss(y_true, y_pred): bce = tf.keras.losses.binary_crossentropy( tf.keras.layers.Flatten()(y_true), tf.keras.layers.Flatten()(y_pred) ) return 0.5*bce + 0.5*focal_tversky_loss(y_true, y_pred)

4.2 学习率调度策略

def get_lr_callback(batch_size=16): lr_start = 0.0005 lr_max = 0.001 * batch_size lr_min = 0.0001 lr_ramp_ep = 5 lr_sus_ep = 0 lr_decay = 0.8 def lrfn(epoch): if epoch < lr_ramp_ep: lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start elif epoch < lr_ramp_ep + lr_sus_ep: lr = lr_max else: lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min return lr return tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)

4.3 模型训练技巧

  • 渐进式图像尺寸:先在小尺寸(128x256)上预训练,再微调全尺寸(256x512)
  • 早停机制:当验证损失连续3个epoch未下降时停止训练
  • SWA(Stochastic Weight Averaging):在训练后期对权重进行平均提升稳定性
checkpoint = tf.keras.callbacks.ModelCheckpoint( 'best_model.h5', monitor='val_dice_coef', mode='max', save_best_only=True, verbose=1 ) early_stop = tf.keras.callbacks.EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True ) history = model.fit( train_dataset, validation_data=valid_dataset, epochs=30, callbacks=[get_lr_callback(), checkpoint, early_stop] )

5. 后处理与提交优化

模型预测后的处理步骤直接影响最终得分,我们开发了针对性的后处理流程。

5.1 缺陷区域优化

def postprocess_mask(mask, min_area=50): # 去除小面积区域 mask = morphology.remove_small_objects(mask > 0.5, min_size=min_area) # 填充小孔洞 mask = morphology.remove_small_holes(mask, area_threshold=min_area) return mask.astype(np.uint8) * 255 def mask_to_submission_format(mask, img_id, class_id): mask = cv2.resize(mask, (1600, 256)) mask = postprocess_mask(mask) rle = mask_to_rle(mask) return f"{img_id}_{class_id}", rle

5.2 测试集推理流程

def predict_test_set(model, test_dir, output_file='submission.csv'): test_images = glob.glob(f"{test_dir}/*.jpg") results = [] for img_path in tqdm(test_images): img_id = os.path.basename(img_path) img = cv2.imread(img_path, 0) img = cv2.resize(img, (512,256)) img = (img - img.mean()) / (img.std() + 1e-6) pred = model.predict(img[np.newaxis,...,np.newaxis])[0] for class_id in range(4): mask = pred[...,class_id] item_id, rle = mask_to_submission_format(mask, img_id, class_id+1) results.append([item_id, rle]) pd.DataFrame(results, columns=["ImageId_ClassId", "EncodedPixels"]).to_csv(output_file, index=False)

在最终提交版本中,我们通过以下策略将Dice系数从0.72提升到0.89:

  • 采用TTA(Test Time Augmentation):对测试图像进行水平翻转等增强并平均预测结果
  • 动态调整缺陷面积阈值:根据不同类别特点设置不同的min_area参数
  • 模型集成:融合ResUNet、EfficientNet-B4和DeepLabV3+三个模型的预测结果
http://www.jsqmd.com/news/502118/

相关文章:

  • StructBERT情感分类模型在旅游评论分析中的创新应用
  • 3大维度彻底攻克ComfyUI视频合成节点缺失问题
  • 无需代码的文本分类神器:AI万能分类器WebUI快速上手体验
  • YOLO11快速部署指南:一键安装,无需配置,开箱即用
  • MiniCPM-V-2_6入门必看:C语言调用模型API的完整示例
  • 3DDFA:如何用单张图片实现高精度三维人脸重建
  • 基于Fay数字人框架的虚拟主持人互动游戏道具系统:从搭建到实战完整指南
  • 如何使用BlurAdmin构建响应式表单:动态字段与复杂验证完整指南
  • PE Tools常见问题解答:解决逆向工程中的典型问题与挑战
  • 如何解决CKEditor编辑器粘贴Word文档时公式乱码的问题?
  • SmallThinker-3B-Preview模型服务化:使用Dify平台构建可视化AI工作流
  • 革新性数据看板:重新定义个人知识管理的工作方式
  • 腾讯混元翻译模型Hunyuan-MT-7B效果展示:多语言翻译实测对比
  • Clawdbot+Qwen3:32B实战:一键部署私有AI对话网关
  • Kingbase数据库运维实战:这些高频命令帮你省下80%时间(附场景案例)
  • 从需求到落地:2026园区专用边缘计算盒子厂家推荐 - 品牌2026
  • RT-1背后的秘密:为什么Transformer能成为具身智能的最佳选择?
  • Gemma-3-12b-it本地AI助手升级指南:集成OCR+语音输入多模态入口
  • ABB机器人有效载荷测定实战:如何用LoadIdentify程序快速校准搬运夹具参数
  • 科幻角色设计宝库:LumiPixel Canvas Quest生成外星种族与未来人类
  • DeepChat多平台部署指南:3大系统×6个关键步骤实现跨平台兼容
  • Pi0 Robot Control Center快速部署:Docker镜像构建与8080端口自定义配置
  • 阿里通义Z-Image-Turbo实战:用AI为电商生成高质感产品概念图
  • 什么是初始访问权限?如何用它落实最小权限原则
  • 如何高效获取中小学电子课本:教师与学生的实用下载工具指南
  • Pixel Art to CSS:像素艺术与CSS转换的无缝桥梁 | 前端开发者的创意解决方案
  • AgentCPM深度研报助手:保障数据隐私的本地研究工具
  • Botkit享元模式:优化机器人资源使用的终极指南
  • 3C认证充电宝哪个品牌靠谱?2026年安全品牌推荐与选购指南 - 新闻快传
  • DeOldify与数据库联动:构建历史图像色彩管理平台