1. 为什么你需要了解Timm的create_model?
如果你正在用PyTorch做计算机视觉项目,尤其是想快速尝试各种Vision Transformer(ViT)模型,那你大概率已经听说过或者用过Timm库了。这个由Ross Wightman大神维护的库,简直就是我们这些“炼丹师”的瑞士军刀。它里面集成了几百个预训练好的模型,从经典的ResNet到最新的Swin Transformer、ConvNeXt,应有尽有。
但很多时候,我们只是停留在“会用”的层面:model = timm.create_model('vit_base_patch16_224', pretrained=True),一行代码搞定,模型和预训练权重就都到手了,非常方便。然而,当你想做一些定制化操作时,比如修改模型结构、加载部分预训练权重、或者把自己写的模型也接入Timm这个强大的生态系统时,问题就来了。你会发现,如果不了解create_model这个函数内部是怎么“变魔术”的,很多操作就会变得磕磕绊绊,只能去网上找一些“黑魔法”代码,知其然不知其所以然。
我自己在项目里就踩过不少坑。比如,想用一个在ImageNet-21k上预训练的ViT模型,但我的分类任务只有10个类别,直接加载预训练权重后,最后的分类头维度对不上,导致报错。又比如,我想把模型当成一个特征提取器,只提取中间某几层的特征,而不是最后的分类结果,该怎么设置参数?这些问题的答案,其实都藏在create_model的源码和它背后那一套精巧的模型注册与构建机制里。
所以,这篇指南的目的,就是带你一起“掀开Timm的引擎盖”,从源码层面彻底搞懂create_model。我们不仅会一行行解析关键代码,更会结合Vision Transformer这个具体案例,手把手教你如何灵活运用它,解决实际项目中遇到的各种问题。相信我,花点时间弄明白这些,以后你再使用Timm时,会感觉无比通透和自信。
2. 初窥门径:create_model的直接使用与核心参数
在深入源码之前,我们先来快速回顾和扩展一下create_model的基本用法。这就像学开车,得先知道方向盘、油门、刹车在哪,再去研究发动机原理。
最基础的调用方式,原始文章已经提到了,就是传入模型名字:
import timm
# 创建一个带有预训练权重的DeiT-base模型
model = timm.create_model('vit_deit_base_patch16_384', pretrained=True)
这里有个非常实用的函数叫list_models(),它能帮你探索Timm的“模型动物园”。我经常用它来查找符合我需求的模型。
# 列出所有可用的模型(包括有预训练权重和没有的)
all_models = timm.list_models()
print(f"Timm总共支持 {len(all_models)} 个模型结构")
# 只列出有预训练权重的模型
pretrained_models = timm.list_models(pretrained=True)
print(f"其中有预训练权重的模型有 {len(pretrained_models)} 个")
# 使用通配符过滤,比如找所有Vision Transformer相关的模型
vit_models = timm.list_models('*vit*', pretrained=True)
print(f"ViT系列模型有:{vit_models[:5]}...") # 查看前5个
接下来是几个至关重要的参数,它们是你灵活操控模型的钥匙:
pretrained: 这个参数大家最熟悉。设为True时,Timm会尝试从云端(通常是GitHub Release或Hugging Face Hub)下载对应的预训练权重文件(.pth文件)并加载。这里有个小细节:下载后的权重文件会缓存在本地目录(通常是~/.cache/torch/hub/checkpoints/),下次再创建相同模型时就直接从本地加载,非常快。我踩过的坑:有时候网络不稳定会导致下载的权重文件损坏,加载时会报错。解决方法就是找到那个缓存文件,删掉它,重新运行程序让它再次下载。num_classes: 这是最常需要修改的参数之一。预训练模型通常是在ImageNet(1000类)上训练的。如果你的任务类别数不同,比如猫狗二分类(2类)或者自己的数据集有200类,就必须设置num_classes=2或num_classes=200。这时,模型最后的全连接分类层(classifier)会被替换成一个新的、具有正确输出维度的层。注意:新层的权重是随机初始化的,你需要在自己的数据上重新训练它。in_chans: 输入图像的通道数,默认为3(RGB)。如果你的输入是灰度图(1通道)或者多光谱图像(比如4通道),就需要修改这个参数。模型第一层卷积或Patch Embedding层的输入通道数会相应改变。pretrained_cfg和pretrained_cfg_overlay: 这是更高级的用法,用于精细控制预训练配置。比如,你想加载的预训练权重来源(URL)、均值标准差(mean/std)等。

403

被折叠的 条评论
为什么被折叠?



