当前位置: 首页 > news >正文

动手学深度学习笔记:丢弃法(Dropout)代码实现

在上一节中,我们已经知道,Dropout的核心思想是:

在训练时随机丢弃一部分神经元,防止模型过度依赖某些局部特征,从而缓解过拟合。

这一节的重点,就是把 Dropout 真正写出来,看看它在代码里到底怎么实现。


1. 从零实现 Dropout 的思路

Dropout 的本质其实很简单:

假设某一层输出是X,我们做三件事:

  1. 生成一个随机掩码mask

  2. 按概率把一部分位置变成 0

  3. 对保留下来的值做缩放

对应公式是:

其中:

  • (X):输入

  • (M):随机掩码

  • (p):丢弃概率

  • (\odot):按元素相乘

也就是说:

  • 被丢弃的位置直接清零

  • 保留下来的位置除以1-p,保持整体期望不变


2. 从零开始实现 Dropout 函数

先手动写一个dropout_layer

import torch from torch import nn from d2l import torch as d2l

导入需要的库。

然后定义 Dropout 层:

def dropout_layer(X, dropout): assert 0 <= dropout <= 1 if dropout == 1: return torch.zeros_like(X) if dropout == 0: return X mask = (torch.rand(X.shape) > dropout).float() return mask * X / (1.0 - dropout)

3. 这段代码怎么理解?

第 1 行

assert 0 <= dropout <= 1

确保丢弃概率合法,必须在 0 到 1 之间。


第 2~3 行

if dropout == 1: return torch.zeros_like(X)

如果丢弃率是 1,说明全部丢掉,直接返回全 0 张量。


第 4~5 行

if dropout == 0: return X

如果丢弃率是 0,说明一个都不丢,直接返回原输入。


第 6 行

mask = (torch.rand(X.shape) > dropout).float()

这一步生成随机掩码。

  • torch.rand(X.shape)会生成和X同形状的随机数

  • 每个位置都是 0 到 1 之间的均匀分布

  • 如果某个位置大于dropout,就保留,记为 1

  • 否则丢弃,记为 0

例如dropout=0.5时,大约一半位置会变成 0。


第 7 行

return mask * X / (1.0 - dropout)

这一步才是真正的 Dropout。

  • mask * X:被丢掉的位置直接归零

  • / (1.0 - dropout):对保留下来的值做缩放

这样做是为了让训练时输出的期望值和测试时尽量一致。


4. 简单测试一下

可以随便构造一个输入看看效果:

X = torch.arange(16, dtype=torch.float32).reshape((2, 8)) print(X) print(dropout_layer(X, 0)) print(dropout_layer(X, 0.5)) print(dropout_layer(X, 1))

你会发现:

  • dropout=0:输出不变

  • dropout=0.5:随机一半元素变成 0,其余放大

  • dropout=1:全部变成 0


5. 在多层感知机中使用 Dropout

接下来,把 Dropout 加到 MLP 中。

这里仍然用 Fashion-MNIST 数据集做分类任务。

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256 dropout1, dropout2 = 0.2, 0.5

这里表示:

  • 输入维度 784

  • 输出类别 10

  • 两个隐藏层,每层 256 个神经元

  • 第一层后丢弃率 0.2

  • 第二层后丢弃率 0.5


6. 定义模型参数

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens1) * 0.01) b1 = nn.Parameter(torch.zeros(num_hiddens1)) W2 = nn.Parameter(torch.randn(num_hiddens1, num_hiddens2) * 0.01) b2 = nn.Parameter(torch.zeros(num_hiddens2)) W3 = nn.Parameter(torch.randn(num_hiddens2, num_outputs) * 0.01) b3 = nn.Parameter(torch.zeros(num_outputs)) params = [W1, b1, W2, b2, W3, b3]

这和前面从零实现 MLP 的思路一样:

  • W1, b1:输入层到第一隐藏层

  • W2, b2:第一隐藏层到第二隐藏层

  • W3, b3:第二隐藏层到输出层


7. 定义 ReLU 激活函数

def relu(X): a = torch.zeros_like(X) return torch.max(X, a)

这里还是使用最常见的 ReLU。


8. 定义前向传播

这一步最关键。

def net(X, is_training=True): X = X.reshape((-1, num_inputs)) H1 = relu(X @ W1 + b1) if is_training: H1 = dropout_layer(H1, dropout1) H2 = relu(H1 @ W2 + b2) if is_training: H2 = dropout_layer(H2, dropout2) return H2 @ W3 + b3

9. 这一段前向传播怎么理解?

第 1 行

X = X.reshape((-1, num_inputs))

把输入图片展平成 784 维向量。


第 2 行

H1 = relu(X @ W1 + b1)

输入层经过第一层线性变换,再经过 ReLU,得到第一隐藏层输出。


第 3~4 行

if is_training: H1 = dropout_layer(H1, dropout1)

如果当前处于训练模式,就对第一隐藏层做 Dropout。

如果是测试模式,就不做 Dropout。


第 5 行

H2 = relu(H1 @ W2 + b2)

第一隐藏层再经过第二层线性变换和 ReLU,得到第二隐藏层输出。


第 6~7 行

if is_training: H2 = dropout_layer(H2, dropout2)

第二隐藏层也做 Dropout。


第 8 行

return H2 @ W3 + b3

最后送入输出层,得到 10 个类别分数。


10. 训练模型

下面就和前面一样,定义损失函数和优化器。

num_epochs, lr, batch_size = 10, 0.5, 256 loss = nn.CrossEntropyLoss(reduction='none') train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) updater = torch.optim.SGD(params, lr=lr)

然后训练:

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

11. 完整代码

下面给你一份适合直接放博客的完整版本。

import torch from torch import nn from d2l import torch as d2l def dropout_layer(X, dropout): assert 0 <= dropout <= 1 if dropout == 1: return torch.zeros_like(X) if dropout == 0: return X mask = (torch.rand(X.shape) > dropout).float() return mask * X / (1.0 - dropout) num_inputs, num_outputs = 784, 10 num_hiddens1, num_hiddens2 = 256, 256 dropout1, dropout2 = 0.2, 0.5 W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens1) * 0.01) b1 = nn.Parameter(torch.zeros(num_hiddens1)) W2 = nn.Parameter(torch.randn(num_hiddens1, num_hiddens2) * 0.01) b2 = nn.Parameter(torch.zeros(num_hiddens2)) W3 = nn.Parameter(torch.randn(num_hiddens2, num_outputs) * 0.01) b3 = nn.Parameter(torch.zeros(num_outputs)) params = [W1, b1, W2, b2, W3, b3] def relu(X): a = torch.zeros_like(X) return torch.max(X, a) def net(X, is_training=True): X = X.reshape((-1, num_inputs)) H1 = relu(X @ W1 + b1) if is_training: H1 = dropout_layer(H1, dropout1) H2 = relu(H1 @ W2 + b2) if is_training: H2 = dropout_layer(H2, dropout2) return H2 @ W3 + b3 num_epochs, lr, batch_size = 10, 0.5, 256 loss = nn.CrossEntropyLoss(reduction='none') train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) updater = torch.optim.SGD(params, lr=lr) d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

12. 简洁实现

如果不用从零写,PyTorch 已经内置了 Dropout,写法更简单。

net = nn.Sequential( nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, 10) )

然后正常训练:

num_epochs, lr, batch_size = 10, 0.5, 256 loss = nn.CrossEntropyLoss() trainer = torch.optim.SGD(net.parameters(), lr=lr) train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size) d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

13. 从零实现和简洁实现的区别

从零实现

你自己手动写了:

  • 随机掩码

  • 保留与丢弃逻辑

  • 缩放操作

  • 训练和测试模式控制

优点是能真正理解 Dropout 的本质。


简洁实现

直接调用:

nn.Dropout(p)

优点是代码短、开发方便,实际工程中更常用。


14. 总结

丢弃法的代码实现,本质上就是三步:

  • 随机生成掩码

  • 把一部分神经元输出置 0

  • 对保留下来的部分做缩放

在训练时启用 Dropout,在测试时关闭 Dropout。
这样做可以有效打破神经元之间的过强依赖,从而缓解过拟合。

所以这一节最关键的代码其实就是:

mask = (torch.rand(X.shape) > dropout).float() return mask * X / (1.0 - dropout)

它几乎就把 Dropout 的核心思想完整体现出来了。


你下一条最适合接的是《丢弃法每一行代码详细注释版》

http://www.jsqmd.com/news/457046/

相关文章:

  • Linux 无处不在,却征服不了台式机?
  • 从“群聊会议”到“施工蓝图”:LangGraph如何让AI工作流稳如泰山?
  • Linux Vim编辑器完全教程:从入门到精通,程序员必备
  • 企业主管必看!Ecovadis评级材料的时效性 - 奋飞咨询ecovadis
  • OpenClaw如何重塑AI代理为个人操作系统的?为什么值得每一个网络工程师关注?
  • 基于 YOLOv8 的肺炎 X 光影像智能辅助诊断系统 前沿 AI 算法 + 实用医疗场景
  • 2026年玻璃钢桥架厂家实力推荐:河北沃瀚环保设备有限公司全系产品解析 - 品牌推荐官
  • 织梦DedeCms 5.7 无法生成首页的解决方法
  • SQL 基础及 MySQL DBA 运维实战 - 4:MySQL 备份与恢复全实战(XtraBackup和mysqldump)
  • 2026年3月安全门窗十大品牌最新推荐 国标权威抗台风 - 资讯焦点
  • K8S存储管理:从Volume到PV/PVC实战
  • 2026年企业人事服务推荐:厦门布瑞泽人才信息服务有限公司,人事代理/外包/招聘一站式解决方案 - 品牌推荐官
  • 2026年谷歌SEO公司权威榜单:十大顶级服务商深度评测 - 资讯焦点
  • 2026成人用品加盟平台哪家好?5大维度实测对比,找到最适合你的那一款 - 资讯焦点
  • [学点编程]python workout,每天10分钟学会python 读书笔记
  • 2026钢带增强螺旋波纹管厂家推荐:pe钢带增强波纹管/钢带增强pe波纹管/hdpe增强钢带螺旋波纹管厂家精选。 - 品牌推荐官
  • 2026年制砖机设备推荐:郑州不二精工设备有限公司,全系制砖机满足多样化生产需求 - 品牌推荐官
  • 3秒去除豆包AI图片水印(无需PS、美图秀秀等工具)
  • 2026年3月GEO服务商实力评测排名:Top7综合竞争力权威榜单发布 - 资讯焦点
  • 2026年外贸建站服务商深度评测:十大实力派机构助您出海无忧 - 资讯焦点
  • 复试专业课问答题
  • !!形成网页显示当前系统时间!!
  • 2026年管网监测设备推荐:安耐恩窖井数据采集器/管网RTU/遥测终端全系解决方案 - 品牌推荐官
  • 专精特新典范:绍兴镭斯特测径仪,小仪器撬动大制造的质量革命 - 资讯焦点
  • ssm+java2026年毕设奢品网站系统【源码+论文】
  • 【实时Linux工业PLC解决方案系列】第二十篇 - 实时Linux PLC故障诊断与报警机制
  • ssm+java2026年毕设舌象识别健康服务系统app【源码+论文】
  • 毕设程序java保险客户管理系统 基于SpringBoot的寿险客户全生命周期管理平台 数字化保险客户运营与保单服务中心系统
  • Highcharts旭日图(Sunburst)使用指南|层级数据的环形可视化艺术
  • ssm+java2026年毕设设备营销【源码+论文】