| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved | |
| # pyre-unsafe | |
| from collections import OrderedDict | |
| from typing import Callable, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.checkpoint import checkpoint | |
| from .model_misc import LayerScale | |
| class ResidualAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_head: int, | |
| mlp_ratio: float = 4.0, | |
| ls_init_value: Optional[float] = None, | |
| act_layer: Callable[[], nn.Module] = nn.GELU, | |
| norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, | |
| ): | |
| super().__init__() | |
| # Attention | |
| self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) | |
| # LayerNorm, LayerScale | |
| self.ln_1 = norm_layer(d_model) | |
| self.ln_2 = norm_layer(d_model) | |
| self.ls_1 = ( | |
| LayerScale(d_model, ls_init_value) | |
| if ls_init_value is not None | |
| else nn.Identity() | |
| ) | |
| self.ls_2 = ( | |
| LayerScale(d_model, ls_init_value) | |
| if ls_init_value is not None | |
| else nn.Identity() | |
| ) | |
| # MLP | |
| mlp_width = int(d_model * mlp_ratio) | |
| self.mlp = nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ("c_fc", nn.Linear(d_model, mlp_width)), | |
| ("gelu", act_layer()), | |
| ("c_proj", nn.Linear(mlp_width, d_model)), | |
| ] | |
| ) | |
| ) | |
| def attention( | |
| self, | |
| q_x: torch.Tensor, | |
| k_x: Optional[torch.Tensor] = None, | |
| v_x: Optional[torch.Tensor] = None, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| k_x = k_x if k_x is not None else q_x | |
| v_x = v_x if v_x is not None else q_x | |
| if attn_mask is not None: | |
| # Leave boolean masks as is | |
| if not attn_mask.dtype == torch.bool: | |
| attn_mask = attn_mask.to(q_x.dtype) | |
| return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0] | |
| def forward( | |
| self, | |
| q_x: torch.Tensor, | |
| k_x: Optional[torch.Tensor] = None, | |
| v_x: Optional[torch.Tensor] = None, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| k_x = ( | |
| self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None | |
| ) | |
| v_x = ( | |
| self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None | |
| ) | |
| x = q_x + self.ls_1( | |
| self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) | |
| ) | |
| x = x + self.ls_2(self.mlp(self.ln_2(x))) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| width: int, | |
| layers: int, | |
| heads: int, | |
| mlp_ratio: float = 4.0, | |
| ls_init_value: Optional[float] = None, | |
| act_layer: Callable[[], nn.Module] = nn.GELU, | |
| norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, | |
| compile_mode: Optional[str] = None, | |
| use_act_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| self.width = width | |
| self.layers = layers | |
| self.grad_checkpointing = use_act_checkpoint | |
| self.resblocks = nn.ModuleList( | |
| [ | |
| ResidualAttentionBlock( | |
| width, | |
| heads, | |
| mlp_ratio, | |
| ls_init_value=ls_init_value, | |
| act_layer=act_layer, | |
| norm_layer=norm_layer, | |
| ) | |
| for _ in range(layers) | |
| ] | |
| ) | |
| if compile_mode is not None: | |
| self.forward = torch.compile( | |
| self.forward, mode=compile_mode, fullgraph=True | |
| ) | |
| if self.grad_checkpointing: | |
| torch._dynamo.config.optimize_ddp = False | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| for _, r in enumerate(self.resblocks): | |
| if ( | |
| self.grad_checkpointing | |
| and not torch.jit.is_scripting() | |
| and self.training | |
| ): | |
| x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) | |
| else: | |
| x = r( | |
| x, | |
| attn_mask=attn_mask, | |
| ) | |
| return x | |
| def text_global_pool( | |
| x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax" | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if pool_type == "first": | |
| pooled, tokens = x[:, 0], x[:, 1:] | |
| elif pool_type == "last": | |
| pooled, tokens = x[:, -1], x[:, :-1] | |
| elif pool_type == "argmax": | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| assert text is not None | |
| pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x | |
| else: | |
| pooled = tokens = x | |
| return pooled, tokens | |
| class TextTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| context_length: int = 77, | |
| vocab_size: int = 49408, | |
| width: int = 512, | |
| heads: int = 8, | |
| layers: int = 12, | |
| mlp_ratio: float = 4.0, | |
| ls_init_value: Optional[float] = None, | |
| output_dim: int = 512, | |
| no_causal_mask: bool = False, | |
| pool_type: str = "none", # no pooling | |
| proj_bias: bool = False, | |
| act_layer: Callable = nn.GELU, | |
| norm_layer: Callable = nn.LayerNorm, | |
| output_tokens: bool = False, | |
| use_ln_post: bool = True, | |
| compile_mode: Optional[str] = None, | |
| use_act_checkpoint: bool = False, | |
| ): | |
| super().__init__() | |
| assert pool_type in ("first", "last", "argmax", "none") | |
| self.output_tokens = output_tokens | |
| self.num_pos = self.context_length = context_length | |
| self.vocab_size = vocab_size | |
| self.width = width | |
| self.output_dim = output_dim | |
| self.heads = heads | |
| self.pool_type = pool_type | |
| self.token_embedding = nn.Embedding(self.vocab_size, width) | |
| self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) | |
| self.transformer = Transformer( | |
| width=width, | |
| layers=layers, | |
| heads=heads, | |
| mlp_ratio=mlp_ratio, | |
| ls_init_value=ls_init_value, | |
| act_layer=act_layer, | |
| norm_layer=norm_layer, | |
| compile_mode=compile_mode, | |
| use_act_checkpoint=use_act_checkpoint, | |
| ) | |
| self.ln_final = norm_layer(width) if use_ln_post else nn.Identity() | |
| if no_causal_mask: | |
| self.attn_mask = None | |
| else: | |
| self.register_buffer( | |
| "attn_mask", self.build_causal_mask(), persistent=False | |
| ) | |
| if proj_bias: | |
| self.text_projection = nn.Linear(width, output_dim) | |
| else: | |
| self.text_projection = nn.Parameter(torch.empty(width, output_dim)) | |
| def build_causal_mask(self) -> torch.Tensor: | |
| # lazily create causal attention mask, with full attention between the tokens | |
| # pytorch uses additive attention mask; fill with -inf | |
| mask = torch.empty(self.num_pos, self.num_pos) | |
| mask.fill_(float("-inf")) | |
| mask.triu_(1) # zero out the lower diagonal | |
| return mask | |
| def forward( | |
| self, text: torch.Tensor | |
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
| seq_len = text.shape[1] | |
| x = self.token_embedding(text) # [batch_size, n_ctx, d_model] | |
| attn_mask = self.attn_mask | |
| if attn_mask is not None: | |
| attn_mask = attn_mask[:seq_len, :seq_len] | |
| x = x + self.positional_embedding[:seq_len] | |
| x = self.transformer(x, attn_mask=attn_mask) | |
| x = self.ln_final(x) | |
| pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) | |
| if self.text_projection is not None: | |
| if isinstance(self.text_projection, nn.Linear): | |
| pooled = self.text_projection(pooled) | |
| else: | |
| pooled = pooled @ self.text_projection | |
| if self.output_tokens: | |
| return pooled, tokens | |
| return pooled | |
| class VETextEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| tokenizer: Callable, | |
| width: int = 1024, | |
| heads: int = 16, | |
| layers: int = 24, | |
| context_length: int = 32, | |
| vocab_size: int = 49408, | |
| use_ln_post: bool = True, | |
| compile_mode: Optional[str] = None, | |
| use_act_checkpoint: bool = True, | |
| ): | |
| super().__init__() | |
| self.context_length = context_length | |
| self.use_ln_post = use_ln_post | |
| self.tokenizer = tokenizer | |
| self.encoder = TextTransformer( | |
| context_length=self.context_length, | |
| vocab_size=vocab_size, | |
| width=width, | |
| heads=heads, | |
| layers=layers, | |
| # we want the tokens, not just the pooled output | |
| output_tokens=True, | |
| use_ln_post=use_ln_post, | |
| compile_mode=compile_mode, | |
| use_act_checkpoint=use_act_checkpoint, | |
| ) | |
| self.resizer = nn.Linear(self.encoder.width, d_model) | |
| def forward( | |
| self, | |
| text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]], | |
| input_boxes: Optional[List] = None, | |
| device: torch.device = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| if isinstance(text[0], str): | |
| # no use case for this | |
| assert input_boxes is None or len(input_boxes) == 0, "not supported" | |
| # Encode the text | |
| tokenized = self.tokenizer(text, context_length=self.context_length).to( | |
| device | |
| ) # [b, seq_len] | |
| text_attention_mask = (tokenized != 0).bool() | |
| # manually embed the tokens | |
| inputs_embeds = self.encoder.token_embedding( | |
| tokenized | |
| ) # [b, seq_len, d=1024] | |
| _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024] | |
| assert text_memory.shape[1] == inputs_embeds.shape[1] | |
| # Invert attention mask because its the opposite in pytorch transformer | |
| text_attention_mask = text_attention_mask.ne(1) | |
| # Transpose memory because pytorch's attention expects sequence first | |
| text_memory = text_memory.transpose(0, 1) | |
| # Resize the encoder hidden states to be of the same d_model as the decoder | |
| text_memory_resized = self.resizer(text_memory) | |
| else: | |
| # The text is already encoded, use as is. | |
| text_attention_mask, text_memory_resized, tokenized = text | |
| inputs_embeds = tokenized["inputs_embeds"] | |
| assert input_boxes is None or len(input_boxes) == 0, ( | |
| "Can't replace boxes in text if it's already encoded" | |
| ) | |
| # Note that the input_embeds are returned in pytorch's convention (sequence first) | |
| return ( | |
| text_attention_mask, | |
| text_memory_resized, | |
| inputs_embeds.transpose(0, 1), | |
| ) | |
Xet Storage Details
- Size:
- 11.3 kB
- Xet hash:
- 574ba404d631ddcfeb83ac08cf43c4c282b08ed27af8cac0a7285bcaaca1b1f5
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.