timm 视觉库中的 create_model 函数详解

timm 视觉库中的 create_model 函数详解

最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm。各位炼丹师应该已经想必已经对其无比熟悉了,本文将介绍其中最关键的函数之一:create_model 函数。

timm简介

PyTorchImageModels,简称timm,是一个巨大的PyTorch代码集合,包括了一系列:

  • image models
  • layers
  • utilities
  • optimizers
  • schedulers
  • data-loaders / augmentations
  • training / validation scripts
  • 旨在将各种 SOTA 模型、图像实用工具、常用的优化器、训练策略等视觉相关常用函数的整合在一起,并具有复现ImageNet训练结果的能力。

    源码:https://github.com/rwightman/pytorch-image-models

    文档:https://fastai.github.io/timmdocs/

    create_model 函数的使用及常用参数

    本小节先介绍 create_model 函数,及常用的参数 **kwargs

    顾名思义,create_model 函数是用来创建一个网络模型(如 ResNet、ViT 等),timm 库本身可供直接调用的模型已有接近400个,用户也可以自己实现一些模型并注册进 timm (这一部分内容将在下一小节着重介绍),供自己调用。

    model_name

    我们首先来看最简单地用法:直接传入模型名称 model_name

    import timm 
    # 创建 resnet-34 
    model = timm.create_model('resnet34')
    # 创建 efficientnet-b0
    model = timm.create_model('efficientnet_b0')
    

    我们可以通过 list_models 函数来查看已经可以直接创建、有预训练参数的模型列表:

    all_pretrained_models_available = timm.list_models(pretrained=True)
    print(all_pretrained_models_available)
    print(len(all_pretrained_models_available))
    

    输出:

    [..., 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', 'xception41', 'xception65', 'xception71']
    452
    

    如果没有设置 pretrained=True 的话有将会输出612,即有预训练权重参数的模型有452个,没有预训练参数,只有模型结构的共有612个。

    pretrained

    如果我们传入 pretrained=True,那么 timm 会从对应的 URL 下载模型权重参数并载入模型,只有当第一次(即本地还没有对应模型参数时)会去下载,之后会直接从本地加载模型权重参数。

    model = timm.create_model('resnet34', pretrained=True)
    

    输出:

    Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth" to /home/song/.cache/torch/hub/checkpoints/resnet34-43635321.pth
    

    features_only、out_indices

    create_mode 函数还支持 features_only=True 参数,此时函数将返回部分网络,该网络提取每一步最深一层的特征图。还可以使用 out_indices=[…] 参数指定层的索引,以提取中间层特征。

    # 创建一个 (1, 3, 224, 224) 形状的张量
    x = torch.randn(1, 3, 224, 224)
    model = timm.create_model('resnet34')
    preds = model(x)
    print('preds shape: {}'.format(preds.shape))
    
    all_feature_extractor = timm.create_model('resnet34', features_only=True)
    all_features = all_feature_extractor(x)
    print('All {} Features: '.format(len(all_features)))
    for i in range(len(all_features)):
        print('feature {} shape: {}'.format(i, all_features[i].shape))
    
    out_indices = [2, 3, 4]
    selected_feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=out_indices)
    selected_features = selected_feature_extractor(x)
    print('Selected Features: ')
    for i in range(len(out_indices)):
        print('feature {} shape: {}'.format(out_indices[i], selected_features[i].shape))
    

    我们以一个 (1, 3, 224, 224) 形状的张量为输入,在视觉任务中,图像输入张量总是类似的形状。上面例程展示了,创建完整模型 model,创建完整特征提取器 all_feature_extractor,和创建某几层特征提取器 selected_feature_extractor 的具体输出。

    可以结合下面 ResNet34 的结构图来理解(图中不同的颜色表示不同的 layer),根据下图分析各层的卷积操作,计算各层最后一个卷积的输入,并与上面例程的输出(附在图后)验证是否一致。

    输出:

    preds shape: torch.Size([1, 1000])
    All 5 Features:
    feature 0 shape: torch.Size([1, 64, 112, 112])
    feature 1 shape: torch.Size([1, 64, 56, 56])
    feature 2 shape: torch.Size([1, 128, 28, 28])
    feature 3 shape: torch.Size([1, 256, 14, 14])
    feature 4 shape: torch.Size([1, 512, 7, 7])
    Selected Features:
    feature 2 shape: torch.Size([1, 128, 28, 28])
    feature 3 shape: torch.Size([1, 256, 14, 14])
    feature 4 shape: torch.Size([1, 512, 7, 7])
    

    这样,我们就可以通过 timm_model 函数及其 features_onlyout_indices 参数将预训练模型方便地转换为自己想要的特征提取器。

    接下来我们来看一下这些特征提取器究竟是什么类型:

    import timm
    feature_extractor = timm.create_model('resnet34', features_only=True, out_indices=[3])
    
    print('type:', type(feature_extractor))
    print('len: ', len(feature_extractor))
    for item in feature_extractor:
        print(item)
    

    输出:

    type: <class 'timm.models.features.FeatureListNet'>
    len:  7
    conv1
    bn1
    act1
    maxpool
    layer1
    layer2
    layer3
    

    可以看到,feature_extractor 其实也是一个神经网络,在 timm 中称为 FeatureListNet,而我们通过 out_indices 参数来指定截取到哪一层特征。

    需要注意的是,ViT 模型并不支持 features_only 选项(0.4.12版本)。

    extractor = timm.create_model('vit_base_patch16_224', features_only=True)
    

    输出:

    RuntimeError: features_only not implemented for Vision Transformer models.
    

    create_model 函数究竟做了什么

    registry

    在了解了 create_model 函数的基本使用之后,我们来深入探索一下 create_model 函数的源码,看一下究竟是怎样实现从模型到特征提取器的转换的。

    create_model 主体只有 50 行左右的代码,因此所有这些神奇的事情是在其他地方完成的。我们知道 timm.list_models() 函数中的每一个模型名字(str)实际上都是一个函数。以下代码可以测试这一点:

    import timm
    import random 
    from timm.models import registry
    
    m = timm.list_models()[-1]
    print(m)
    registry.is_model(m)
    

    输出:

    xception71
    True
    

    实际上,在 timm 内部,有一个字典称为 _model_entrypoints 包含了所有的模型名称和他们各自的函数。比如说,我们可以通过 model_entrypoint 函数从 _model_entrypoints 内部得到 xception71 模型的构造函数。

    constuctor_fn = registry.model_entrypoint(m)
    print(constuctor_fn)
    

    输出:

    <function timm.models.xception_aligned.xception71(pretrained=False, **kwargs)>
    

    也有可能输出:

    <function xception71 at 0x7fc0cba0eca0>
    

    一样的。

    如我们所见,在 timm.models.xception_aligned 模块中有一个函数称为 xception71 。类似的,timm 中的每一个模型都有着一个这样的构造函数。事实上,内部的 _model_entrypoints 字典大概长这个样子:

    _model_entrypoints
    > > 
    {
    'cspresnet50':<function timm.models.cspnet.cspresnet50(pretrained=False, **kwargs)>,'cspresnet50d': <function timm.models.cspnet.cspresnet50d(pretrained=False, **kwargs)>,
    'cspresnet50w': <function timm.models.cspnet.cspresnet50w(pretrained=False, **kwargs)>,
    'cspresnext50': <function timm.models.cspnet.cspresnext50(pretrained=False, **kwargs)>,
    'cspresnext50_iabn': <function timm.models.cspnet.cspresnext50_iabn(pretrained=False, **kwargs)>,
    'cspdarknet53': <function timm.models.cspnet.cspdarknet53(pretrained=False, **kwargs)>,
    'cspdarknet53_iabn': <function timm.models.cspnet.cspdarknet53_iabn(pretrained=False, **kwargs)>,
    'darknet53': <function timm.models.cspnet.darknet53(pretrained=False, **kwargs)>,
    'densenet121': <function timm.models.densenet.densenet121(pretrained=False, **kwargs)>,
    'densenetblur121d': <function timm.models.densenet.densenetblur121d(pretrained=False, **kwargs)>,
    'densenet121d': <function timm.models.densenet.densenet121d(pretrained=False, **kwargs)>,
    'densenet169': <function timm.models.densenet.densenet169(pretrained=False, **kwargs)>,
    'densenet201': <function timm.models.densenet.densenet201(pretrained=False, **kwargs)>,
    'densenet161': <function timm.models.densenet.densenet161(pretrained=False, **kwargs)>,
    'densenet264': <function timm.models.densenet.densenet264(pretrained=False, **kwargs)>,
    
    }
    

    所以说,在 timm 对应的模块中,每个模型都有一个构造器。比如说 ResNets 系列模型被定义在 timm.models.resnet 模块中。因此,实际上我们有两种方式来创建一个 resnet34 模型:

    import timm
    from timm.models.resnet import resnet34
    
    # 使用 create_model
    m = timm.create_model('resnet34')
    
    # 直接调用构造函数
    m = resnet34()
    

    但使用上,我们无须调用构造函数。所用模型都可以通过 create_model 函数来将创建。

    Register model

    resnet34 构造函数的源码如下:

    @register_model
    def resnet34(pretrained=False, **kwargs):
        """Constructs a ResNet-34 model.
        """
        model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
        return _create_resnet('resnet34', pretrained, **model_args)
    

    我们会发现 timm 中的每个模型都有一个 register_model 装饰器。最开始, _model_entrypoints 是一个空字典。我们是通过 register_model 装饰器来不断地像其中添加模型名称和它对应的构造函数。该装饰器的定义如下:

    def register_model(fn):
        # lookup containing module
        mod = sys.modules[fn.__module__]
        module_name_split = fn.__module__.split('.')
        module_name = module_name_split[-1] if len(module_name_split) else ''
    
        # add model to __all__ in module
        model_name = fn.__name__
        if hasattr(mod, '__all__'):
            mod.__all__.append(model_name)
        else:
            mod.__all__ = [model_name]
    
        # add entries to registry dict/sets
        _model_entrypoints[model_name] = fn
        _model_to_module[model_name] = module_name
        _module_to_models[module_name].add(model_name)
        has_pretrained = False  # check if model has a pretrained url to allow filtering on this
        if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
            # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
            # entrypoints or non-matching combos
            has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
        if has_pretrained:
            _model_has_pretrained.add(model_name)
        return fn
    

    我们可以看到, register_model 函数完成了一些比较基础的步骤,但这里需要指出的是这一句:

    _model_entrypoints[model_name] = fn
    

    它将给定的 fn 添加到 _model_entrypoints 其键名为 fn.__name__。所以说 resnet34 函数上的装饰器 @register_model_model_entrypoints 中创建一个新的条目,像这样:

    {’resnet34’: <function timm.models.resnet.resnet34(pretrained=False, **kwargs)>}
    

    我们同样可以看到在 resnet34 构造函数的源码中,在设置完一些 model_args 之后,它会随后调用 _create_resnet 函数。让我们再来看一下该函数的源码:

    def _create_resnet(variant, pretrained=False, **kwargs):
        return build_model_with_cfg(
            ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
    

    所以在 _create_resnet 函数之中,会再调用 build_model_with_cfg 函数并将一个构造器类 ResNet 、变量名 resnet34、一个 default_cfg 和一些 **kwargs 传入其中。

    default config

    timm 中所有的模型都有一个默认的配置,包括指向它的预训练权重参数的URL、类别数、输入图像尺寸、池化尺寸等。

    resnet34 的默认配置如下:

    {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
    'num_classes': 1000,
    'input_size': (3, 224, 224),
    'pool_size': (7, 7),
    'crop_pct': 0.875,
    'interpolation': 'bilinear',
    'mean': (0.485, 0.456, 0.406),
    'std': (0.229, 0.224, 0.225),
    'first_conv': 'conv1',
    'classifier': 'fc'}
    

    此默认配置与其他参数(如构造函数类和一些模型参数)一起传递给 build_model_with_cfg 函数。

    build model with config

    这个 build_model_with_cfg 函数负责:

    1. 真正地实例化一个模型类来创建一个模型
    2. pruned=True,对模型进行剪枝
    3. pretrained=True,加载预训练模型参数
    4. features_only=True,将模型转换为特征提取器

    看一下该函数的源码:

    def build_model_with_cfg(
            model_cls: Callable,
            variant: str,
            pretrained: bool,
            default_cfg: dict,
            model_cfg: dict = None,
            feature_cfg: dict = None,
            pretrained_strict: bool = True,
            pretrained_filter_fn: Callable = None,
            pretrained_custom_load: bool = False,
            **kwargs):
        pruned = kwargs.pop('pruned', False)
        features = False
        feature_cfg = feature_cfg or {}
    
        if kwargs.pop('features_only', False):
            features = True
            feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
            if 'out_indices' in kwargs:
                feature_cfg['out_indices'] = kwargs.pop('out_indices')
    
        model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
        model.default_cfg = deepcopy(default_cfg)
    
        if pruned:
            model = adapt_model_from_file(model, variant)
    
        # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
        num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
        if pretrained:
            if pretrained_custom_load:
                load_custom_pretrained(model)
            else:
                load_pretrained(
                    model,
                    num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
                    filter_fn=pretrained_filter_fn, strict=pretrained_strict)
    
        if features:
            feature_cls = FeatureListNet
            if 'feature_cls' in feature_cfg:
                feature_cls = feature_cfg.pop('feature_cls')
                if isinstance(feature_cls, str):
                    feature_cls = feature_cls.lower()
                    if 'hook' in feature_cls:
                        feature_cls = FeatureHookNet
                    else:
                        assert False, f'Unknown feature class {feature_cls}'
            model = feature_cls(model, **feature_cfg)
            model.default_cfg = default_cfg_for_features(default_cfg)  # add back default_cfg
    
        return model
    

    我们可以看到,模型在这一步被创建出来:model = model_cls(**kwargs)。本文将不再深入到 prunedadapt_model_from_file 内部查看。

    总结

    通过本文,我们已经完全了解了 create_model 函数,我们了解到:

  • 每个模型有不同的构造函数,可以传入不同的参数, _model_entrypoints 字典包括了所有的模型名称及其对应的构造函数
  • build_with_model_cfg 函数接收模型构造器类和其中的一些具体参数,真正地实例化一个模型
  • load_pretrained 会加载预训练参数
  • FeatureListNet 类可以将模型转换为特征提取器
  • Ref:

    https://github.com/rwightman/pytorch-image-models

    https://fastai.github.io/timmdocs/

    https://fastai.github.io/timmdocs/create_model#Turn-any-model-into-a-feature-extractor

    https://fastai.github.io/timmdocs/tutorial_feature_extractor

    https://zhuanlan.zhihu.com/p/404107277

    来源:Adenialzz

    物联沃分享整理
    物联沃-IOTWORD物联网 » timm 视觉库中的 create_model 函数详解

    发表评论