当前位置: 首页 > news >正文

GridMask--随机用“网格状”的遮挡去盖住图片的一部分,迫使模型学习更鲁棒的特征。

图片

这是一个经典的数据增强模块 GridMask,常用于目标检测、BEV、分类等视觉任务。
它的核心思想:
随机用“网格状”的遮挡去盖住图片的一部分,迫使模型学习更鲁棒的特征。
类似:
Cutout(随机挖洞)
Random Erasing
DropBlock
但 GridMask 是:
“规则网格”遮挡,而不是随机矩形。

set_prob()作用:
训练前期弱增强,
后期逐渐增强。
epoch=0 -> prob=0
epoch=50 -> prob=0.5
epoch=100 -> prob=1
属于 curriculum augmentation。

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import cv2class GridMask(nn.Module):def __init__(self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0):super(GridMask, self).__init__()self.use_h = use_hself.use_w = use_wself.rotate = rotateself.offset = offsetself.ratio = ratioself.mode = modeself.st_prob = probself.prob = probdef set_prob(self, epoch, max_epoch):self.prob = self.st_prob * epoch / max_epochdef forward(self, x):if np.random.rand() > self.prob or not self.training:return xn, c, h, w = x.size()x = x.view(-1, h, w)hh = int(1.5 * h)ww = int(1.5 * w)d = np.random.randint(2, h)self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)mask = np.ones((hh, ww), np.float32)st_h = np.random.randint(d)st_w = np.random.randint(d)if self.use_h:for i in range(hh // d):s = d * i + st_ht = min(s + self.l, hh)mask[s:t, :] *= 0if self.use_w:for i in range(ww // d):s = d * i + st_wt = min(s + self.l, ww)mask[:, s:t] *= 0r = np.random.randint(self.rotate)mask = Image.fromarray(np.uint8(mask))mask = mask.rotate(r)mask = np.asarray(mask)mask = mask[(hh - h) // 2 : (hh - h) // 2 + h,(ww - w) // 2 : (ww - w) // 2 + w,]device = x.devicemask = torch.from_numpy(mask).float().to(device)if self.mode == 1:mask = 1 - maskmask = mask.expand_as(x)if self.offset:offset = (torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float().to(device))x = x * mask + offset * (1 - mask)else:x = x * maskreturn x.view(n, c, h, w)# =========================
# 创建 GridMask
# =========================grid_mask = GridMask(True,True,rotate=10,offset=False,ratio=0.65,mode=1,prob=1.0,   # 这里改成1,保证一定触发
)# 一定要train模式
# eval模式不会增强grid_mask.train()# =========================
# 读取图片
# =========================img_path = 'test.jpg'img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)# 保存原图
img_ori = img.copy()# 转 tensor
img_tensor = torch.from_numpy(img).float() / 255.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)# =========================
# GridMask增强
# =========================with torch.no_grad():out = grid_mask(img_tensor)# 转回 numpy
out = out.squeeze(0).permute(1, 2, 0).numpy()# clip
out = np.clip(out, 0, 1)# =========================
# 可视化
# =========================plt.figure(figsize=(12, 6))plt.subplot(1, 2, 1)
plt.imshow(img_ori)
plt.title('Original')
plt.axis('off')plt.subplot(1, 2, 2)
plt.imshow(out)
plt.title('GridMask')
plt.axis('off')plt.tight_layout()
plt.show()

这是一个经典的数据增强模块 GridMask,常用于目标检测、BEV、分类等视觉任务。

它的核心思想:

随机用“网格状”的遮挡去盖住图片的一部分,迫使模型学习更鲁棒的特征。

类似:

  • Cutout(随机挖洞)
  • Random Erasing
  • DropBlock

但 GridMask 是:

“规则网格”遮挡,而不是随机矩形。


整体效果

例如原图:

################
################
################
################

GridMask 后:

##..##..##..####
##..##..##..####
################
################
##..##..##..####

形成周期性遮挡。


代码整体流程

forward 流程:

输入图像↓
随机生成网格mask↓
随机旋转↓
裁剪回原图大小↓
mask乘到图像↓
输出增强图

1. 初始化

def __init__(self,use_h,use_w,rotate=1,offset=False,ratio=0.5,mode=0,prob=1.0
)

参数解释


use_h

是否沿高度方向遮挡

self.use_h = use_h

例如:

True

会生成横条遮挡。



use_w

是否沿宽度方向遮挡

self.use_w = use_w

例如:

True

会生成竖条遮挡。


如果:

use_h=True
use_w=True

就形成网格。


rotate

mask随机旋转角度范围

r = np.random.randint(self.rotate)

例如:

rotate=10

表示:

0~9°

随机旋转。


offset

是否用随机值填充被遮挡区域

默认:

False

即:

遮挡区域 = 0

如果:

True

则:

遮挡区域 = 随机噪声

ratio

遮挡比例

self.l = d * ratio

例如:

d = 100
ratio = 0.5

则:

遮挡宽度 = 50

mode

mask翻转模式

mode=0

正常:

0 -> 被遮挡
1 -> 保留

如果:

mode=1

则反转。


prob

应用增强概率

if np.random.rand() > self.prob:return x

2. set_prob()

def set_prob(self, epoch, max_epoch):self.prob = self.st_prob * epoch / max_epoch

作用:

训练前期弱增强,
后期逐渐增强。

例如:

st_prob = 1.0

训练:

epoch=0   -> prob=0
epoch=50  -> prob=0.5
epoch=100 -> prob=1

属于 curriculum augmentation。


3. forward()


Step1 判断是否增强

if np.random.rand() > self.prob or not self.training:return x

推理阶段不增强。


Step2 获取shape

n, c, h, w = x.size()

例如:

[8, 3, 256, 256]

Step3 reshape

x = x.view(-1, h, w)

变成:

[8*3, 256, 256]

方便mask broadcast。


Step4 扩大mask尺寸

hh = int(1.5 * h)
ww = int(1.5 * w)

为什么?

因为后面会旋转。

旋转后角落会缺失。

扩大后再中心裁剪。

这是经典 trick。


Step5 随机网格间距 d

d = np.random.randint(2, h)

例如:

d = 80

表示:

每80像素一个周期

Step6 计算遮挡宽度

self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)

即:

l = d * ratio

并保证:

1 <= l < d

例如:

d=80
ratio=0.5

得到:

l=40

Step7 初始化mask

mask = np.ones((hh, ww), np.float32)

开始全部为1:

111111111
111111111

Step8 横向遮挡

for i in range(hh // d):s = d * i + st_ht = min(s + self.l, hh)mask[s:t, :] *= 0

例如:

11111111
00000000
11111111
00000000

Step9 纵向遮挡

mask[:, s:t] *= 0

形成网格:

11001100
11001100
00000000
00000000
11001100

Step10 随机旋转

mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)

避免:

总是水平垂直

提升随机性。


Step11 中心裁剪

mask = mask[(hh - h) // 2 : (hh - h) // 2 + h,(ww - w) // 2 : (ww - w) // 2 + w,
]

恢复原图大小。


Step12 转tensor

mask = torch.from_numpy(mask).float().cuda()

shape:

[h,w]

Step13 mode翻转

if self.mode == 1:mask = 1 - mask

Step14 broadcast

mask = mask.expand_as(x)

扩展成:

[n*c, h, w]

Step15 应用mask


普通模式

x = x * mask

遮挡部分变0。


offset模式

x = x * mask + offset * (1 - mask)

被遮挡区域填随机噪声。

类似:

CutMix + Noise

Step16 reshape回去

return x.view(n, c, h, w)

恢复:

[B,C,H,W]

GridMask核心思想

本质:

强迫模型:
不能依赖局部纹理
必须学习全局特征

特别适合:

  • 检测
  • 分割
  • BEV
  • 自动驾驶

因为:

现实中:

  • 遮挡
  • 雨雾
  • 行人遮挡
  • 传感器缺失

都很多。


为什么 BEVDepth 里经常用?

BEV任务:

camera → depth → BEV

非常容易:

过拟合局部纹理

GridMask 能:

  • 提升泛化
  • 提升鲁棒性
  • 防止依赖单摄像头区域

所以:

很多 BEV 系列:

  • BEVDet
  • BEVDepth
  • PETR
  • BEVFormer

都会默认带 GridMask。

http://www.jsqmd.com/news/771676/

相关文章:

  • KMS智能激活工具终极指南:如何永久激活Windows和Office系统
  • Temu在韩国提速“火箭配送”:当日达背后,跨境物流的护城河正在变深
  • 如何利用 Taotoken 的用量看板分析与优化你的大模型 API 支出
  • 【限时解密】AISMM人才成熟度诊断矩阵(v3.2):仅开放72小时,测完立即生成定制化招聘策略报告
  • 热键侦探:3步解决Windows热键冲突的终极指南
  • 构建高性能Web图像处理应用:OpenCV.js架构与集成指南
  • 2026实验室净化装修公司合规选型与权威对比指南 - 品牌策略主理人
  • 基于多智能体与具身AI的龙虾社交广场:架构设计与工程实践
  • 基于AI与双级缓存的新闻聚合器:从架构设计到工程实践
  • 如何测试 CloudCone VPS 的磁盘 IO 性能是否达标
  • 如何解决Upscayl中的Vulkan兼容性问题:完整指南
  • MAA助手:明日方舟自动化工具终极使用指南
  • 告别模糊屏!AMD黑苹果Sonoma下开启2K HIDPI的详细步骤与工具推荐
  • AISMM评估数据可视化落地难?92%团队忽略的4个关键指标校准点(附权威验证脚本)
  • 开发者技能图谱:结构化学习路径与知识体系构建指南
  • 2026北京小程序开发哪家最靠谱?国内排名前十专业的小程序定制开发服务商盘点 - 品牌策略主理人
  • 收藏!小白程序员轻松入门大模型:6步解锁AI Agent开发全攻略
  • AISMM模型深度解构:从0到1打造技术品牌的4个不可逆阶段
  • 在 Hermes Agent 项目中集成 Taotoken 提供方的详细配置步骤
  • 通过Taotoken CLI工具一键配置开发环境中的API访问密钥
  • AISMM模型实施失败的3个隐性根源,92%CTO至今未察觉——今天不读,下周就可能被审计否决
  • JavaScript 鼠标滚轮事件详解:监听向上/向下滑动
  • 2026年高精度便携式超声波流量计品牌口碑与厂家实力介绍 - 品牌推荐大师1
  • 蓝桥杯单片机备赛:用NE555测频率,从原理图到代码的避坑实操
  • 2026年素材网站选购指南:实测5款优质平台,告别选型焦虑 - 极欧测评
  • 温岭市大溪致翔机械设备租赁:专业的台州吊车租赁公司 - LYL仔仔
  • 基于Next.js与GitHub Pages构建个人开发者门户:从SSG到CI/CD全流程实践
  • 拆解特斯拉Autopilot与比亚迪DiPilot:主流车企的ADAS方案到底有何不同?
  • OR-Tools:如何用Google的运筹学引擎解决现实世界优化难题?
  • 【IEEE出版、高校联合主办、启动评优】第八届物联网、自动化和人工智能国际学术会议(IoTAAI 2026)