Spaces:
Running
on
Zero
Running
on
Zero
| """ Multi-Scale Vision Transformer v2 | |
| @inproceedings{li2021improved, | |
| title={MViTv2: Improved multiscale vision transformers for classification and detection}, | |
| author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph}, | |
| booktitle={CVPR}, | |
| year={2022} | |
| } | |
| Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit | |
| Original copyright below. | |
| Modifications and timm support by / Copyright 2022, Ross Wightman | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved. | |
| import operator | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| from functools import partial, reduce | |
| from typing import Union, List, Tuple, Optional | |
| import torch | |
| import torch.utils.checkpoint as checkpoint | |
| from torch import nn | |
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from .fx_features import register_notrace_function | |
| from .helpers import build_model_with_cfg | |
| from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple | |
| from .registry import register_model | |
| def _cfg(url='', **kwargs): | |
| return { | |
| 'url': url, | |
| 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, | |
| 'crop_pct': .9, 'interpolation': 'bicubic', | |
| 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, | |
| 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', | |
| 'fixed_input_size': True, | |
| **kwargs | |
| } | |
| default_cfgs = dict( | |
| mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'), | |
| mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'), | |
| mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'), | |
| mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'), | |
| mvitv2_base_in21k=_cfg( | |
| url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth', | |
| num_classes=19168), | |
| mvitv2_large_in21k=_cfg( | |
| url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth', | |
| num_classes=19168), | |
| mvitv2_huge_in21k=_cfg( | |
| url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth', | |
| num_classes=19168), | |
| mvitv2_small_cls=_cfg(url=''), | |
| ) | |
| class MultiScaleVitCfg: | |
| depths: Tuple[int, ...] = (2, 3, 16, 3) | |
| embed_dim: Union[int, Tuple[int, ...]] = 96 | |
| num_heads: Union[int, Tuple[int, ...]] = 1 | |
| mlp_ratio: float = 4. | |
| pool_first: bool = False | |
| expand_attn: bool = True | |
| qkv_bias: bool = True | |
| use_cls_token: bool = False | |
| use_abs_pos: bool = False | |
| residual_pooling: bool = True | |
| mode: str = 'conv' | |
| kernel_qkv: Tuple[int, int] = (3, 3) | |
| stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2)) | |
| stride_kv: Optional[Tuple[Tuple[int, int]]] = None | |
| stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4) | |
| patch_kernel: Tuple[int, int] = (7, 7) | |
| patch_stride: Tuple[int, int] = (4, 4) | |
| patch_padding: Tuple[int, int] = (3, 3) | |
| pool_type: str = 'max' | |
| rel_pos_type: str = 'spatial' | |
| act_layer: Union[str, Tuple[str, str]] = 'gelu' | |
| norm_layer: Union[str, Tuple[str, str]] = 'layernorm' | |
| norm_eps: float = 1e-6 | |
| def __post_init__(self): | |
| num_stages = len(self.depths) | |
| if not isinstance(self.embed_dim, (tuple, list)): | |
| self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages)) | |
| assert len(self.embed_dim) == num_stages | |
| if not isinstance(self.num_heads, (tuple, list)): | |
| self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages)) | |
| assert len(self.num_heads) == num_stages | |
| if self.stride_kv_adaptive is not None and self.stride_kv is None: | |
| _stride_kv = self.stride_kv_adaptive | |
| pool_kv_stride = [] | |
| for i in range(num_stages): | |
| if min(self.stride_q[i]) > 1: | |
| _stride_kv = [ | |
| max(_stride_kv[d] // self.stride_q[i][d], 1) | |
| for d in range(len(_stride_kv)) | |
| ] | |
| pool_kv_stride.append(tuple(_stride_kv)) | |
| self.stride_kv = tuple(pool_kv_stride) | |
| model_cfgs = dict( | |
| mvitv2_tiny=MultiScaleVitCfg( | |
| depths=(1, 2, 5, 2), | |
| ), | |
| mvitv2_small=MultiScaleVitCfg( | |
| depths=(1, 2, 11, 2), | |
| ), | |
| mvitv2_base=MultiScaleVitCfg( | |
| depths=(2, 3, 16, 3), | |
| ), | |
| mvitv2_large=MultiScaleVitCfg( | |
| depths=(2, 6, 36, 4), | |
| embed_dim=144, | |
| num_heads=2, | |
| expand_attn=False, | |
| ), | |
| mvitv2_base_in21k=MultiScaleVitCfg( | |
| depths=(2, 3, 16, 3), | |
| ), | |
| mvitv2_large_in21k=MultiScaleVitCfg( | |
| depths=(2, 6, 36, 4), | |
| embed_dim=144, | |
| num_heads=2, | |
| expand_attn=False, | |
| ), | |
| mvitv2_small_cls=MultiScaleVitCfg( | |
| depths=(1, 2, 11, 2), | |
| use_cls_token=True, | |
| ), | |
| ) | |
| def prod(iterable): | |
| return reduce(operator.mul, iterable, 1) | |
| class PatchEmbed(nn.Module): | |
| """ | |
| PatchEmbed. | |
| """ | |
| def __init__( | |
| self, | |
| dim_in=3, | |
| dim_out=768, | |
| kernel=(7, 7), | |
| stride=(4, 4), | |
| padding=(3, 3), | |
| ): | |
| super().__init__() | |
| self.proj = nn.Conv2d( | |
| dim_in, | |
| dim_out, | |
| kernel_size=kernel, | |
| stride=stride, | |
| padding=padding, | |
| ) | |
| def forward(self, x) -> Tuple[torch.Tensor, List[int]]: | |
| x = self.proj(x) | |
| # B C H W -> B HW C | |
| return x.flatten(2).transpose(1, 2), x.shape[-2:] | |
| def reshape_pre_pool( | |
| x, | |
| feat_size: List[int], | |
| has_cls_token: bool = True | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| H, W = feat_size | |
| if has_cls_token: | |
| cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] | |
| else: | |
| cls_tok = None | |
| x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous() | |
| return x, cls_tok | |
| def reshape_post_pool( | |
| x, | |
| num_heads: int, | |
| cls_tok: Optional[torch.Tensor] = None | |
| ) -> Tuple[torch.Tensor, List[int]]: | |
| feat_size = [x.shape[2], x.shape[3]] | |
| L_pooled = x.shape[2] * x.shape[3] | |
| x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3) | |
| if cls_tok is not None: | |
| x = torch.cat((cls_tok, x), dim=2) | |
| return x, feat_size | |
| def cal_rel_pos_type( | |
| attn: torch.Tensor, | |
| q: torch.Tensor, | |
| has_cls_token: bool, | |
| q_size: List[int], | |
| k_size: List[int], | |
| rel_pos_h: torch.Tensor, | |
| rel_pos_w: torch.Tensor, | |
| ): | |
| """ | |
| Spatial Relative Positional Embeddings. | |
| """ | |
| sp_idx = 1 if has_cls_token else 0 | |
| q_h, q_w = q_size | |
| k_h, k_w = k_size | |
| # Scale up rel pos if shapes for q and k are different. | |
| q_h_ratio = max(k_h / q_h, 1.0) | |
| k_h_ratio = max(q_h / k_h, 1.0) | |
| dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio | |
| dist_h += (k_h - 1) * k_h_ratio | |
| q_w_ratio = max(k_w / q_w, 1.0) | |
| k_w_ratio = max(q_w / k_w, 1.0) | |
| dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio | |
| dist_w += (k_w - 1) * k_w_ratio | |
| Rh = rel_pos_h[dist_h.long()] | |
| Rw = rel_pos_w[dist_w.long()] | |
| B, n_head, q_N, dim = q.shape | |
| r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim) | |
| rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh) | |
| rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw) | |
| attn[:, :, sp_idx:, sp_idx:] = ( | |
| attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w) | |
| + rel_h[:, :, :, :, :, None] | |
| + rel_w[:, :, :, :, None, :] | |
| ).view(B, -1, q_h * q_w, k_h * k_w) | |
| return attn | |
| class MultiScaleAttentionPoolFirst(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out, | |
| feat_size, | |
| num_heads=8, | |
| qkv_bias=True, | |
| mode="conv", | |
| kernel_q=(1, 1), | |
| kernel_kv=(1, 1), | |
| stride_q=(1, 1), | |
| stride_kv=(1, 1), | |
| has_cls_token=True, | |
| rel_pos_type='spatial', | |
| residual_pooling=True, | |
| norm_layer=nn.LayerNorm, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.dim_out = dim_out | |
| self.head_dim = dim_out // num_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.has_cls_token = has_cls_token | |
| padding_q = tuple([int(q // 2) for q in kernel_q]) | |
| padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) | |
| self.q = nn.Linear(dim, dim_out, bias=qkv_bias) | |
| self.k = nn.Linear(dim, dim_out, bias=qkv_bias) | |
| self.v = nn.Linear(dim, dim_out, bias=qkv_bias) | |
| self.proj = nn.Linear(dim_out, dim_out) | |
| # Skip pooling with kernel and stride size of (1, 1, 1). | |
| if prod(kernel_q) == 1 and prod(stride_q) == 1: | |
| kernel_q = None | |
| if prod(kernel_kv) == 1 and prod(stride_kv) == 1: | |
| kernel_kv = None | |
| self.mode = mode | |
| self.unshared = mode == 'conv_unshared' | |
| self.pool_q, self.pool_k, self.pool_v = None, None, None | |
| self.norm_q, self.norm_k, self.norm_v = None, None, None | |
| if mode in ("avg", "max"): | |
| pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d | |
| if kernel_q: | |
| self.pool_q = pool_op(kernel_q, stride_q, padding_q) | |
| if kernel_kv: | |
| self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) | |
| self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) | |
| elif mode == "conv" or mode == "conv_unshared": | |
| dim_conv = dim // num_heads if mode == "conv" else dim | |
| if kernel_q: | |
| self.pool_q = nn.Conv2d( | |
| dim_conv, | |
| dim_conv, | |
| kernel_q, | |
| stride=stride_q, | |
| padding=padding_q, | |
| groups=dim_conv, | |
| bias=False, | |
| ) | |
| self.norm_q = norm_layer(dim_conv) | |
| if kernel_kv: | |
| self.pool_k = nn.Conv2d( | |
| dim_conv, | |
| dim_conv, | |
| kernel_kv, | |
| stride=stride_kv, | |
| padding=padding_kv, | |
| groups=dim_conv, | |
| bias=False, | |
| ) | |
| self.norm_k = norm_layer(dim_conv) | |
| self.pool_v = nn.Conv2d( | |
| dim_conv, | |
| dim_conv, | |
| kernel_kv, | |
| stride=stride_kv, | |
| padding=padding_kv, | |
| groups=dim_conv, | |
| bias=False, | |
| ) | |
| self.norm_v = norm_layer(dim_conv) | |
| else: | |
| raise NotImplementedError(f"Unsupported model {mode}") | |
| # relative pos embedding | |
| self.rel_pos_type = rel_pos_type | |
| if self.rel_pos_type == 'spatial': | |
| assert feat_size[0] == feat_size[1] | |
| size = feat_size[0] | |
| q_size = size // stride_q[1] if len(stride_q) > 0 else size | |
| kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size | |
| rel_sp_dim = 2 * max(q_size, kv_size) - 1 | |
| self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
| self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
| trunc_normal_tf_(self.rel_pos_h, std=0.02) | |
| trunc_normal_tf_(self.rel_pos_w, std=0.02) | |
| self.residual_pooling = residual_pooling | |
| def forward(self, x, feat_size: List[int]): | |
| B, N, _ = x.shape | |
| fold_dim = 1 if self.unshared else self.num_heads | |
| x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) | |
| q = k = v = x | |
| if self.pool_q is not None: | |
| q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) | |
| q = self.pool_q(q) | |
| q, q_size = reshape_post_pool(q, self.num_heads, q_tok) | |
| else: | |
| q_size = feat_size | |
| if self.norm_q is not None: | |
| q = self.norm_q(q) | |
| if self.pool_k is not None: | |
| k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) | |
| k = self.pool_k(k) | |
| k, k_size = reshape_post_pool(k, self.num_heads, k_tok) | |
| else: | |
| k_size = feat_size | |
| if self.norm_k is not None: | |
| k = self.norm_k(k) | |
| if self.pool_v is not None: | |
| v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) | |
| v = self.pool_v(v) | |
| v, v_size = reshape_post_pool(v, self.num_heads, v_tok) | |
| else: | |
| v_size = feat_size | |
| if self.norm_v is not None: | |
| v = self.norm_v(v) | |
| q_N = q_size[0] * q_size[1] + int(self.has_cls_token) | |
| q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) | |
| q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3) | |
| k_N = k_size[0] * k_size[1] + int(self.has_cls_token) | |
| k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) | |
| k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3) | |
| v_N = v_size[0] * v_size[1] + int(self.has_cls_token) | |
| v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) | |
| v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3) | |
| attn = (q * self.scale) @ k.transpose(-2, -1) | |
| if self.rel_pos_type == 'spatial': | |
| attn = cal_rel_pos_type( | |
| attn, | |
| q, | |
| self.has_cls_token, | |
| q_size, | |
| k_size, | |
| self.rel_pos_h, | |
| self.rel_pos_w, | |
| ) | |
| attn = attn.softmax(dim=-1) | |
| x = attn @ v | |
| if self.residual_pooling: | |
| x = x + q | |
| x = x.transpose(1, 2).reshape(B, -1, self.dim_out) | |
| x = self.proj(x) | |
| return x, q_size | |
| class MultiScaleAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out, | |
| feat_size, | |
| num_heads=8, | |
| qkv_bias=True, | |
| mode="conv", | |
| kernel_q=(1, 1), | |
| kernel_kv=(1, 1), | |
| stride_q=(1, 1), | |
| stride_kv=(1, 1), | |
| has_cls_token=True, | |
| rel_pos_type='spatial', | |
| residual_pooling=True, | |
| norm_layer=nn.LayerNorm, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.dim_out = dim_out | |
| self.head_dim = dim_out // num_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.has_cls_token = has_cls_token | |
| padding_q = tuple([int(q // 2) for q in kernel_q]) | |
| padding_kv = tuple([int(kv // 2) for kv in kernel_kv]) | |
| self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) | |
| self.proj = nn.Linear(dim_out, dim_out) | |
| # Skip pooling with kernel and stride size of (1, 1, 1). | |
| if prod(kernel_q) == 1 and prod(stride_q) == 1: | |
| kernel_q = None | |
| if prod(kernel_kv) == 1 and prod(stride_kv) == 1: | |
| kernel_kv = None | |
| self.mode = mode | |
| self.unshared = mode == 'conv_unshared' | |
| self.norm_q, self.norm_k, self.norm_v = None, None, None | |
| self.pool_q, self.pool_k, self.pool_v = None, None, None | |
| if mode in ("avg", "max"): | |
| pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d | |
| if kernel_q: | |
| self.pool_q = pool_op(kernel_q, stride_q, padding_q) | |
| if kernel_kv: | |
| self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv) | |
| self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv) | |
| elif mode == "conv" or mode == "conv_unshared": | |
| dim_conv = dim_out // num_heads if mode == "conv" else dim_out | |
| if kernel_q: | |
| self.pool_q = nn.Conv2d( | |
| dim_conv, | |
| dim_conv, | |
| kernel_q, | |
| stride=stride_q, | |
| padding=padding_q, | |
| groups=dim_conv, | |
| bias=False, | |
| ) | |
| self.norm_q = norm_layer(dim_conv) | |
| if kernel_kv: | |
| self.pool_k = nn.Conv2d( | |
| dim_conv, | |
| dim_conv, | |
| kernel_kv, | |
| stride=stride_kv, | |
| padding=padding_kv, | |
| groups=dim_conv, | |
| bias=False, | |
| ) | |
| self.norm_k = norm_layer(dim_conv) | |
| self.pool_v = nn.Conv2d( | |
| dim_conv, | |
| dim_conv, | |
| kernel_kv, | |
| stride=stride_kv, | |
| padding=padding_kv, | |
| groups=dim_conv, | |
| bias=False, | |
| ) | |
| self.norm_v = norm_layer(dim_conv) | |
| else: | |
| raise NotImplementedError(f"Unsupported model {mode}") | |
| # relative pos embedding | |
| self.rel_pos_type = rel_pos_type | |
| if self.rel_pos_type == 'spatial': | |
| assert feat_size[0] == feat_size[1] | |
| size = feat_size[0] | |
| q_size = size // stride_q[1] if len(stride_q) > 0 else size | |
| kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size | |
| rel_sp_dim = 2 * max(q_size, kv_size) - 1 | |
| self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
| self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim)) | |
| trunc_normal_tf_(self.rel_pos_h, std=0.02) | |
| trunc_normal_tf_(self.rel_pos_w, std=0.02) | |
| self.residual_pooling = residual_pooling | |
| def forward(self, x, feat_size: List[int]): | |
| B, N, _ = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv.unbind(dim=0) | |
| if self.pool_q is not None: | |
| q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token) | |
| q = self.pool_q(q) | |
| q, q_size = reshape_post_pool(q, self.num_heads, q_tok) | |
| else: | |
| q_size = feat_size | |
| if self.norm_q is not None: | |
| q = self.norm_q(q) | |
| if self.pool_k is not None: | |
| k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token) | |
| k = self.pool_k(k) | |
| k, k_size = reshape_post_pool(k, self.num_heads, k_tok) | |
| else: | |
| k_size = feat_size | |
| if self.norm_k is not None: | |
| k = self.norm_k(k) | |
| if self.pool_v is not None: | |
| v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token) | |
| v = self.pool_v(v) | |
| v, _ = reshape_post_pool(v, self.num_heads, v_tok) | |
| if self.norm_v is not None: | |
| v = self.norm_v(v) | |
| attn = (q * self.scale) @ k.transpose(-2, -1) | |
| if self.rel_pos_type == 'spatial': | |
| attn = cal_rel_pos_type( | |
| attn, | |
| q, | |
| self.has_cls_token, | |
| q_size, | |
| k_size, | |
| self.rel_pos_h, | |
| self.rel_pos_w, | |
| ) | |
| attn = attn.softmax(dim=-1) | |
| x = attn @ v | |
| if self.residual_pooling: | |
| x = x + q | |
| x = x.transpose(1, 2).reshape(B, -1, self.dim_out) | |
| x = self.proj(x) | |
| return x, q_size | |
| class MultiScaleBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out, | |
| num_heads, | |
| feat_size, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=0.0, | |
| norm_layer=nn.LayerNorm, | |
| kernel_q=(1, 1), | |
| kernel_kv=(1, 1), | |
| stride_q=(1, 1), | |
| stride_kv=(1, 1), | |
| mode="conv", | |
| has_cls_token=True, | |
| expand_attn=False, | |
| pool_first=False, | |
| rel_pos_type='spatial', | |
| residual_pooling=True, | |
| ): | |
| super().__init__() | |
| proj_needed = dim != dim_out | |
| self.dim = dim | |
| self.dim_out = dim_out | |
| self.has_cls_token = has_cls_token | |
| self.norm1 = norm_layer(dim) | |
| self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None | |
| if stride_q and prod(stride_q) > 1: | |
| kernel_skip = [s + 1 if s > 1 else s for s in stride_q] | |
| stride_skip = stride_q | |
| padding_skip = [int(skip // 2) for skip in kernel_skip] | |
| self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip) | |
| else: | |
| self.shortcut_pool_attn = None | |
| att_dim = dim_out if expand_attn else dim | |
| attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention | |
| self.attn = attn_layer( | |
| dim, | |
| att_dim, | |
| num_heads=num_heads, | |
| feat_size=feat_size, | |
| qkv_bias=qkv_bias, | |
| kernel_q=kernel_q, | |
| kernel_kv=kernel_kv, | |
| stride_q=stride_q, | |
| stride_kv=stride_kv, | |
| norm_layer=norm_layer, | |
| has_cls_token=has_cls_token, | |
| mode=mode, | |
| rel_pos_type=rel_pos_type, | |
| residual_pooling=residual_pooling, | |
| ) | |
| self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| self.norm2 = norm_layer(att_dim) | |
| mlp_dim_out = dim_out | |
| self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None | |
| self.mlp = Mlp( | |
| in_features=att_dim, | |
| hidden_features=int(att_dim * mlp_ratio), | |
| out_features=mlp_dim_out, | |
| ) | |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| def _shortcut_pool(self, x, feat_size: List[int]): | |
| if self.shortcut_pool_attn is None: | |
| return x | |
| if self.has_cls_token: | |
| cls_tok, x = x[:, :1, :], x[:, 1:, :] | |
| else: | |
| cls_tok = None | |
| B, L, C = x.shape | |
| H, W = feat_size | |
| x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() | |
| x = self.shortcut_pool_attn(x) | |
| x = x.reshape(B, C, -1).transpose(1, 2) | |
| if cls_tok is not None: | |
| x = torch.cat((cls_tok, x), dim=1) | |
| return x | |
| def forward(self, x, feat_size: List[int]): | |
| x_norm = self.norm1(x) | |
| # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj | |
| x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm) | |
| x_shortcut = self._shortcut_pool(x_shortcut, feat_size) | |
| x, feat_size_new = self.attn(x_norm, feat_size) | |
| x = x_shortcut + self.drop_path1(x) | |
| x_norm = self.norm2(x) | |
| x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm) | |
| x = x_shortcut + self.drop_path2(self.mlp(x_norm)) | |
| return x, feat_size_new | |
| class MultiScaleVitStage(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| dim_out, | |
| depth, | |
| num_heads, | |
| feat_size, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| mode="conv", | |
| kernel_q=(1, 1), | |
| kernel_kv=(1, 1), | |
| stride_q=(1, 1), | |
| stride_kv=(1, 1), | |
| has_cls_token=True, | |
| expand_attn=False, | |
| pool_first=False, | |
| rel_pos_type='spatial', | |
| residual_pooling=True, | |
| norm_layer=nn.LayerNorm, | |
| drop_path=0.0, | |
| ): | |
| super().__init__() | |
| self.grad_checkpointing = False | |
| self.blocks = nn.ModuleList() | |
| if expand_attn: | |
| out_dims = (dim_out,) * depth | |
| else: | |
| out_dims = (dim,) * (depth - 1) + (dim_out,) | |
| for i in range(depth): | |
| attention_block = MultiScaleBlock( | |
| dim=dim, | |
| dim_out=out_dims[i], | |
| num_heads=num_heads, | |
| feat_size=feat_size, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| kernel_q=kernel_q, | |
| kernel_kv=kernel_kv, | |
| stride_q=stride_q if i == 0 else (1, 1), | |
| stride_kv=stride_kv, | |
| mode=mode, | |
| has_cls_token=has_cls_token, | |
| pool_first=pool_first, | |
| rel_pos_type=rel_pos_type, | |
| residual_pooling=residual_pooling, | |
| expand_attn=expand_attn, | |
| norm_layer=norm_layer, | |
| drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path, | |
| ) | |
| dim = out_dims[i] | |
| self.blocks.append(attention_block) | |
| if i == 0: | |
| feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)]) | |
| self.feat_size = feat_size | |
| def forward(self, x, feat_size: List[int]): | |
| for blk in self.blocks: | |
| if self.grad_checkpointing and not torch.jit.is_scripting(): | |
| x, feat_size = checkpoint.checkpoint(blk, x, feat_size) | |
| else: | |
| x, feat_size = blk(x, feat_size) | |
| return x, feat_size | |
| class MultiScaleVit(nn.Module): | |
| """ | |
| Improved Multiscale Vision Transformers for Classification and Detection | |
| Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, | |
| Christoph Feichtenhofer* | |
| https://arxiv.org/abs/2112.01526 | |
| Multiscale Vision Transformers | |
| Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik, | |
| Christoph Feichtenhofer* | |
| https://arxiv.org/abs/2104.11227 | |
| """ | |
| def __init__( | |
| self, | |
| cfg: MultiScaleVitCfg, | |
| img_size: Tuple[int, int] = (224, 224), | |
| in_chans: int = 3, | |
| global_pool: str = 'avg', | |
| num_classes: int = 1000, | |
| drop_path_rate: float = 0., | |
| drop_rate: float = 0., | |
| ): | |
| super().__init__() | |
| img_size = to_2tuple(img_size) | |
| norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) | |
| self.num_classes = num_classes | |
| self.drop_rate = drop_rate | |
| self.global_pool = global_pool | |
| self.depths = tuple(cfg.depths) | |
| self.expand_attn = cfg.expand_attn | |
| embed_dim = cfg.embed_dim[0] | |
| self.patch_embed = PatchEmbed( | |
| dim_in=in_chans, | |
| dim_out=embed_dim, | |
| kernel=cfg.patch_kernel, | |
| stride=cfg.patch_stride, | |
| padding=cfg.patch_padding, | |
| ) | |
| patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1]) | |
| num_patches = prod(patch_dims) | |
| if cfg.use_cls_token: | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.num_prefix_tokens = 1 | |
| pos_embed_dim = num_patches + 1 | |
| else: | |
| self.num_prefix_tokens = 0 | |
| self.cls_token = None | |
| pos_embed_dim = num_patches | |
| if cfg.use_abs_pos: | |
| self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim)) | |
| else: | |
| self.pos_embed = None | |
| num_stages = len(cfg.embed_dim) | |
| feat_size = patch_dims | |
| dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] | |
| self.stages = nn.ModuleList() | |
| for i in range(num_stages): | |
| if cfg.expand_attn: | |
| dim_out = cfg.embed_dim[i] | |
| else: | |
| dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)] | |
| stage = MultiScaleVitStage( | |
| dim=embed_dim, | |
| dim_out=dim_out, | |
| depth=cfg.depths[i], | |
| num_heads=cfg.num_heads[i], | |
| feat_size=feat_size, | |
| mlp_ratio=cfg.mlp_ratio, | |
| qkv_bias=cfg.qkv_bias, | |
| mode=cfg.mode, | |
| pool_first=cfg.pool_first, | |
| expand_attn=cfg.expand_attn, | |
| kernel_q=cfg.kernel_qkv, | |
| kernel_kv=cfg.kernel_qkv, | |
| stride_q=cfg.stride_q[i], | |
| stride_kv=cfg.stride_kv[i], | |
| has_cls_token=cfg.use_cls_token, | |
| rel_pos_type=cfg.rel_pos_type, | |
| residual_pooling=cfg.residual_pooling, | |
| norm_layer=norm_layer, | |
| drop_path=dpr[i], | |
| ) | |
| embed_dim = dim_out | |
| feat_size = stage.feat_size | |
| self.stages.append(stage) | |
| self.num_features = embed_dim | |
| self.norm = norm_layer(embed_dim) | |
| self.head = nn.Sequential(OrderedDict([ | |
| ('drop', nn.Dropout(self.drop_rate)), | |
| ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) | |
| ])) | |
| if self.pos_embed is not None: | |
| trunc_normal_tf_(self.pos_embed, std=0.02) | |
| if self.cls_token is not None: | |
| trunc_normal_tf_(self.cls_token, std=0.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_tf_(m.weight, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0.0) | |
| def no_weight_decay(self): | |
| return {k for k, _ in self.named_parameters() | |
| if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])} | |
| def group_matcher(self, coarse=False): | |
| matcher = dict( | |
| stem=r'^patch_embed', # stem and embed | |
| blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] | |
| ) | |
| return matcher | |
| def set_grad_checkpointing(self, enable=True): | |
| for s in self.stages: | |
| s.grad_checkpointing = enable | |
| def get_classifier(self): | |
| return self.head.fc | |
| def reset_classifier(self, num_classes, global_pool=None): | |
| self.num_classes = num_classes | |
| if global_pool is not None: | |
| self.global_pool = global_pool | |
| self.head = nn.Sequential(OrderedDict([ | |
| ('drop', nn.Dropout(self.drop_rate)), | |
| ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) | |
| ])) | |
| def forward_features(self, x): | |
| x, feat_size = self.patch_embed(x) | |
| B, N, C = x.shape | |
| if self.cls_token is not None: | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| if self.pos_embed is not None: | |
| x = x + self.pos_embed | |
| for stage in self.stages: | |
| x, feat_size = stage(x, feat_size) | |
| x = self.norm(x) | |
| return x | |
| def forward_head(self, x, pre_logits: bool = False): | |
| if self.global_pool: | |
| if self.global_pool == 'avg': | |
| x = x[:, self.num_prefix_tokens:].mean(1) | |
| else: | |
| x = x[:, 0] | |
| return x if pre_logits else self.head(x) | |
| def forward(self, x): | |
| x = self.forward_features(x) | |
| x = self.forward_head(x) | |
| return x | |
| def checkpoint_filter_fn(state_dict, model): | |
| if 'stages.0.blocks.0.norm1.weight' in state_dict: | |
| return state_dict | |
| import re | |
| if 'model_state' in state_dict: | |
| state_dict = state_dict['model_state'] | |
| depths = getattr(model, 'depths', None) | |
| expand_attn = getattr(model, 'expand_attn', True) | |
| assert depths is not None, 'model requires depth attribute to remap checkpoints' | |
| depth_map = {} | |
| block_idx = 0 | |
| for stage_idx, d in enumerate(depths): | |
| depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)}) | |
| block_idx += d | |
| out_dict = {} | |
| for k, v in state_dict.items(): | |
| k = re.sub( | |
| r'blocks\.(\d+)', | |
| lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}', | |
| k) | |
| if expand_attn: | |
| k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k) | |
| else: | |
| k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k) | |
| if 'head' in k: | |
| k = k.replace('head.projection', 'head.fc') | |
| out_dict[k] = v | |
| # for k, v in state_dict.items(): | |
| # if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: | |
| # # To resize pos embedding when using model at different size from pretrained weights | |
| # v = resize_pos_embed( | |
| # v, | |
| # model.pos_embed, | |
| # 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), | |
| # model.patch_embed.grid_size | |
| # ) | |
| return out_dict | |
| def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs): | |
| return build_model_with_cfg( | |
| MultiScaleVit, variant, pretrained, | |
| model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], | |
| pretrained_filter_fn=checkpoint_filter_fn, | |
| feature_cfg=dict(flatten_sequential=True), | |
| **kwargs) | |
| def mvitv2_tiny(pretrained=False, **kwargs): | |
| return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs) | |
| def mvitv2_small(pretrained=False, **kwargs): | |
| return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs) | |
| def mvitv2_base(pretrained=False, **kwargs): | |
| return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs) | |
| def mvitv2_large(pretrained=False, **kwargs): | |
| return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs) | |
| # @register_model | |
| # def mvitv2_base_in21k(pretrained=False, **kwargs): | |
| # return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs) | |
| # | |
| # | |
| # @register_model | |
| # def mvitv2_large_in21k(pretrained=False, **kwargs): | |
| # return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs) | |
| # | |
| # | |
| # @register_model | |
| # def mvitv2_huge_in21k(pretrained=False, **kwargs): | |
| # return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) | |
| def mvitv2_small_cls(pretrained=False, **kwargs): | |
| return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs) | |