Spaces:
Build error
Build error
| # This code is referenced from https://github.com/dhansmair/flamingo-mini | |
| import torch | |
| from einops import rearrange, repeat | |
| from einops_exts import rearrange_many | |
| from torch import einsum, nn | |
| import math | |
| import torch.nn.functional as F | |
| from .configuration_gecko import GeckoConfig | |
| from transformers.activations import ACT2FN | |
| from torch.nn.init import trunc_normal_ | |
| from functools import partial | |
| def feed_forward_layer(dim: int, mult: int = 4, activation: str = 'gelu'): | |
| """Feed forward layer with given activation function""" | |
| activations = dict(gelu=nn.GELU, relu=nn.ReLU) | |
| assert activation in activations, f'activation can only be one of {activations.keys()}' | |
| inner_dim = int(dim * mult) | |
| return nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, inner_dim, bias=False), | |
| activations[activation](), | |
| nn.Linear(inner_dim, dim, bias=False), | |
| ) | |
| class PerceiverAttentionLayer(nn.Module): | |
| """Perceiver Attention Layer""" | |
| def __init__(self, dim: int, dim_head: int = 64, heads: int = 8): | |
| super().__init__() | |
| self.scale = dim_head**-0.5 | |
| self.heads = heads | |
| self.dim_head = dim_head | |
| inner_dim = dim_head * heads | |
| # trainable components of PerceiverAttentionLayer | |
| self.norm_media = nn.LayerNorm(dim) | |
| self.norm_latents = nn.LayerNorm(dim) | |
| self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_k = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_v = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
| def forward(self, features, latents): | |
| """Latent vectors are cross-attending to the visual features x | |
| Args: | |
| features: Batch of visual features with shape (batch_size, n_tokens, dim) | |
| latents: Latent learnt vectors which are used to compute queries with shape (batch_size, n_latents, dim) | |
| Returns: | |
| Attention score with shape (batch_size, n_latents, dim) | |
| """ | |
| assert features.ndim == 3 | |
| assert latents.ndim == 3 | |
| assert features.shape[0] == latents.shape[0] | |
| assert features.shape[2] == latents.shape[2] | |
| n_heads = self.heads | |
| n_batch, n_features, dim = features.shape | |
| n_queries = latents.shape[1] | |
| # Layer normalization | |
| x = self.norm_media(features) | |
| latents = self.norm_latents(latents) | |
| # Compute the queries from the latents, for all attention heads simultaneously | |
| q = self.to_q(latents) | |
| q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads) | |
| assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head]) | |
| # Keys and values for all attention heads | |
| kv_input = torch.cat((x, latents), dim=-2) | |
| n_features_latents = n_features + n_queries | |
| k = self.to_k(kv_input) | |
| v = self.to_v(kv_input) | |
| k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads) | |
| assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head]) | |
| q = q * self.scale | |
| # Attention scores | |
| sim = einsum('b h q d, b h f d -> b h q f', q, k) | |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
| alphas = sim.softmax(dim=-1) | |
| out = einsum('b h q f, b h f v -> b h q v', alphas, v) | |
| out = rearrange(out, 'b h q v -> b q (h v)') | |
| return self.to_out(out) | |
| class GeckoResamplerProjector(nn.Module): | |
| """Perceiver Resampler with multi-head attention layer""" | |
| def __init__( | |
| self, | |
| config: GeckoConfig, | |
| num_queries: int = 64, | |
| depth: int = 2, | |
| dim_head: int = 32, | |
| heads: int = 4, | |
| ff_mult: int = 2, | |
| ): | |
| super().__init__() | |
| self.dim = config.text_config.hidden_size | |
| self.num_queries = num_queries | |
| self.latents = nn.Parameter(torch.randn(self.num_queries, self.dim)) # type: ignore[reportPrivateUsage] | |
| self.linear = nn.Linear(config.vision_config.hidden_size, self.dim) | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| PerceiverAttentionLayer(dim=self.dim, dim_head=dim_head, heads=heads), | |
| feed_forward_layer(dim=self.dim, mult=ff_mult, activation=config.projector_hidden_act), | |
| ] | |
| ) | |
| ) | |
| # Layer normalization takes as input the query vector length | |
| self.norm = nn.LayerNorm(self.dim) | |
| def forward(self, x_f: torch.Tensor): | |
| """Run perceiver resampler on the input visual embeddings | |
| Args: | |
| x_f: Input visual embeddings of shape (batch_size, num_tokens, d_visual) | |
| Returns: | |
| Resampler features of shape (batch_size, num_queries, d_visual) | |
| """ | |
| assert x_f.ndim == 3 | |
| x_f = self.linear(x_f) | |
| batch_size, num_tokens, dim = x_f.shape | |
| assert dim == self.dim | |
| # Copy the latents for every element in the batch | |
| x = repeat(self.latents, 'q d -> b q d', b=batch_size) | |
| # Apply attention and feed forward layer | |
| for attn, ffw in self.layers: | |
| x = x + attn(x_f, x) | |
| x = x + ffw(x) | |
| assert x.shape == torch.Size([batch_size, self.num_queries, self.dim]) | |
| norm = self.norm(x) | |
| return norm | |
| class GeckoMLPProjector(nn.Module): | |
| def __init__(self, config: GeckoConfig): | |
| super().__init__() | |
| self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) | |
| self.act = ACT2FN[config.projector_hidden_act] | |
| self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) | |
| def forward(self, image_features): | |
| hidden_states = self.linear_1(image_features) | |
| hidden_states = self.act(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| return hidden_states |