1. 环境准备与项目初始化
好了,咱们直接开始。如果你对图像超分辨率感兴趣,特别是想亲手复现一下ESRGAN这个经典算法,那这篇手把手的教程就是为你准备的。我假设你已经有一些PyTorch的基础,比如知道怎么搭个简单的CNN,跑过几个训练循环。如果没有也没关系,跟着步骤来,遇到不懂的随时查,也能跟得上。ESRGAN,全称增强型超分辨率生成对抗网络,它干的事儿就是把一张模糊的、像素低的小图,变成一张清晰的、细节丰富的大图。这技术应用场景太多了,比如修复老照片、提升游戏或视频的画质、甚至帮助医学影像分析。今天,我们不只讲理论,重点是让你能真正跑起来一个属于自己的ESRGAN模型。
首先,你得把“厨房”收拾好,也就是搭建开发环境。我强烈建议使用Anaconda来管理Python环境,它能帮你省去很多依赖冲突的麻烦。咱们一步步来。
1.1 创建并激活Conda环境
打开你的终端(Windows用Anaconda Prompt,Mac/Linux用终端),输入以下命令来创建一个新的Python环境。我习惯用Python 3.8,比较稳定,和各个包的兼容性也好。
conda create -n esrgan_env python=3.8 -y
创建成功后,激活这个环境:
conda activate esrgan_env
你会看到命令行提示符前面变成了(esrgan_env),这说明你已经在这个独立的环境里了,接下来安装的所有包都不会影响系统或其他项目。
1.2 安装PyTorch及相关依赖
这是核心步骤。PyTorch的安装命令取决于你的操作系统和是否有GPU。有GPU(NVIDIA显卡)并且安装了CUDA的话,训练速度会快几十倍。你可以去NVIDIA控制面板看看你的CUDA版本(比如11.3, 11.6等)。访问PyTorch官网(https://pytorch.org/get-started/locally/)获取最准确的安装命令。这里我以CUDA 11.3为例:
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
如果你没有GPU,或者想先确保环境能跑起来,可以安装CPU版本:
pip install torch torchvision torchaudio
安装完PyTorch,我们还需要一些图像处理和数据加载的帮手:
pip install opencv-python pillow matplotlib numpy tqdm
opencv-python(cv2): 强大的图像处理库,我们可能用来读图、写图或做一些预处理。pillow(PIL): Python图像处理的基础库,必不可少。matplotlib: 用来可视化我们的图像,看看超分效果。numpy: 数值计算基础,PyTorch的好搭档。tqdm: 给你的训练循环加个进度条,看着训练进度一点点走,心里有底。
1.3 准备项目目录与数据集
环境好了,我们来规划一下项目文件夹。别把所有文件都扔在桌面,一个清晰的结构会让你后续操作非常舒服。我通常这样组织:
esrgan_project/
├── data/
│ ├── train/
│ │ ├── HR/ # 存放所有高分辨率原图
│ │ └── LR/ # 存放对应的低分辨率图(可通过脚本生成)
│ └── val/ # 验证集,结构同train
├── models/ # 存放我们定义的PyTorch模型代码
├── utils/ # 存放工具函数,比如计算指标、图像处理
├── outputs/ # 存放训练日志、生成的图像、保存的模型
├── train.py # 主训练脚本
└── inference.py # 推理(预测)脚本
现在说说数据集。ESRGAN需要成对的低分辨率(LR)和高分辨率(HR)图像进行训练。你可以使用公开数据集,比如DIV2K,这是超分领域常用的高质量数据集。下载后,你需要把高分辨率图像(比如2048x2048)放到data/train/HR/下。然后,关键的一步:你需要生成对应的低分辨率图像。通常做法是对HR图像进行下采样(比如用双三次插值缩小4倍),然后保存到data/train/LR/。我写了一个简单的脚本帮你做这个事:
import cv2
import os
hr_dir = ‘data/train/HR/'
lr_dir = ‘data/train/LR/'
scale_factor = 4 # 我们实现4倍超分
os.makedirs(lr_dir, exist_ok=True)
for img_name in os.listdir(hr_dir):
hr_path = os.path.join(hr_dir, img_name)
hr_img = cv2.imread(hr_path)
h, w = hr_img.shape[:2]
# 计算下采样后的尺寸
lr_h, lr_w = h // scale_factor, w // scale_factor
# 使用双三次插值下采样,模拟低分辨率图像
lr_img = cv2.resize(hr_img, (lr_w, lr_h), interpolation=cv2.INTER_CUBIC)
lr_path = os.path.join(lr_dir, img_name)
cv2.imwrite(lr_path, lr_img)
print(“低分辨率图像生成完毕!”)
运行这个脚本,你的训练数据就准备好了。验证集也如法炮制。记住,数据是模型效果的基石,尽量使用高质量、多样化的图像。
2. 核心模块:深入理解并构建RRDB
环境数据都齐了,现在我们来啃最硬的一块骨头——RRDB模块。这是ESRGAN的灵魂,理解了它,你就理解了ESRGAN为何比之前的SRGAN更强。RRDB的全称是残差中的残差密集块。名字听起来复杂,我们拆开看。
首先,什么是密集连接?你可以想象成在卷积层之间修了很多条“小路”。每一层的输出,不仅传给下一层,还会直接传给后面所有层。这样做的好处是特征复用性极强,梯度流动也更顺畅,缓解了深层网络训练时的梯度消失问题。在RRDB中,一个块内部有三个卷积层,它们就是密集连接的。
其次,什么是残差学习?这是ResNet的核心思想。与其让网络直接学习一个复杂的映射H(x),不如让它学习残差F(x) = H(x) - x。这样,网络只需要学习输入和输出之间的“差值”或“修饰”,任务变得更简单。在RRDB中,每个块的最终输出是 输入 + 经过密集块处理后的特征 * 0.2。注意那个0.2,这是一个残差缩放因子,是ESRGAN的一个关键技巧。它让残差分支的贡献小一些,让主路径的梯度占主导,进一步稳定训练。
最后,**“残差中的残差”**是什么意思?ESRGAN的生成器会把很多个RRDB块串联起来(原文用了23个)。每个RRDB块内部有残差连接(输入加残差),而这一堆RRDB块整体又作为一个大的组件,其输出会与最初的输入特征相加。这就形成了多级的残差结构,让网络既能学习深层的抽象特征,又不会丢失原始的细节信息。
现在,我们动手把它写成代码。在models文件夹下创建generator.py。
import torch
import torch.nn as nn
class ResidualDenseBlock(nn.Module):
“”“单个残差密集块”“”
def __init__(self, num_channels=64, growth_channels=32):
super(ResidualDenseBlock, self).__init_

1158

被折叠的 条评论
为什么被折叠?



