| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | from typing import Callable, List, Any, Tuple, Dict |
| | import warnings |
| |
|
| | import torch |
| | from torch import nn, Tensor |
| |
|
| | from .attention import Attention, MemEffAttention |
| | from .drop_path import DropPath |
| | from .layer_scale import LayerScale |
| | from .mlp import Mlp |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None |
| | try: |
| | if XFORMERS_ENABLED: |
| | from xformers.ops import fmha, scaled_index_add, index_select_cat |
| |
|
| | XFORMERS_AVAILABLE = True |
| | warnings.warn("xFormers is available (Block)") |
| | else: |
| | warnings.warn("xFormers is disabled (Block)") |
| | raise ImportError |
| | except ImportError: |
| | XFORMERS_AVAILABLE = False |
| |
|
| | warnings.warn("xFormers is not available (Block)") |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | mlp_ratio: float = 4.0, |
| | qkv_bias: bool = False, |
| | proj_bias: bool = True, |
| | ffn_bias: bool = True, |
| | drop: float = 0.0, |
| | attn_drop: float = 0.0, |
| | init_values=None, |
| | drop_path: float = 0.0, |
| | act_layer: Callable[..., nn.Module] = nn.GELU, |
| | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
| | attn_class: Callable[..., nn.Module] = Attention, |
| | ffn_layer: Callable[..., nn.Module] = Mlp, |
| | ) -> None: |
| | super().__init__() |
| | |
| | self.norm1 = norm_layer(dim) |
| | self.attn = attn_class( |
| | dim, |
| | num_heads=num_heads, |
| | qkv_bias=qkv_bias, |
| | proj_bias=proj_bias, |
| | attn_drop=attn_drop, |
| | proj_drop=drop, |
| | ) |
| | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | self.norm2 = norm_layer(dim) |
| | mlp_hidden_dim = int(dim * mlp_ratio) |
| | self.mlp = ffn_layer( |
| | in_features=dim, |
| | hidden_features=mlp_hidden_dim, |
| | act_layer=act_layer, |
| | drop=drop, |
| | bias=ffn_bias, |
| | ) |
| | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | self.sample_drop_ratio = drop_path |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | def attn_residual_func(x: Tensor) -> Tensor: |
| | return self.ls1(self.attn(self.norm1(x))) |
| |
|
| | def ffn_residual_func(x: Tensor) -> Tensor: |
| | return self.ls2(self.mlp(self.norm2(x))) |
| |
|
| | if self.training and self.sample_drop_ratio > 0.1: |
| | |
| | x = drop_add_residual_stochastic_depth( |
| | x, |
| | residual_func=attn_residual_func, |
| | sample_drop_ratio=self.sample_drop_ratio, |
| | ) |
| | x = drop_add_residual_stochastic_depth( |
| | x, |
| | residual_func=ffn_residual_func, |
| | sample_drop_ratio=self.sample_drop_ratio, |
| | ) |
| | elif self.training and self.sample_drop_ratio > 0.0: |
| | x = x + self.drop_path1(attn_residual_func(x)) |
| | x = x + self.drop_path1(ffn_residual_func(x)) |
| | else: |
| | x = x + attn_residual_func(x) |
| | x = x + ffn_residual_func(x) |
| | return x |
| |
|
| |
|
| | def drop_add_residual_stochastic_depth( |
| | x: Tensor, |
| | residual_func: Callable[[Tensor], Tensor], |
| | sample_drop_ratio: float = 0.0, |
| | ) -> Tensor: |
| | |
| | b, n, d = x.shape |
| | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) |
| | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] |
| | x_subset = x[brange] |
| |
|
| | |
| | residual = residual_func(x_subset) |
| |
|
| | x_flat = x.flatten(1) |
| | residual = residual.flatten(1) |
| |
|
| | residual_scale_factor = b / sample_subset_size |
| |
|
| | |
| | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) |
| | return x_plus_residual.view_as(x) |
| |
|
| |
|
| | def get_branges_scales(x, sample_drop_ratio=0.0): |
| | b, n, d = x.shape |
| | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) |
| | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] |
| | residual_scale_factor = b / sample_subset_size |
| | return brange, residual_scale_factor |
| |
|
| |
|
| | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): |
| | if scaling_vector is None: |
| | x_flat = x.flatten(1) |
| | residual = residual.flatten(1) |
| | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) |
| | else: |
| | x_plus_residual = scaled_index_add( |
| | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor |
| | ) |
| | return x_plus_residual |
| |
|
| |
|
| | attn_bias_cache: Dict[Tuple, Any] = {} |
| |
|
| |
|
| | def get_attn_bias_and_cat(x_list, branges=None): |
| | """ |
| | this will perform the index select, cat the tensors, and provide the attn_bias from cache |
| | """ |
| | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] |
| | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) |
| | if all_shapes not in attn_bias_cache.keys(): |
| | seqlens = [] |
| | for b, x in zip(batch_sizes, x_list): |
| | for _ in range(b): |
| | seqlens.append(x.shape[1]) |
| | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) |
| | attn_bias._batch_sizes = batch_sizes |
| | attn_bias_cache[all_shapes] = attn_bias |
| |
|
| | if branges is not None: |
| | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) |
| | else: |
| | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) |
| | cat_tensors = torch.cat(tensors_bs1, dim=1) |
| |
|
| | return attn_bias_cache[all_shapes], cat_tensors |
| |
|
| |
|
| | def drop_add_residual_stochastic_depth_list( |
| | x_list: List[Tensor], |
| | residual_func: Callable[[Tensor, Any], Tensor], |
| | sample_drop_ratio: float = 0.0, |
| | scaling_vector=None, |
| | ) -> Tensor: |
| | |
| | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] |
| | branges = [s[0] for s in branges_scales] |
| | residual_scale_factors = [s[1] for s in branges_scales] |
| |
|
| | |
| | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) |
| |
|
| | |
| | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) |
| |
|
| | outputs = [] |
| | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): |
| | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) |
| | return outputs |
| |
|
| |
|
| | class NestedTensorBlock(Block): |
| | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: |
| | """ |
| | x_list contains a list of tensors to nest together and run |
| | """ |
| | assert isinstance(self.attn, MemEffAttention) |
| |
|
| | if self.training and self.sample_drop_ratio > 0.0: |
| |
|
| | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
| | return self.attn(self.norm1(x), attn_bias=attn_bias) |
| |
|
| | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
| | return self.mlp(self.norm2(x)) |
| |
|
| | x_list = drop_add_residual_stochastic_depth_list( |
| | x_list, |
| | residual_func=attn_residual_func, |
| | sample_drop_ratio=self.sample_drop_ratio, |
| | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, |
| | ) |
| | x_list = drop_add_residual_stochastic_depth_list( |
| | x_list, |
| | residual_func=ffn_residual_func, |
| | sample_drop_ratio=self.sample_drop_ratio, |
| | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, |
| | ) |
| | return x_list |
| | else: |
| |
|
| | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
| | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) |
| |
|
| | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
| | return self.ls2(self.mlp(self.norm2(x))) |
| |
|
| | attn_bias, x = get_attn_bias_and_cat(x_list) |
| | x = x + attn_residual_func(x, attn_bias=attn_bias) |
| | x = x + ffn_residual_func(x) |
| | return attn_bias.split(x) |
| |
|
| | def forward(self, x_or_x_list): |
| | if isinstance(x_or_x_list, Tensor): |
| | return super().forward(x_or_x_list) |
| | elif isinstance(x_or_x_list, list): |
| | if not XFORMERS_AVAILABLE: |
| | raise AssertionError("xFormers is required for using nested tensors") |
| | return self.forward_nested(x_or_x_list) |
| | else: |
| | raise AssertionError |
| |
|