unpad-impl / modeling_bert.py
sdadas's picture
Update modeling_bert.py
ce4368a verified
from typing import Unpack
import torch
from transformers import (
Cache,
EncoderDecoderCache,
DynamicCache,
DataCollatorWithFlattening,
BertModel, BertForMaskedLM,
BertForSequenceClassification,
BertForTokenClassification,
BertForMultipleChoice,
BertForQuestionAnswering
)
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.utils import TransformersKwargs
def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
return features
def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
if inputs.dim() == 3:
inputs = inputs.squeeze()
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
return padded_inputs
class UnpadBertModel(BertModel):
_no_split_modules = ["BertEmbeddings", "BertLayer"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config, add_pooling_layer)
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
device = input_ids.device
batch_size = input_ids.shape[0]
seq_length = input_ids.shape[1]
indices = None
if self.config._attn_implementation.startswith("flash_attention"):
if input_ids is None or attention_mask is None:
raise ValueError("Unpadding requires both input_ids and attention_mask")
with torch.no_grad():
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
features = _unpad_input(input_ids, attention_mask)
input_ids = features["input_ids"].to(device=device)
position_ids = features["position_ids"].to(device=device)
attention_mask = None
kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
kwargs["max_length_k"] = features["max_length_k"]
kwargs["max_length_q"] = features["max_length_q"]
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
attention_mask, encoder_attention_mask = self._create_attention_masks(
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
position_ids=position_ids,
**kwargs,
)
sequence_output = encoder_outputs.last_hidden_state
if self.config._attn_implementation.startswith("flash_attention"):
sequence_output = _pad_output(
inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
)
class UnpadBertForMaskedLM(BertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.bert = UnpadBertModel(config, add_pooling_layer=False)
self.post_init()
class UnpadBertForSequenceClassification(BertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.bert = UnpadBertModel(config)
self.post_init()
class UnpadBertForTokenClassification(BertForTokenClassification):
def __init__(self, config):
super().__init__(config)
self.bert = UnpadBertModel(config)
self.post_init()
class UnpadBertForMultipleChoice(BertForMultipleChoice):
def __init__(self, config):
super().__init__(config)
self.bert = UnpadBertModel(config)
self.post_init()
class UnpadBertForQuestionAnswering(BertForQuestionAnswering):
def __init__(self, config):
super().__init__(config)
self.bert = UnpadBertModel(config, add_pooling_layer=False)
self.post_init()
def enable_bert_unpadding():
BertModel.forward = UnpadBertModel.forward