Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_ | |
| from openrec.modeling.common import Block, PatchEmbed | |
| from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage | |
| class ViT(nn.Module): | |
| def __init__( | |
| self, | |
| img_size=[32, 128], | |
| patch_size=[4, 8], | |
| in_channels=3, | |
| out_channels=256, | |
| embed_dim=384, | |
| depth=12, | |
| num_heads=6, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0, | |
| drop_path_rate=0.0, | |
| norm_layer=nn.LayerNorm, | |
| act_layer=nn.GELU, | |
| last_stage=False, | |
| feat2d=False, | |
| use_cls_token=False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.embed_dim = embed_dim | |
| self.out_channels = embed_dim | |
| self.use_cls_token = use_cls_token | |
| self.feat_sz = [ | |
| img_size[0] // patch_size[0], img_size[1] // patch_size[1] | |
| ] | |
| self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, | |
| embed_dim) | |
| num_patches = self.patch_embed.num_patches | |
| if use_cls_token: | |
| self.cls_token = nn.Parameter( | |
| torch.zeros([1, 1, embed_dim], dtype=torch.float32), | |
| requires_grad=True, | |
| ) | |
| trunc_normal_(self.cls_token, mean=0, std=0.02) | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros([1, num_patches + 1, embed_dim], | |
| dtype=torch.float32), | |
| requires_grad=True, | |
| ) | |
| else: | |
| self.pos_embed = nn.Parameter( | |
| torch.zeros([1, num_patches, embed_dim], dtype=torch.float32), | |
| requires_grad=True, | |
| ) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| dpr = np.linspace(0, drop_path_rate, depth) | |
| self.blocks = nn.ModuleList([ | |
| Block( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| drop=drop_rate, | |
| act_layer=act_layer, | |
| attn_drop=attn_drop_rate, | |
| drop_path=dpr[i], | |
| norm_layer=norm_layer, | |
| ) for i in range(depth) | |
| ]) | |
| self.norm = norm_layer(embed_dim) | |
| self.last_stage = last_stage | |
| self.feat2d = feat2d | |
| if last_stage: | |
| self.out_channels = out_channels | |
| self.stages = LastStage(embed_dim, out_channels, last_drop=0.1) | |
| if feat2d: | |
| self.stages = Feat2D() | |
| trunc_normal_(self.pos_embed, mean=0, std=0.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, mean=0, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| zeros_(m.bias) | |
| if isinstance(m, nn.LayerNorm): | |
| zeros_(m.bias) | |
| ones_(m.weight) | |
| if isinstance(m, nn.Conv2d): | |
| kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
| def no_weight_decay(self): | |
| return {'pos_embed'} | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| if self.use_cls_token: | |
| x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.norm(x) | |
| if self.use_cls_token: | |
| x = x[:, 1:, :] | |
| if self.last_stage: | |
| x, sz = self.stages(x, self.feat_sz) | |
| if self.feat2d: | |
| x, sz = self.stages(x, self.feat_sz) | |
| return x | |