import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from typing import Any, Optional, Tuple, Type from torchvision.models import swin_b, convnext_base from .transformer import TwoWayTransformer, LayerNorm2d from transformers.utils.generic import ModelOutput class MLP(nn.Module): def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, ) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) ) self.sigmoid_output = sigmoid_output def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) if self.sigmoid_output: x = F.sigmoid(x) return x class FaceDecoder(nn.Module): def __init__( self, *, transformer_dim: 256, transformer: nn.Module, activation: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.transformer_dim = transformer_dim self.transformer = transformer self.landmarks_token = nn.Embedding(1, transformer_dim) self.pose_token = nn.Embedding(1, transformer_dim) self.attribute_token = nn.Embedding(1, transformer_dim) self.visibility_token = nn.Embedding(1, transformer_dim) self.age_token = nn.Embedding(1, transformer_dim) self.gender_token = nn.Embedding(1, transformer_dim) self.race_token = nn.Embedding(1, transformer_dim) self.mask_tokens = nn.Embedding(11, transformer_dim) self.output_upscaling = nn.Sequential( nn.ConvTranspose2d( transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 ), LayerNorm2d(transformer_dim // 4), activation(), nn.ConvTranspose2d( transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 ), activation(), ) self.output_hypernetwork_mlps = MLP( transformer_dim, transformer_dim, transformer_dim // 8, 3 ) self.landmarks_prediction_head = MLP(transformer_dim, transformer_dim, 136, 3) self.pose_prediction_head = MLP(transformer_dim, transformer_dim, 3, 3) self.attribute_prediction_head = MLP(transformer_dim, transformer_dim, 40, 3) self.visibility_prediction_head = MLP(transformer_dim, transformer_dim, 29, 3) self.age_prediction_head = MLP(transformer_dim, transformer_dim, 8, 3) self.gender_prediction_head = MLP(transformer_dim, transformer_dim, 2, 3) self.race_prediction_head = MLP(transformer_dim, transformer_dim, 5, 3) def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: output_tokens = torch.cat( [ self.landmarks_token.weight, self.pose_token.weight, self.attribute_token.weight, self.visibility_token.weight, self.age_token.weight, self.gender_token.weight, self.race_token.weight, self.mask_tokens.weight, ], dim=0, ) tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1) src = image_embeddings pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1) b, c, h, w = src.shape hs, src = self.transformer(src, pos_src, tokens) landmarks_token_out = hs[:, 0, :] pose_token_out = hs[:, 1, :] attribute_token_out = hs[:, 2, :] visibility_token_out = hs[:, 3, :] age_token_out = hs[:, 4, :] gender_token_out = hs[:, 5, :] race_token_out = hs[:, 6, :] mask_token_out = hs[:, 7:, :] landmark_output = self.landmarks_prediction_head(landmarks_token_out) headpose_output = self.pose_prediction_head(pose_token_out) attribute_output = self.attribute_prediction_head(attribute_token_out) visibility_output = self.visibility_prediction_head(visibility_token_out) age_output = self.age_prediction_head(age_token_out) gender_output = self.gender_prediction_head(gender_token_out) race_output = self.race_prediction_head(race_token_out) src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in = self.output_hypernetwork_mlps(mask_token_out) b, c, h, w = upscaled_embedding.shape seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) return ( landmark_output, headpose_output, attribute_output, visibility_output, age_output, gender_output, race_output, seg_output, ) class PositionEmbeddingRandom(nn.Module): """ Positional encoding using random spatial frequencies. """ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: super().__init__() if scale is None or scale <= 0.0: scale = 1.0 self.register_buffer( "positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)), ) def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords @ self.positional_encoding_gaussian_matrix coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) def forward(self, size: Tuple[int, int]) -> torch.Tensor: """Generate positional encoding for a grid of the specified size.""" h, w = size device: Any = self.positional_encoding_gaussian_matrix.device grid = torch.ones((h, w), device=device, dtype=torch.float32) y_embed = grid.cumsum(dim=0) - 0.5 x_embed = grid.cumsum(dim=1) - 0.5 y_embed = y_embed / h x_embed = x_embed / w pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) return pe.permute(2, 0, 1) # C x H x W def forward_with_coords( self, coords_input: torch.Tensor, image_size: Tuple[int, int] ) -> torch.Tensor: """Positionally encode points that are not normalized to [0,1].""" coords = coords_input.clone() coords[:, :, 0] = coords[:, :, 0] / image_size[1] coords[:, :, 1] = coords[:, :, 1] / image_size[0] return self._pe_encoding(coords.to(torch.float)) # B x N x C class FaceXFormerMLP(nn.Module): def __init__(self, input_dim): super().__init__() self.proj = nn.Linear(input_dim, 256) # 128, 256, 512, 1024 => 256 def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states.flatten(2).transpose(1, 2) hidden_states = self.proj(hidden_states) return hidden_states class FaceXFormer(nn.Module): def __init__(self): super(FaceXFormer, self).__init__() # Backbone: Swin-B swin_v2 = swin_b(weights="IMAGENET1K_V1") self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1])) self.backbone.requires_grad_(False) # # Backbone: ConvNext-B # convnext_v2 = convnext_base(weights='IMAGENET1K_V1') # self.backbone = torch.nn.Sequential( # *(list(convnext_v2.children())[:-1])) self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"] self.multi_scale_features = [] embed_dim = 1024 out_chans = 256 self.pe_layer = PositionEmbeddingRandom(out_chans // 2) for name, module in self.backbone.named_modules(): if name in self.target_layer_names: module.register_forward_hook(self.save_features_hook(name)) self.face_decoder = FaceDecoder( transformer_dim=256, transformer=TwoWayTransformer( depth=2, embedding_dim=256, mlp_dim=2048, num_heads=8, ), ) num_encoder_blocks = 4 hidden_sizes = [128, 256, 512, 1024] decoder_hidden_size = 256 mlps = [] for i in range(num_encoder_blocks): mlp = FaceXFormerMLP(input_dim=hidden_sizes[i]) mlps.append(mlp) self.linear_c = nn.ModuleList(mlps) self.linear_fuse = nn.Conv2d( in_channels=decoder_hidden_size * num_encoder_blocks, # 1024 out_channels=decoder_hidden_size, # 256 kernel_size=1, bias=False, ) def save_features_hook(self, name): def hook(module, input, output): self.multi_scale_features.append(output.permute(0, 3, 1, 2).contiguous()) return hook def predict(self, x, labels, tasks): self.multi_scale_features.clear() _, _, h, w = x.shape features = self.backbone(x).squeeze() batch_size = self.multi_scale_features[-1].shape[0] all_hidden_states = () for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c): height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape( batch_size, -1, height, width ) encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=self.multi_scale_features[0].size()[2:], mode="bilinear", align_corners=False, ) all_hidden_states += (encoder_hidden_state,) fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) image_pe = self.pe_layer( (fused_states.shape[2], fused_states.shape[3]) ).unsqueeze(0) ( landmark_output, headpose_output, attribute_output, visibility_output, age_output, gender_output, race_output, seg_output, ) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe) segmentation_indices = tasks == 0 seg_output = seg_output[segmentation_indices] landmarks_indices = tasks == 1 landmark_output = landmark_output[landmarks_indices] headpose_indices = tasks == 2 headpose_output = headpose_output[headpose_indices] attribute_indices = tasks == 3 attribute_output = attribute_output[attribute_indices] age_indices = tasks == 4 age_output = age_output[age_indices] gender_output = gender_output[age_indices] race_output = race_output[age_indices] visibility_indices = tasks == 5 visibility_output = visibility_output[visibility_indices] return ( landmark_output, headpose_output, attribute_output, visibility_output, age_output, gender_output, race_output, seg_output, ) def loss( self, predictions: torch.Tensor, labels: torch.Tensor, num_items_in_batch=None ): # print(predictions.shape) # print(labels.shape) # print("predic:", predictions) # print("labels:", labels) # Used L2 loss for now loss = torch.nn.functional.mse_loss(predictions, labels, reduction="sum") if num_items_in_batch: loss /= num_items_in_batch return loss def forward(self, x, labels, num_items_in_batch=None): self.multi_scale_features.clear() _, _, h, w = x.shape features = self.backbone(x).squeeze() batch_size = self.multi_scale_features[-1].shape[0] all_hidden_states = () for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c): height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape( batch_size, -1, height, width ) encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=self.multi_scale_features[0].size()[2:], mode="bilinear", align_corners=False, ) all_hidden_states += (encoder_hidden_state,) fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) image_pe = self.pe_layer( (fused_states.shape[2], fused_states.shape[3]) ).unsqueeze(0) ( landmark_output, headpose_output, attribute_output, visibility_output, age_output, gender_output, race_output, seg_output, ) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe) # All tasks are landmark prediction if labels is not None: loss = self.loss(landmark_output.view(-1, 68, 2), labels) else: loss = None return ModelOutput( loss=loss, )