Spaces:
Runtime error
Runtime error
| import os | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:4096' # do this before importing pytorch | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from transformers import EsmModel | |
| import torch | |
| import numpy as np | |
| from lightning.pytorch import seed_everything | |
| from typing import Tuple | |
| import torch | |
| import gc | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from transformers import EsmModel, PreTrainedModel | |
| from configuration import MetaLATTEConfig | |
| from urllib.parse import urljoin | |
| seed_everything(42) | |
| class GELU(nn.Module): | |
| """Implementation of the gelu activation function. | |
| For information: OpenAI GPT's gelu is slightly different | |
| (and gives slightly different results): | |
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| """ | |
| def forward(self, x): | |
| return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) # x: B, L, H, hidden # x1: B, L, H, hidden // 2 | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(x, cos, sin): | |
| # Assuming x has shape (B, L, H, HIDDEN_DIM) | |
| # cos and sin have shape (1, L, HIDDEN_DIM) | |
| cos = cos.unsqueeze(2) # (1, L, 1, HIDDEN_DIM) | |
| sin = sin.unsqueeze(2) # (1, L, 1, HIDDEN_DIM) | |
| return (x * cos) + (rotate_half(x) * sin) | |
| class RotaryEmbedding(torch.nn.Module): | |
| """ | |
| The rotary position embeddings from RoFormer_ (Su et. al). | |
| A crucial insight from the method is that the query and keys are | |
| transformed by rotation matrices which depend on the relative positions. | |
| Other implementations are available in the Rotary Transformer repo_ and in | |
| GPT-NeoX_, GPT-NeoX was an inspiration | |
| .. _RoFormer: https://arxiv.org/abs/2104.09864 | |
| .. _repo: https://github.com/ZhuiyiTechnology/roformer | |
| .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox | |
| .. warning: Please note that this embedding is not registered on purpose, as it is transformative | |
| (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis | |
| """ | |
| def __init__(self, dim: int, *_, **__): | |
| super().__init__() | |
| # Generate and save the inverse frequency buffer (non trainable) | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self._seq_len_cached = None | |
| self._cos_cached = None | |
| self._sin_cached = None | |
| def _update_cos_sin_tables(self, x, seq_dimension=1): | |
| seq_len = x.shape[seq_dimension] | |
| # Reset the tables if the sequence length has changed, | |
| # or if we're on a new device (possibly due to tracing for instance) | |
| if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: | |
| self._seq_len_cached = seq_len | |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) # L, 256 | |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # L, 512 | |
| self._cos_cached = emb.cos()[None, :, :] # 1, L, 512 | |
| self._sin_cached = emb.sin()[None, :, :] # 1, L, 512 | |
| return self._cos_cached, self._sin_cached | |
| def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) | |
| return ( | |
| apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), # B, L, H, hidden | |
| apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), | |
| ) | |
| def macro_f1(y_true, y_pred, thresholds): | |
| y_pred_binary = (y_pred >= thresholds).float() | |
| tp = (y_true * y_pred_binary).sum(dim=0) | |
| fp = ((1 - y_true) * y_pred_binary).sum(dim=0) | |
| fn = (y_true * (1 - y_pred_binary)).sum(dim=0) | |
| precision = tp / (tp + fp + 1e-7) | |
| recall = tp / (tp + fn + 1e-7) | |
| f1 = 2 * precision * recall / (precision + recall + 1e-7) | |
| macro_f1 = f1.mean() | |
| return macro_f1 | |
| def safeguard_softmax(logits, dim=-1): | |
| # remove max number to prevent exp() to be INF | |
| max_logits, _ = logits.max(dim=dim, keepdim=True) | |
| exp_logits = torch.exp(logits - max_logits) | |
| exp_sum = exp_logits.sum(dim=dim, keepdim=True) | |
| probs = exp_logits / (exp_sum + 1e-7) # Adding a small epsilon to prevent division by zero | |
| return probs | |
| class PositionalAttentionHead(nn.Module): | |
| def __init__(self, hidden_dim, n_heads): | |
| super(PositionalAttentionHead, self).__init__() | |
| self.n_heads = n_heads | |
| self.hidden_dim = hidden_dim | |
| self.head_dim = hidden_dim // n_heads | |
| self.preattn_ln = nn.LayerNorm(self.head_dim) | |
| self.Q = nn.Linear(self.head_dim, self.head_dim, bias=False) | |
| self.K = nn.Linear(self.head_dim, self.head_dim, bias=False) | |
| self.V = nn.Linear(self.head_dim, self.head_dim, bias=False) | |
| self.rot_emb = RotaryEmbedding(self.head_dim) | |
| def forward(self, x, attention_mask): | |
| batch_size, seq_len, _ = x.size() # B, L, H | |
| x = x.view(batch_size, seq_len, self.n_heads, self.head_dim) | |
| x = self.preattn_ln(x) | |
| q = self.Q(x) | |
| k = self.K(x) | |
| v = self.V(x) | |
| q, k = self.rot_emb(q, k) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| attn_scores = torch.einsum('bqhd,bkhd->bhqk', q, k) / math.sqrt(self.head_dim) | |
| #print(attention_mask.unsqueeze(1).shape) | |
| #print(attention_mask.unsqueeze(1).unsqueeze(1).shape) | |
| attn_scores = attn_scores.masked_fill(torch.logical_not(attention_mask.unsqueeze(1).unsqueeze(1)), float("-inf")) # B, H, L, L | |
| attn_probs = safeguard_softmax(attn_scores, dim=-1) | |
| x = torch.einsum('bhqk,bkhd->bqhd', attn_probs, v) | |
| x = x.reshape(batch_size, seq_len, self.hidden_dim) # B, L, H | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return x, attn_probs | |
| class CosineAnnealingWithWarmup(_LRScheduler): | |
| # Implement based on Llama paper's description | |
| # https://arxiv.org/abs/2302.13971 | |
| def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate | |
| super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| if self.last_epoch < self.warmup_steps: | |
| return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs] | |
| progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
| cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) | |
| decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio | |
| return [decayed_lr * base_lr for base_lr in self.base_lrs] | |
| class RobertaLMHead(nn.Module): | |
| """Head for masked language modeling.""" | |
| def __init__(self, embed_dim, output_dim, weight): | |
| super().__init__() | |
| self.dense = nn.Linear(embed_dim, embed_dim) | |
| self.layer_norm = nn.LayerNorm(embed_dim) | |
| self.weight = weight | |
| self.gelu = GELU() | |
| self.bias = nn.Parameter(torch.zeros(output_dim)) | |
| def forward(self, features): | |
| x = self.dense(features) | |
| x = self.gelu(x) | |
| x = self.layer_norm(x) | |
| # project back to size of vocabulary with bias | |
| x = F.linear(x, self.weight) + self.bias | |
| return x | |
| class MultitaskProteinModel(PreTrainedModel): | |
| config_class = MetaLATTEConfig | |
| base_model_prefix = "metalatte" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.esm_model = EsmModel.from_pretrained(self.config.esm_model_name) | |
| # layer freezing for the original esm model | |
| # first freeze all | |
| for param in self.esm_model.parameters(): | |
| param.requires_grad = False | |
| # unfreeze the required layers | |
| for i in range(config.num_layers_to_finetune): | |
| for param in self.esm_model.encoder.layer[-i-1].parameters(): | |
| param.requires_grad = True | |
| self.lm_head = RobertaLMHead(embed_dim = 1280, output_dim=33, weight=self.esm_model.embeddings.word_embeddings.weight) | |
| # esm_dim should be 1280 | |
| self.attn_head = PositionalAttentionHead(self.config.hidden_size, self.config.num_attention_heads) | |
| self.attn_ln = nn.LayerNorm(self.config.hidden_size) | |
| self.attn_skip = nn.Linear(self.config.hidden_size, self.config.hidden_size) | |
| self.linear_layers = nn.ModuleList() | |
| # Add linear layers after the attention head | |
| for _ in range(self.config.num_linear_layers): | |
| self.linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size)) | |
| self.reduction_layers = nn.Sequential( | |
| nn.Linear(self.config.hidden_size, self.config.hidden_dim), | |
| GELU(), | |
| nn.Linear(self.config.hidden_dim, self.config.num_labels) | |
| ) | |
| self.clf_ln = nn.LayerNorm(self.config.hidden_size) | |
| self.classification_thresholds = nn.Parameter(torch.tensor([0.5]*self.config.num_labels)) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| config = kwargs.pop("config", None) | |
| if config is None: | |
| config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path) | |
| model = cls(config) | |
| #state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict'] | |
| try: | |
| state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin") | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| state_dict_url, | |
| map_location=torch.device('cpu') | |
| )['state_dict'] | |
| model.load_state_dict(state_dict, strict=False) | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}") | |
| return model | |
| def forward(self, input_ids, attention_mask=None): | |
| outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
| embeddings = outputs.last_hidden_state | |
| attention_masks = attention_mask | |
| x_pool, x_attns = self.attn_head(embeddings, attention_masks) | |
| x_pool = self.attn_ln(x_pool + self.attn_skip(x_pool)) # Added skip connection for the attention layer | |
| for linear_layer in self.linear_layers: | |
| residue = x_pool | |
| x_pool = linear_layer(x_pool) # 1280 -> 1280 | |
| x_pool = F.silu(x_pool) | |
| x_pool = x_pool + residue # Skip connection | |
| x_weighted = torch.einsum('bhlk,bld->bhld', x_attns, x_pool) # (B, H, L, 1280) | |
| x_combined = x_weighted.mean(dim=1) # Average over heads: (B, L, 1280) | |
| x_combined = self.clf_ln(x_combined) | |
| mlm_logits = self.lm_head(x_combined) | |
| attention_masks = attention_masks.unsqueeze(-1).float() # (B, L, 1) | |
| attention_sum = attention_masks.sum(dim=1, keepdim=True) # (B, 1, 1) | |
| x_combined_masked = (x_combined * attention_masks).sum(dim=1) / attention_sum.squeeze(1) # (B, 1280) | |
| # Compute classification logits | |
| x_pred = self.reduction_layers(x_combined_masked) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return x_pred, x_attns, x_combined_masked, mlm_logits | |
| def predict(self, input_ids, attention_mask=None): | |
| x_pred, _, _, _ = self.forward(input_ids, attention_mask) | |
| classification_output = torch.sigmoid(x_pred) | |
| predictions = (classification_output >= self.classification_thresholds).float() | |
| for i, pred in enumerate(predictions): | |
| if pred.sum() == 0: | |
| weighted_probs = classification_output[i] | |
| max_class = torch.argmax(weighted_probs) | |
| predictions[i, max_class] = 1.0 | |
| return classification_output, predictions | |