在真实的项目中,无论是做学术研究还是工业落地,我们都会面临一个痛苦的现实:有太多超参数需要调节了(如学习率、Batch Size、网络层数、权重衰减率等)。 如果每次调参都去修改 Python 源代码,不仅效率低下、容易出错,而且根本无法进行自动化的批量实验。
本篇介绍 Python 标准库中功能最强大的命令行解析工具 —— argparse,它能让你的 PyTorch 脚本变成一个专业的“可执行程序”。
一、argparse的应用场景
如果没有命令行解析,你只能在代码里硬编码:
LR = 0.01
BATCH_SIZE = 32
EPOCHS = 100
DATA_PATH = "/home/user/data/train"
# … …
1.工程落地
在生产环境中,模型通常在 Linux 服务器上通过 Shell 脚本(.sh)或 Docker 容器自动化运行。算法工程师需要通过命令行动态传递设备 ID、输入路径等参数,不可能手动去改源码。
⭐⭐⭐工程铁律:生产代码如果还有硬编码路径,属于代码缺陷。
2.科研实验
- 消融实验(Ablation Study):控制变量,对比不同组件的贡献
- 超参搜索:学习率、Batch Size、权重衰减的网格搜索
- 多组实验并行:GPU 集群同时跑多个配置
你可以写一个简单的 Bash 脚本,一行命令跑 10 个实验。
⭐⭐⭐科研技巧:配合 wandb 或 tensorboard,将 argparse 参数自动同步到实验追踪平台,生成漂亮的对比表格。
二、argparse的工作流
1.argparse简介
argsparse 是 python的命令行解析的标准模块,内置于 python,不需要安装。这个库可以让我们直接在命令行中就可以向程序中传入参数。
我们可以使用python file.py来运行python文件。而argparse的作用就是将命令行传入的其他参数进行解析、保存和使用。在使用argparse后,我们在命令行输入的参数就可以以这种形式python file.py --lr 1e-4 --batch_size 32来完成对常见超参数的设置。
2.argparse的核心工作流
argparse 的使用就像是“注册会员 → 填写资料 → 领会员卡”的过程:
- 1.创建解析器:实例化
ArgumentParser对象。 - 2.添加参数:调用
add_argument()方法,定义你想让外部传入的超参数。 - 3.解析参数:调用
parse_args(),将命令行输入的字符串转换为 Python 对象的属性。
# demo.py
import argparse
# 1. 创建ArgumentParser()对象 ———— “一个空箱子”
parser = argparse.ArgumentParser(description="PyTorch Leaf Classification Demo")
# 2. 添加参数 ———— ”往箱子里塞格子,定义每个格子装什么“
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate for optimizer') # 可选参数
parser.add_argument('--batch_size', type=int, required=True, help='batch size for dataloader') # 必选参数
parser.add_argument('-o', '--output', action='store_true', help="shows output")
# action = `store_true` 会将output参数记录为True
# 3. 使用parse_args()解析参数 ———— “从命令行收货,把东西分类装好”
args = parser.parse_args()
# 在后续代码中通过 args.xxx 访问
print(f"当前使用的学习率: {args.lr}")
argparse的参数主要可以分为可选参数和必选参数。可选参数就跟我们的lr参数相类似,未输入的情况下会设置为默认值。必选参数就跟我们的batch_size参数相类似,当我们给参数设置required =True后,我们就必须传入该参数,否则就会报错。
3.超参数类型与高级用法
在深度学习脚本中,超参数的类型五花八门。以下是标准工业写法及如何处理它们。
①数字与字符串
parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
parser.add_argument('--arch', type=str, default='resnet18', help='model architecture')
②如何正确处理布尔值 (True / False)?
'''
这是错误的写法的!
因为在命令行中,无论你输入 `--use_gpu False` 还是 `--use_gpu True`,`argparse` 接收到的都是字符串 "False"。而在 Python 中,任何非空字符串的布尔值都是 `True`!这会导致你的 GPU 开关永远关不掉。
'''
parser.add_argument('--use_gpu', type=bool, default=True, help='whether to use GPU')
正确写法使用 action='store_true'
'''
默认是 False,如果在命令行加了 --use_gpu,它就自动变成 True
'''
parser.add_argument('--use_gpu', action='store_true', help='enable GPU training')
③限定选项:choices
有时你只希望用户在有限的几个选项里选,比如只能选 resnet18 或 resnet50:
parser.add_argument('--model', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'vit'], help='model type')
三、高效使用argparse修改超参数
每个人都有着不同的超参数管理方式,在这里我将分享我使用argparse管理超参数的方式。实际项目中,我更推荐"独立配置模块 + 批量实验支持"的混合模式:
通常情况下,为了使代码更加简洁和模块化,将有关超参数的操作写在config.py,然后在train.py或者其他文件导入就可以。
# config.py
import argparse
def get_args():
parser = argparse.ArgumentParser()
# 常用超参数组合
parser.add_argument('--root_dir', type=str, default='./data', help='dataset path')
parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument('--epochs', type=int, default=20, help='number of total epochs to run')
parser.add_argument('--lr', type=float, default=0.01, help='initial learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay (L2 penalty)')
parser.add_argument('--seed', type=int, default=42, help='seed for initializing training')
parser.add_argument('--amp', action='store_true', help='use automatic mixed precision')
return parser.parse_args()
# train.py
from config import get_args
def main():
args = get_args()
# ... 训练逻辑 ...
if __name__ == '__main__':
main()
批量调参脚本写法
# 自动连续跑三个实验,分别测试不同的学习率
python train.py --lr 0.1 --batch_size 32
python train.py --lr 0.01 --batch_size 32
python train.py --lr 0.001 --batch_size 64 --amp
or
# run.sh 批量调参(科研场景)
for lr in 0.1 0.01 0.001; do
for bs in 32 64; do
python train.py --lr $lr --batch_size $bs --output_dir "exp_lr${lr}_bs${bs}"
done
done
总结
argparse给我们提供了一种新的更加便捷的方式,而在一些大型的深度学习库中人们也会使用json、dict、yaml等文件格式去保存超参数进行训练。
1523

被折叠的 条评论
为什么被折叠?



