Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.nn.init import ones_, trunc_normal_, zeros_ | |
| from openrec.modeling.common import DropPath, Identity, Mlp | |
| from openrec.modeling.decoders.nrtr_decoder import Embeddings | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim**-0.5 | |
| self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
| self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, q, kv, key_mask=None): | |
| N, C = kv.shape[1:] | |
| QN = q.shape[1] | |
| q = self.q(q).reshape([-1, QN, self.num_heads, | |
| C // self.num_heads]).transpose(1, 2) | |
| q = q * self.scale | |
| k, v = self.kv(kv).reshape( | |
| [-1, N, 2, self.num_heads, | |
| C // self.num_heads]).permute(2, 0, 3, 1, 4) | |
| attn = q.matmul(k.transpose(2, 3)) | |
| if key_mask is not None: | |
| attn = attn + key_mask.unsqueeze(1) | |
| attn = F.softmax(attn, -1) | |
| # if not self.training: | |
| # self.attn_map = attn | |
| attn = self.attn_drop(attn) | |
| x = (attn.matmul(v)).transpose(1, 2).reshape((-1, QN, C)) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class EdgeDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0.0, | |
| attn_drop=0.0, | |
| drop_path=[0.0, 0.0], | |
| act_layer=nn.GELU, | |
| norm_layer='nn.LayerNorm', | |
| epsilon=1e-6, | |
| ): | |
| super().__init__() | |
| self.head_dim = dim // num_heads | |
| self.scale = qk_scale or self.head_dim**-0.5 | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path1 = DropPath( | |
| drop_path[0]) if drop_path[0] > 0.0 else Identity() | |
| self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) | |
| self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) | |
| self.p = nn.Linear(dim, dim) | |
| self.cv = nn.Linear(dim, dim) | |
| self.pv = nn.Linear(dim, dim) | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| self.p_proj = nn.Linear(dim, dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp_ratio = mlp_ratio | |
| self.mlp = Mlp( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| ) | |
| def forward(self, p, cv, pv): | |
| pN = p.shape[1] | |
| vN = cv.shape[1] | |
| p_shortcut = p | |
| p1 = self.p(p).reshape( | |
| [-1, pN, self.num_heads, | |
| self.dim // self.num_heads]).transpose(1, 2) | |
| cv1 = self.cv(cv).reshape( | |
| [-1, vN, self.num_heads, | |
| self.dim // self.num_heads]).transpose(1, 2) | |
| pv1 = self.pv(pv).reshape( | |
| [-1, vN, self.num_heads, | |
| self.dim // self.num_heads]).transpose(1, 2) | |
| edge = F.softmax(p1.matmul(pv1.transpose(2, 3)), -1) # B h N N | |
| p_c = (edge @ cv1).transpose(1, 2).reshape((-1, pN, self.dim)) | |
| x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c))) | |
| x = self.norm2(x1 + self.drop_path1(self.mlp(x1))) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=False, | |
| qk_scale=None, | |
| drop=0.0, | |
| attn_drop=0.0, | |
| drop_path=0.0, | |
| act_layer=nn.GELU, | |
| norm_layer=nn.LayerNorm, | |
| epsilon=1e-6, | |
| ): | |
| super().__init__() | |
| self.norm1 = norm_layer(dim, eps=epsilon) | |
| self.mixer = Attention( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| ) | |
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() | |
| self.norm2 = norm_layer(dim, eps=epsilon) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp_ratio = mlp_ratio | |
| self.mlp = Mlp( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| ) | |
| def forward(self, q, kv, key_mask=None): | |
| x1 = self.norm1(q + self.drop_path(self.mixer(q, kv, key_mask))) | |
| x = self.norm2(x1 + self.drop_path(self.mlp(x1))) | |
| return x | |
| class CPPDDecoder(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| num_layer=2, | |
| drop_path_rate=0.1, | |
| max_len=25, | |
| vis_seq=50, | |
| iters=1, | |
| pos_len=False, | |
| ch=False, | |
| rec_layer=1, | |
| num_heads=None, | |
| ds=False, | |
| **kwargs): | |
| super(CPPDDecoder, self).__init__() | |
| self.out_channels = out_channels # none + 26 + 10 | |
| dim = in_channels | |
| self.dim = dim | |
| self.iters = iters | |
| self.max_len = max_len + 1 # max_len + eos | |
| self.pos_len = pos_len | |
| self.ch = ch | |
| self.char_node_embed = Embeddings(d_model=dim, | |
| vocab=self.out_channels, | |
| scale_embedding=True) | |
| self.pos_node_embed = Embeddings(d_model=dim, | |
| vocab=self.max_len, | |
| scale_embedding=True) | |
| dpr = np.linspace(0, drop_path_rate, num_layer + rec_layer) | |
| self.char_node_decoder = nn.ModuleList([ | |
| DecoderLayer( | |
| dim=dim, | |
| num_heads=dim // 32 if num_heads is None else num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=dpr[i], | |
| ) for i in range(num_layer) | |
| ]) | |
| self.pos_node_decoder = nn.ModuleList([ | |
| DecoderLayer( | |
| dim=dim, | |
| num_heads=dim // 32 if num_heads is None else num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop_path=dpr[i], | |
| ) for i in range(num_layer) | |
| ]) | |
| self.edge_decoder = nn.ModuleList([ | |
| DecoderLayer( | |
| dim=dim, | |
| num_heads=dim // 32 if num_heads is None else num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| qk_scale=1.0 if (rec_layer + i) % 2 != 0 else None, | |
| drop_path=dpr[num_layer + i], | |
| ) for i in range(rec_layer) | |
| ]) | |
| self.rec_layer_num = rec_layer | |
| self_mask = torch.tril( | |
| torch.ones([self.max_len, self.max_len], dtype=torch.float32)) | |
| self_mask = torch.where( | |
| self_mask > 0, | |
| torch.zeros_like(self_mask, dtype=torch.float32), | |
| torch.full([self.max_len, self.max_len], | |
| float('-inf'), | |
| dtype=torch.float32), | |
| ) | |
| self.self_mask = self_mask.unsqueeze(0) | |
| self.char_pos_embed = nn.Parameter(torch.zeros([1, self.max_len, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| self.ds = ds | |
| if not self.ds: | |
| self.vis_pos_embed = nn.Parameter(torch.zeros([1, vis_seq, dim], | |
| dtype=torch.float32), | |
| requires_grad=True) | |
| trunc_normal_(self.vis_pos_embed, std=0.02) | |
| self.char_node_fc1 = nn.Linear(dim, max_len) | |
| self.pos_node_fc1 = nn.Linear(dim, self.max_len) | |
| self.edge_fc = nn.Linear(dim, self.out_channels) | |
| trunc_normal_(self.char_pos_embed, std=0.02) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| zeros_(m.bias) | |
| elif isinstance(m, nn.LayerNorm): | |
| zeros_(m.bias) | |
| ones_(m.weight) | |
| def no_weight_decay(self): | |
| return { | |
| 'char_pos_embed', 'vis_pos_embed', 'char_node_embed', | |
| 'pos_node_embed' | |
| } | |
| def forward(self, x, data=None): | |
| if self.training: | |
| return self.forward_train(x, data) | |
| else: | |
| return self.forward_test(x) | |
| def forward_test(self, x): | |
| if not self.ds: | |
| visual_feats = x + self.vis_pos_embed | |
| else: | |
| visual_feats = x | |
| bs = visual_feats.shape[0] | |
| pos_node_embed = self.pos_node_embed( | |
| torch.arange(self.max_len).cuda( | |
| x.get_device())).unsqueeze(0) + self.char_pos_embed | |
| pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1]) | |
| char_vis_node_query = visual_feats | |
| pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1) | |
| for char_decoder_layer, pos_decoder_layer in zip( | |
| self.char_node_decoder, self.pos_node_decoder): | |
| char_vis_node_query = char_decoder_layer(char_vis_node_query, | |
| char_vis_node_query) | |
| pos_vis_node_query = pos_decoder_layer( | |
| pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :]) | |
| pos_node_query = pos_vis_node_query[:, :self.max_len, :] | |
| char_vis_feats = char_vis_node_query | |
| # pos_vis_feats = pos_vis_node_query[:, self.max_len :, :] | |
| # pos_node_feats = self.edge_decoder( | |
| # pos_node_query, char_vis_feats, pos_vis_feats | |
| # ) # B, 26, dim | |
| pos_node_feats = pos_node_query | |
| for layer_i in range(self.rec_layer_num): | |
| rec_layer = self.edge_decoder[layer_i] | |
| if (self.rec_layer_num + layer_i) % 2 == 0: | |
| pos_node_feats = rec_layer(pos_node_feats, pos_node_feats, | |
| self.self_mask) | |
| else: | |
| pos_node_feats = rec_layer(pos_node_feats, char_vis_feats) | |
| edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37 | |
| edge_logits = F.softmax( | |
| edge_feats, | |
| -1) # * F.sigmoid(pos_node_feats1.unsqueeze(-1)) # B, 26, 37 | |
| return edge_logits | |
| def forward_train(self, x, targets=None): | |
| if not self.ds: | |
| visual_feats = x + self.vis_pos_embed | |
| else: | |
| visual_feats = x | |
| bs = visual_feats.shape[0] | |
| if self.ch: | |
| char_node_embed = self.char_node_embed(targets[-2]) | |
| else: | |
| char_node_embed = self.char_node_embed( | |
| torch.arange(self.out_channels).cuda( | |
| x.get_device())).unsqueeze(0) | |
| char_node_embed = torch.tile(char_node_embed, [bs, 1, 1]) | |
| counting_char_num = char_node_embed.shape[1] | |
| pos_node_embed = self.pos_node_embed( | |
| torch.arange(self.max_len).cuda( | |
| x.get_device())).unsqueeze(0) + self.char_pos_embed | |
| pos_node_embed = torch.tile(pos_node_embed, [bs, 1, 1]) | |
| node_feats = [] | |
| char_vis_node_query = torch.concat([char_node_embed, visual_feats], 1) | |
| pos_vis_node_query = torch.concat([pos_node_embed, visual_feats], 1) | |
| for char_decoder_layer, pos_decoder_layer in zip( | |
| self.char_node_decoder, self.pos_node_decoder): | |
| char_vis_node_query = char_decoder_layer( | |
| char_vis_node_query, | |
| char_vis_node_query[:, counting_char_num:, :]) | |
| pos_vis_node_query = pos_decoder_layer( | |
| pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :]) | |
| char_node_query = char_vis_node_query[:, :counting_char_num, :] | |
| pos_node_query = pos_vis_node_query[:, :self.max_len, :] | |
| char_vis_feats = char_vis_node_query[:, counting_char_num:, :] | |
| char_node_feats1 = self.char_node_fc1(char_node_query) | |
| pos_node_feats1 = self.pos_node_fc1(pos_node_query) | |
| if not self.pos_len: | |
| diag_mask = torch.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile( | |
| [pos_node_feats1.shape[0], 1, 1]) | |
| pos_node_feats1 = ( | |
| pos_node_feats1 * | |
| diag_mask.cuda(pos_node_feats1.get_device())).sum(-1) | |
| node_feats.append(char_node_feats1) | |
| node_feats.append(pos_node_feats1) | |
| pos_node_feats = pos_node_query | |
| for layer_i in range(self.rec_layer_num): | |
| rec_layer = self.edge_decoder[layer_i] | |
| if (self.rec_layer_num + layer_i) % 2 == 0: | |
| pos_node_feats = rec_layer(pos_node_feats, pos_node_feats, | |
| self.self_mask) | |
| else: | |
| pos_node_feats = rec_layer(pos_node_feats, char_vis_feats) | |
| edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37 | |
| return node_feats, edge_feats | |