| from typing import Unpack |
| import torch |
| from transformers import ( |
| Cache, |
| EncoderDecoderCache, |
| DynamicCache, |
| DataCollatorWithFlattening, |
| BertModel, BertForMaskedLM, |
| BertForSequenceClassification, |
| BertForTokenClassification, |
| BertForMultipleChoice, |
| BertForQuestionAnswering |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions |
| from transformers.utils import TransformersKwargs |
|
|
|
|
| def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor): |
| collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) |
| features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)]) |
| return features |
|
|
|
|
| def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor: |
| if inputs.dim() == 3: |
| inputs = inputs.squeeze() |
| if inputs.dim() == 1: |
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen) |
| else: |
| _, *rest = inputs.shape |
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen, *rest) |
| return padded_inputs |
|
|
|
|
| class UnpadBertModel(BertModel): |
| _no_split_modules = ["BertEmbeddings", "BertLayer"] |
|
|
| def __init__(self, config, add_pooling_layer=True): |
| super().__init__(config, add_pooling_layer) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| token_type_ids: torch.Tensor | None = None, |
| position_ids: torch.Tensor | None = None, |
| inputs_embeds: torch.Tensor | None = None, |
| encoder_hidden_states: torch.Tensor | None = None, |
| encoder_attention_mask: torch.Tensor | None = None, |
| past_key_values: Cache | None = None, |
| use_cache: bool | None = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if self.config.is_decoder: |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| else: |
| use_cache = False |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = ( |
| EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) |
| if encoder_hidden_states is not None or self.config.is_encoder_decoder |
| else DynamicCache(config=self.config) |
| ) |
|
|
| past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
| device = input_ids.device |
| batch_size = input_ids.shape[0] |
| seq_length = input_ids.shape[1] |
| indices = None |
| if self.config._attn_implementation.startswith("flash_attention"): |
| if input_ids is None or attention_mask is None: |
| raise ValueError("Unpadding requires both input_ids and attention_mask") |
| with torch.no_grad(): |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| features = _unpad_input(input_ids, attention_mask) |
| input_ids = features["input_ids"].to(device=device) |
| position_ids = features["position_ids"].to(device=device) |
| attention_mask = None |
| kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device) |
| kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device) |
| kwargs["max_length_k"] = features["max_length_k"] |
| kwargs["max_length_q"] = features["max_length_q"] |
|
|
| embedding_output = self.embeddings( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| token_type_ids=token_type_ids, |
| inputs_embeds=inputs_embeds, |
| past_key_values_length=past_key_values_length, |
| ) |
|
|
| attention_mask, encoder_attention_mask = self._create_attention_masks( |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| embedding_output=embedding_output, |
| encoder_hidden_states=encoder_hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
| encoder_outputs = self.encoder( |
| embedding_output, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| position_ids=position_ids, |
| **kwargs, |
| ) |
| sequence_output = encoder_outputs.last_hidden_state |
| if self.config._attn_implementation.startswith("flash_attention"): |
| sequence_output = _pad_output( |
| inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length |
| ) |
|
|
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
| return BaseModelOutputWithPoolingAndCrossAttentions( |
| last_hidden_state=sequence_output, |
| pooler_output=pooled_output, |
| past_key_values=encoder_outputs.past_key_values, |
| ) |
|
|
|
|
| class UnpadBertForMaskedLM(BertForMaskedLM): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bert = UnpadBertModel(config, add_pooling_layer=False) |
| self.post_init() |
|
|
|
|
| class UnpadBertForSequenceClassification(BertForSequenceClassification): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bert = UnpadBertModel(config) |
| self.post_init() |
|
|
|
|
| class UnpadBertForTokenClassification(BertForTokenClassification): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bert = UnpadBertModel(config) |
| self.post_init() |
|
|
|
|
| class UnpadBertForMultipleChoice(BertForMultipleChoice): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bert = UnpadBertModel(config) |
| self.post_init() |
|
|
|
|
| class UnpadBertForQuestionAnswering(BertForQuestionAnswering): |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bert = UnpadBertModel(config, add_pooling_layer=False) |
| self.post_init() |
|
|
|
|
| def enable_bert_unpadding(): |
| BertModel.forward = UnpadBertModel.forward |
|
|