TaliDror commited on
Commit
efd5117
·
1 Parent(s): 780f1aa

updated _make_causal_mask and _expand_mask to newer transformer version

Browse files
Files changed (1) hide show
  1. external/arc2face/models.py +41 -50
external/arc2face/models.py CHANGED
@@ -2,58 +2,36 @@ import torch
2
  from transformers import CLIPTextModel
3
  from typing import Any, Callable, Dict, Optional, Tuple, Union, List
4
  from transformers.modeling_outputs import BaseModelOutputWithPooling
5
- #from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
6
- try:
7
- from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
8
- except ImportError:
9
- # transformers >=4.47 removed these internal helpers from modeling_clip.
10
- # Reimplement them directly from the transformers 4.34 source so the mask
11
- # format (additive, shape [bsz,1,tgt,src]) matches what CLIPEncoder expects.
12
- def _make_causal_mask(input_ids_shape, dtype, device, past_key_values_length=0):
13
- bsz, tgt_len = input_ids_shape
14
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
15
- mask_cond = torch.arange(tgt_len, device=device)
16
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(tgt_len, 1), 0)
17
- mask = mask.to(dtype)
18
- if past_key_values_length > 0:
19
- mask = torch.cat(
20
- [torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1
21
- )
22
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
23
 
24
- def _expand_mask(mask, dtype, tgt_len=None):
25
- bsz, src_len = mask.shape
26
- tgt_len = tgt_len if tgt_len is not None else src_len
27
- expanded = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
28
- inverted = 1.0 - expanded
29
- return inverted.masked_fill(inverted.to(torch.bool), torch.finfo(dtype).min)
30
 
31
  class CLIPTextModelWrapper(CLIPTextModel):
32
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
33
  # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
34
- # Supports both transformers <=4.46 (self.text_model sub-attribute) and >=4.47 (flat structure, no text_model).
35
  def forward(
36
- self,
37
- input_ids: Optional[torch.Tensor] = None,
38
- attention_mask: Optional[torch.Tensor] = None,
39
- position_ids: Optional[torch.Tensor] = None,
40
- output_attentions: Optional[bool] = None,
41
- output_hidden_states: Optional[bool] = None,
42
- return_dict: Optional[bool] = None,
43
- input_token_embs: Optional[torch.Tensor] = None,
44
- return_token_embs: Optional[bool] = False,
45
  ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
46
 
47
- # In transformers <=4.46 the transformer lives in self.text_model;
48
- # in >=4.47 it was inlined directly onto CLIPTextModel (flat structure).
49
- tm = getattr(self, 'text_model', self)
50
-
51
  if return_token_embs:
52
- return tm.embeddings.token_embedding(input_ids)
53
 
54
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
55
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
56
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
57
 
58
  if input_ids is None:
59
  raise ValueError("You have to specify input_ids")
@@ -61,13 +39,19 @@ class CLIPTextModelWrapper(CLIPTextModel):
61
  input_shape = input_ids.size()
62
  input_ids = input_ids.view(-1, input_shape[-1])
63
 
64
- hidden_states = tm.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
 
65
 
66
- causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
 
 
 
 
67
  if attention_mask is not None:
68
- attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
 
69
 
70
- encoder_outputs = tm.encoder(
71
  inputs_embeds=hidden_states,
72
  attention_mask=attention_mask,
73
  causal_attention_mask=causal_attention_mask,
@@ -77,18 +61,25 @@ class CLIPTextModelWrapper(CLIPTextModel):
77
  )
78
 
79
  last_hidden_state = encoder_outputs[0]
80
- last_hidden_state = tm.final_layer_norm(last_hidden_state)
81
 
82
- eos_token_id = getattr(tm, 'eos_token_id', self.config.eos_token_id)
83
- if eos_token_id == 2:
 
 
 
 
 
84
  pooled_output = last_hidden_state[
85
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
86
  input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
87
  ]
88
  else:
 
89
  pooled_output = last_hidden_state[
90
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
91
- (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == eos_token_id)
 
92
  .int()
93
  .argmax(dim=-1),
94
  ]
@@ -101,4 +92,4 @@ class CLIPTextModelWrapper(CLIPTextModel):
101
  pooler_output=pooled_output,
102
  hidden_states=encoder_outputs.hidden_states,
103
  attentions=encoder_outputs.attentions,
104
- )
 
2
  from transformers import CLIPTextModel
3
  from typing import Any, Callable, Dict, Optional, Tuple, Union, List
4
  from transformers.modeling_outputs import BaseModelOutputWithPooling
5
+ # from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
6
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, \
7
+ _prepare_4d_attention_mask # transformers 4.36.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
9
 
10
  class CLIPTextModelWrapper(CLIPTextModel):
11
  # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
12
  # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
 
13
  def forward(
14
+ self,
15
+ input_ids: Optional[torch.Tensor] = None,
16
+ attention_mask: Optional[torch.Tensor] = None,
17
+ position_ids: Optional[torch.Tensor] = None,
18
+ output_attentions: Optional[bool] = None,
19
+ output_hidden_states: Optional[bool] = None,
20
+ return_dict: Optional[bool] = None,
21
+ input_token_embs: Optional[torch.Tensor] = None,
22
+ return_token_embs: Optional[bool] = False,
23
  ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
24
 
 
 
 
 
25
  if return_token_embs:
26
+ return self.text_model.embeddings.token_embedding(input_ids)
27
 
28
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29
+
30
+ output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
31
+ output_hidden_states = (
32
+ output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
33
+ )
34
+ return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
35
 
36
  if input_ids is None:
37
  raise ValueError("You have to specify input_ids")
 
39
  input_shape = input_ids.size()
40
  input_ids = input_ids.view(-1, input_shape[-1])
41
 
42
+ hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids,
43
+ inputs_embeds=input_token_embs)
44
 
45
+ # CLIP's text model uses causal mask, prepare it here.
46
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
47
+ causal_attention_mask = _create_4d_causal_attention_mask(input_shape, hidden_states.dtype,
48
+ device=hidden_states.device)
49
+ # expand attention_mask
50
  if attention_mask is not None:
51
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
52
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
53
 
54
+ encoder_outputs = self.text_model.encoder(
55
  inputs_embeds=hidden_states,
56
  attention_mask=attention_mask,
57
  causal_attention_mask=causal_attention_mask,
 
61
  )
62
 
63
  last_hidden_state = encoder_outputs[0]
64
+ last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
65
 
66
+ if self.text_model.eos_token_id == 2:
67
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
68
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
69
+ # ------------------------------------------------------------
70
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
71
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
72
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
73
  pooled_output = last_hidden_state[
74
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
75
  input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
76
  ]
77
  else:
78
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
79
  pooled_output = last_hidden_state[
80
  torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
81
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
82
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
83
  .int()
84
  .argmax(dim=-1),
85
  ]
 
92
  pooler_output=pooled_output,
93
  hidden_states=encoder_outputs.hidden_states,
94
  attentions=encoder_outputs.attentions,
95
+ )