当前位置: 首页 > news >正文

动手学深度学习——样式迁移代码

1. 前言

上一篇我们已经从整体上理解了**样式迁移(Style Transfer)**的思想:

  • 内容图提供结构与布局

  • 样式图提供纹理与风格

  • 生成图通过不断优化得到

  • 优化目标同时考虑内容损失和风格损失

这一节我们就正式进入代码实现。

和前面 FCN 那种“搭模型”不太一样,样式迁移代码的核心不是新建一个复杂网络,而是:

固定一个预训练特征提取网络,然后把生成图本身当作待优化参数。

所以这一节真正要看懂的是:

  • 怎么提取内容特征

  • 怎么提取风格特征

  • Gram 矩阵怎么计算

  • 总损失怎么写

  • 怎么一步步优化生成图


2. 样式迁移代码的整体流程

李沐这里的实现思路非常清楚,可以概括成下面几步:

第一步:读取内容图和样式图

把两张图都处理成网络可接受的张量格式。

第二步:加载预训练 CNN

通常使用 VGG 这类网络提取中间层特征。

第三步:指定哪些层表示内容,哪些层表示风格

  • 内容通常取较深层

  • 风格通常取多个层

第四步:初始化生成图

常见做法是直接用内容图初始化,而不是随机噪声。

第五步:定义损失函数

包括:

  • 内容损失

  • 风格损失

  • 总变差损失

第六步:优化生成图

不断更新生成图,使其同时接近内容图和样式图。


3. 读取与预处理图像

首先要做的是把图片读进来,并统一成网络需要的输入格式。

常见写法大致如下:

import torch from torch import nn from torchvision import models, transforms from PIL import Image

然后定义图像预处理:

rgb_mean = torch.tensor([0.485, 0.456, 0.406]) rgb_std = torch.tensor([0.229, 0.224, 0.225]) preprocess = transforms.Compose([ transforms.Resize((300, 450)), transforms.ToTensor(), transforms.Normalize(mean=rgb_mean, std=rgb_std) ])

这里做了三件事:

  • 调整图像尺寸

  • 转成张量

  • 按 ImageNet 的均值和方差做标准化

为什么要标准化?

因为后面用的是预训练网络,它就是在这种输入分布下训练出来的。


4. 为什么显示图像时要反归一化

输入网络时用了标准化,但如果想把结果显示出来,就必须把它恢复回正常图像范围。

通常会写一个“后处理”函数,把图像从标准化空间还原回来。

核心思想是:

  • 先乘标准差

  • 再加均值

  • 最后裁剪到0~1

因为网络处理的是“适合训练的张量”,
而我们人眼想看的是“正常 RGB 图片”。


5. 加载预训练特征提取网络

经典样式迁移里,最常用的是 VGG 网络。
例如:

pretrained_net = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features

这里取的是.features,而不是整个分类头。

原因很简单:

样式迁移只需要卷积层提取的中间特征,不需要最后的分类器。

所以这里保留的是特征提取部分。


6. 为什么样式迁移主要用中间层特征

因为我们不关心最终的分类结果,
我们关心的是图像在不同层上的表示。

不同层特征有不同作用:

  • 较浅层:更偏向边缘、纹理、颜色局部模式

  • 较深层:更偏向物体结构和语义布局

这刚好对应样式迁移里的两个目标:

  • 内容

  • 风格

所以中间层特征才是关键。


7. 指定内容层和风格层

李沐这里通常会显式指定:

  • 哪些层用来表示内容

  • 哪些层用来表示风格

例如:

style_layers = [0, 5, 10, 19, 28] content_layers = [25]

这表示:

  • 风格用多个层共同表示

  • 内容通常只用一个较深层表示

为什么风格要用多个层?

因为风格不是单一层次的信息,而是从浅层纹理到稍深层模式的综合统计。

为什么内容只用较深层?

因为内容更关心整体结构和语义,不需要太多浅层细节。


8. 提取中间层特征

通常会定义一个函数,把指定层的输出取出来:

def extract_features(X, content_layers, style_layers): contents = [] styles = [] for i, layer in enumerate(pretrained_net): X = layer(X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents, styles

这段代码的意义非常大。

它不是只拿最终输出,而是在网络前向传播过程中,
把我们关心的中间层结果都收集起来。

最终返回两组特征:

  • 内容特征

  • 风格特征


9. 内容特征怎么用

对于内容图,我们提取一次内容特征并固定下来:

content_Y, _ = extract_features(content_X, content_layers, style_layers)

这里的content_Y就是内容目标。

之后每次优化生成图时,都要让生成图的对应内容特征接近它。

所以内容图特征在整个训练过程中相当于:

被模仿的内容标准答案


10. 风格特征为什么不能直接拿原特征比较

因为风格不强调空间位置精确一致。

如果你直接逐元素比较某层特征,就会隐含一个要求:

这个纹理必须出现在完全对应的位置

这显然不符合“风格”的含义。

风格更像一种整体统计规律,而不是某个具体位置上的精确匹配。

所以风格不能直接比较特征本身,而要比较其统计关系。


11. Gram 矩阵的代码实现

样式迁移中最核心的一个函数就是 Gram 矩阵。

常见写法如下:

def gram(X): num_channels, n = X.shape[1], X.numel() // X.shape[1] X = X.reshape((num_channels, n)) return torch.matmul(X, X.T) / (num_channels * n)

这段代码一定要理解。

第一步:把特征图按通道展开

原来的特征图可能是:

(batch, channels, height, width)

这里本质上是把每个通道拉成一个长向量。

第二步:计算通道之间内积

X @ X.T

得到的是一个:

channels × channels

的矩阵。

它描述的是:

不同特征通道之间的相关性

这正是风格的重要表达。


12. 为什么 Gram 矩阵能表示风格

因为风格更像:

  • 颜色分布规律

  • 纹理重复模式

  • 特征通道之间共同激活的结构

而 Gram 矩阵正好抓住了:

通道之间的整体统计相关性

它弱化了“具体在哪里”,强化了“整体像不像这种风格”。

所以经典样式迁移里,风格目标通常就是:

让生成图的 Gram 矩阵接近样式图的 Gram 矩阵。


13. 先提取样式图的 Gram 矩阵目标

通常会先把样式图跑一遍网络,并计算各风格层的 Gram 矩阵:

_, style_Y = extract_features(style_X, content_layers, style_layers) style_Y_gram = [gram(Y) for Y in style_Y]

这相当于固定下:

  • 样式图在多个层上的风格表示

后面优化生成图时,就让生成图的 Gram 矩阵逐渐接近这些目标。


14. 内容损失怎么写

内容损失通常很直接,就是比较内容特征之间的差异。

例如:

def content_loss(Y_hat, Y): return torch.square(Y_hat - Y.detach()).mean()

这里:

  • Y_hat是生成图的内容特征

  • Y是内容图的内容特征

detach()的作用是把目标视为常量,不参与梯度更新。

内容损失越小,说明生成图在高层语义结构上越接近内容图。


15. 风格损失怎么写

风格损失则比较 Gram 矩阵:

def style_loss(Y_hat, gram_Y): return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

这里:

  • gram(Y_hat)是生成图当前风格特征的 Gram 矩阵

  • gram_Y是样式图对应层的 Gram 矩阵目标

风格损失越小,说明生成图越具有样式图的纹理和风格统计特征。


16. 总变差损失是什么

除了内容和风格损失,通常还会加一个总变差损失。

例如:

def tv_loss(Y_hat): return 0.5 * ( torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() + torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean() )

它的作用是:

让生成图在相邻像素之间更平滑,减少噪声和过度抖动

因为如果只优化内容和风格,有时图像会出现很多局部噪点。
总变差损失能让结果更自然一些。


17. 总损失如何组合

接下来就要把三类损失加起来。

常见形式是:

total_loss = content_weight * content_l \ + style_weight * style_l \ + tv_weight * tv_l

这里的三个权重非常重要:

  • content_weight决定保留内容的强度

  • style_weight决定风格迁移的强度

  • tv_weight决定平滑程度

如果:

  • 内容权重大,结果更像原图

  • 风格权重大,结果更像画作

  • 平滑权重大,图更干净但可能少些细节

所以样式迁移效果很大程度上取决于这些权重平衡。


18. 生成图为什么常常直接用内容图初始化

在理论上,生成图可以随机初始化。
但实际中,常见做法是:

gen_img = content_X.clone().requires_grad_(True)

为什么?

因为如果直接从内容图开始优化,那么:

  • 一开始就已经有正确的内容结构

  • 后面只需要逐渐叠加风格

这样训练更稳定,也更容易得到视觉上合理的结果。

如果从纯随机噪声开始,优化过程通常更慢,也更难控制。


19. 为什么优化的是图像,而不是模型参数

这是样式迁移代码最特别的一点。

平时训练神经网络时,我们更新的是:

  • 卷积核参数

  • 全连接层参数

但样式迁移里:

  • 预训练网络固定不动

  • 内容图和样式图固定不动

  • 真正更新的是生成图本身

也就是说,生成图被视为一个可学习参数。

这是一种非常经典的“输入优化”思路。


20. 优化循环怎么写

优化流程一般长这样:

optimizer = torch.optim.Adam([gen_img], lr=0.3)

然后循环做:

  1. 提取生成图特征

  2. 计算内容损失

  3. 计算风格损失

  4. 计算总变差损失

  5. 合成总损失

  6. 反向传播更新生成图

伪代码如下:

for epoch in range(num_epochs): optimizer.zero_grad() contents_Y_hat, styles_Y_hat = extract_features(gen_img, content_layers, style_layers) contents_l = [content_loss(Y_hat, Y) for Y_hat, Y in zip(contents_Y_hat, content_Y)] styles_l = [style_loss(Y_hat, Y) for Y_hat, Y in zip(styles_Y_hat, style_Y_gram)] tv_l = tv_loss(gen_img) l = content_weight * sum(contents_l) + \ style_weight * sum(styles_l) + \ tv_weight * tv_l l.backward() optimizer.step()

这就是样式迁移的核心优化主循环。


21. 每轮优化后生成图发生了什么

每做一次反向传播,生成图像素都会被轻微调整。

这些调整方向由损失函数共同决定:

  • 内容损失要求它保持内容结构

  • 风格损失要求它具有样式纹理

  • 总变差损失要求它更平滑

所以随着迭代次数增加,生成图会逐渐从“普通照片”变成“带画风的内容图”。

这也是样式迁移最直观、最有成就感的地方。


22. 为什么要把不同层的风格损失加起来

因为单层特征不能完整表示风格。

浅层更擅长描述:

  • 边缘

  • 简单纹理

  • 局部颜色

稍深层则能描述:

  • 更复杂的纹理模式

  • 更抽象的局部结构关系

所以经典样式迁移通常会取多个风格层,把它们的损失加权求和。

这相当于从多个尺度共同约束“风格像不像”。


23. 这节代码最该掌握什么

如果从学习重点看,这一节最应该吃透的是下面几件事。

23.1 特征提取函数

知道如何从预训练网络中拿到中间层特征。

23.2 Gram 矩阵

知道它怎么写、为什么能表示风格。

23.3 三种损失

  • 内容损失

  • 风格损失

  • 总变差损失

23.4 生成图是优化变量

这是样式迁移区别于普通训练的关键。

23.5 总损失的权重平衡

知道为什么不同权重会明显影响最终图像效果。


24. 样式迁移代码的主线可以怎么背

这篇代码其实非常适合用一句主线来背:

固定预训练网络,提取内容图和样式图的目标特征,然后不断更新生成图,让它同时匹配内容和风格。

如果再拆细一点,就是:

  • 读图

  • 提特征

  • 算 Gram

  • 定义损失

  • 优化生成图

只要这五步不乱,样式迁移代码整体就能看懂。


25. 本节总结

这一节我们学习了样式迁移的代码实现,核心内容可以总结为以下几点。

25.1 用预训练 CNN 提取中间层特征

不同层特征分别用于内容和风格表示。

25.2 内容损失比较深层特征差异

它负责保留原图的结构和布局。

25.3 风格损失比较 Gram 矩阵差异

它负责迁移样式图的纹理和风格。

25.4 总变差损失让图像更平滑

防止结果中出现太多噪声。

25.5 样式迁移优化的是生成图本身

这与普通训练网络参数的思路完全不同。


26. 学习感悟

样式迁移代码很有意思,因为它让我们看到:

神经网络不只是“拿来训练参数”的工具,它还可以作为一个固定的特征空间,指导我们直接优化输入本身。

这是一种非常漂亮的思路。

你会发现,到了这里,卷积网络已经不再只是分类器,
而变成了一个“视觉特征度量器”:

  • 它告诉我们什么叫内容接近

  • 什么叫风格接近

  • 然后再利用梯度一步步把图像改出来

这就是样式迁移最迷人的地方。

http://www.jsqmd.com/news/637788/

相关文章:

  • 推荐1款家庭库存管理软件,建议收藏使用!
  • 万象视界灵坛实操手册:图像预处理Pipeline(Resize/Crop/Normalize)对齐CLIP标准
  • 可靠性如何嵌入产品开发流程
  • 忍者像素绘卷开源可部署:支持国产操作系统(OpenEuler)的兼容方案
  • AIAgent目标分解到底难在哪?5大认知陷阱正在拖垮你的智能体落地进度
  • unifolm-vla的数据训练recipe统计
  • Langchain .. 学习 --- LCEL和Runnable劳
  • DAMO-YOLO TinyNAS保姆级教学:EagleEye日志分析、错误排查与常见报错解决方案
  • 仿真模拟电击穿路径的模型:自定义形状、有限元Comsol相场法及PDE模块应用
  • 新能源极耳裁切产线:西门子S7-1500 PLC与基恩士变频器EtherNet/IP协议转换应用
  • 负载箱的故障模式与工程局限:从理想模型到现实约束的技术反思
  • 协议层延迟骤增87%?揭秘AIAgent微服务间通信协议设计的4层降本增效架构实践,今天不看明天宕机
  • 以前我背的是字母,现在才像是在真正记单词
  • DeerFlow PPT自动生成:研究报告一键转换为演示文稿
  • 国企行政筹办正式会议,标准国企会议纪要撰写权威指南
  • 像素语言·维度裂变器:5分钟上手,让AI帮你一键改写平庸文案
  • Phi-4-mini-reasoning企业实操:金融风控规则推理引擎构建案例
  • AI头像生成器保姆级教程:中文描述转Midjourney V6可用Prompt全解析
  • SpringBoot 应用启动流程:从启动到 Web 容器初始化
  • 【工业级AIAgent仿真底座】:基于Docker+Kubernetes+gymnasium的可复现、可审计、可压测环境搭建全链路
  • 从零搭建高性能BitTorrent Tracker:xbt-Tracker与Transmission全流程指南
  • 双非本科入行AI Agent:我是怎么跑通这条路的
  • 45、如何理解和实现递归?数组扁平化里递归有什么缺陷?
  • LightOnOCR-2-1B手把手教学:从零开始,打造你的智能文字提取工具
  • RobotStudio多版本共存避坑指南:5.0/6.0/2019版如何和平共处?
  • 智能优化算法专题(7)【讲解+报告】基于PID控制与模糊PID控制搭建一阶倒立摆仿真(在线整定PID参数)-对比小车位移与摆杆角度
  • 2026年4月洁净手术室厂商推荐,弥散供氧/厂房净化/供氧设备带/医用气体/集中供氧/无菌手术室,洁净手术室商家怎么选择 - 品牌推荐师
  • GX0011单线脉冲温度传感器实战:从NTC替代到STM32驱动,实现低功耗多点测温
  • 杭州专业WordPress模板开发服务商
  • 安科瑞AIM-T系列工业IT绝缘监测及故障定位解决方案为关键供电场所筑牢安全防线