TaliDror commited on
Commit ·
efd5117
1
Parent(s): 780f1aa
updated _make_causal_mask and _expand_mask to newer transformer version
Browse files- external/arc2face/models.py +41 -50
external/arc2face/models.py
CHANGED
|
@@ -2,58 +2,36 @@ import torch
|
|
| 2 |
from transformers import CLIPTextModel
|
| 3 |
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
| 4 |
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 5 |
-
#from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
except ImportError:
|
| 9 |
-
# transformers >=4.47 removed these internal helpers from modeling_clip.
|
| 10 |
-
# Reimplement them directly from the transformers 4.34 source so the mask
|
| 11 |
-
# format (additive, shape [bsz,1,tgt,src]) matches what CLIPEncoder expects.
|
| 12 |
-
def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
|
| 13 |
-
bsz, tgt_len = input_ids_shape
|
| 14 |
-
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 15 |
-
mask_cond = torch.arange(tgt_len, device=device)
|
| 16 |
-
mask.masked_fill_(mask_cond < (mask_cond + 1).view(tgt_len, 1), 0)
|
| 17 |
-
mask = mask.to(dtype)
|
| 18 |
-
if past_key_values_length > 0:
|
| 19 |
-
mask = torch.cat(
|
| 20 |
-
[torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1
|
| 21 |
-
)
|
| 22 |
-
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 23 |
|
| 24 |
-
def _expand_mask(mask, dtype, tgt_len=None):
|
| 25 |
-
bsz, src_len = mask.shape
|
| 26 |
-
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 27 |
-
expanded = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 28 |
-
inverted = 1.0 - expanded
|
| 29 |
-
return inverted.masked_fill(inverted.to(torch.bool), torch.finfo(dtype).min)
|
| 30 |
|
| 31 |
class CLIPTextModelWrapper(CLIPTextModel):
|
| 32 |
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
| 33 |
# Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
|
| 34 |
-
# Supports both transformers <=4.46 (self.text_model sub-attribute) and >=4.47 (flat structure, no text_model).
|
| 35 |
def forward(
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
|
| 46 |
|
| 47 |
-
# In transformers <=4.46 the transformer lives in self.text_model;
|
| 48 |
-
# in >=4.47 it was inlined directly onto CLIPTextModel (flat structure).
|
| 49 |
-
tm = getattr(self, 'text_model', self)
|
| 50 |
-
|
| 51 |
if return_token_embs:
|
| 52 |
-
return
|
| 53 |
|
| 54 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
if input_ids is None:
|
| 59 |
raise ValueError("You have to specify input_ids")
|
|
@@ -61,13 +39,19 @@ class CLIPTextModelWrapper(CLIPTextModel):
|
|
| 61 |
input_shape = input_ids.size()
|
| 62 |
input_ids = input_ids.view(-1, input_shape[-1])
|
| 63 |
|
| 64 |
-
hidden_states =
|
|
|
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
if attention_mask is not None:
|
| 68 |
-
|
|
|
|
| 69 |
|
| 70 |
-
encoder_outputs =
|
| 71 |
inputs_embeds=hidden_states,
|
| 72 |
attention_mask=attention_mask,
|
| 73 |
causal_attention_mask=causal_attention_mask,
|
|
@@ -77,18 +61,25 @@ class CLIPTextModelWrapper(CLIPTextModel):
|
|
| 77 |
)
|
| 78 |
|
| 79 |
last_hidden_state = encoder_outputs[0]
|
| 80 |
-
last_hidden_state =
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
pooled_output = last_hidden_state[
|
| 85 |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 86 |
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
| 87 |
]
|
| 88 |
else:
|
|
|
|
| 89 |
pooled_output = last_hidden_state[
|
| 90 |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 91 |
-
|
|
|
|
| 92 |
.int()
|
| 93 |
.argmax(dim=-1),
|
| 94 |
]
|
|
@@ -101,4 +92,4 @@ class CLIPTextModelWrapper(CLIPTextModel):
|
|
| 101 |
pooler_output=pooled_output,
|
| 102 |
hidden_states=encoder_outputs.hidden_states,
|
| 103 |
attentions=encoder_outputs.attentions,
|
| 104 |
-
)
|
|
|
|
| 2 |
from transformers import CLIPTextModel
|
| 3 |
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
| 4 |
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 5 |
+
# from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
| 6 |
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, \
|
| 7 |
+
_prepare_4d_attention_mask # transformers 4.36.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class CLIPTextModelWrapper(CLIPTextModel):
|
| 11 |
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
| 12 |
# Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
|
|
|
|
| 13 |
def forward(
|
| 14 |
+
self,
|
| 15 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 16 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 17 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 18 |
+
output_attentions: Optional[bool] = None,
|
| 19 |
+
output_hidden_states: Optional[bool] = None,
|
| 20 |
+
return_dict: Optional[bool] = None,
|
| 21 |
+
input_token_embs: Optional[torch.Tensor] = None,
|
| 22 |
+
return_token_embs: Optional[bool] = False,
|
| 23 |
) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if return_token_embs:
|
| 26 |
+
return self.text_model.embeddings.token_embedding(input_ids)
|
| 27 |
|
| 28 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 29 |
+
|
| 30 |
+
output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
|
| 31 |
+
output_hidden_states = (
|
| 32 |
+
output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
|
| 33 |
+
)
|
| 34 |
+
return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
|
| 35 |
|
| 36 |
if input_ids is None:
|
| 37 |
raise ValueError("You have to specify input_ids")
|
|
|
|
| 39 |
input_shape = input_ids.size()
|
| 40 |
input_ids = input_ids.view(-1, input_shape[-1])
|
| 41 |
|
| 42 |
+
hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids,
|
| 43 |
+
inputs_embeds=input_token_embs)
|
| 44 |
|
| 45 |
+
# CLIP's text model uses causal mask, prepare it here.
|
| 46 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
| 47 |
+
causal_attention_mask = _create_4d_causal_attention_mask(input_shape, hidden_states.dtype,
|
| 48 |
+
device=hidden_states.device)
|
| 49 |
+
# expand attention_mask
|
| 50 |
if attention_mask is not None:
|
| 51 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 52 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
| 53 |
|
| 54 |
+
encoder_outputs = self.text_model.encoder(
|
| 55 |
inputs_embeds=hidden_states,
|
| 56 |
attention_mask=attention_mask,
|
| 57 |
causal_attention_mask=causal_attention_mask,
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
last_hidden_state = encoder_outputs[0]
|
| 64 |
+
last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
|
| 65 |
|
| 66 |
+
if self.text_model.eos_token_id == 2:
|
| 67 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
| 68 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
| 69 |
+
# ------------------------------------------------------------
|
| 70 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
| 71 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 72 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
| 73 |
pooled_output = last_hidden_state[
|
| 74 |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 75 |
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
| 76 |
]
|
| 77 |
else:
|
| 78 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
| 79 |
pooled_output = last_hidden_state[
|
| 80 |
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
| 81 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
| 82 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
|
| 83 |
.int()
|
| 84 |
.argmax(dim=-1),
|
| 85 |
]
|
|
|
|
| 92 |
pooler_output=pooled_output,
|
| 93 |
hidden_states=encoder_outputs.hidden_states,
|
| 94 |
attentions=encoder_outputs.attentions,
|
| 95 |
+
)
|