Update modeling_minicpm.py
Browse files- modeling_minicpm.py +49 -14
modeling_minicpm.py
CHANGED
|
@@ -27,7 +27,7 @@ import torch.nn.functional as F
|
|
| 27 |
import torch.utils.checkpoint
|
| 28 |
from torch import nn
|
| 29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 30 |
-
|
| 31 |
from transformers.activations import ACT2FN
|
| 32 |
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
from transformers.modeling_attn_mask_utils import (
|
|
@@ -35,6 +35,7 @@ from transformers.modeling_attn_mask_utils import (
|
|
| 35 |
_prepare_4d_attention_mask,
|
| 36 |
_prepare_4d_causal_attention_mask,
|
| 37 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
|
|
| 38 |
)
|
| 39 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 40 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -320,9 +321,6 @@ class MiniCPMAttention(nn.Module):
|
|
| 320 |
self.rope_theta = config.rope_theta
|
| 321 |
|
| 322 |
self.is_causal = config.is_causal
|
| 323 |
-
|
| 324 |
-
logger.info(f"self.is_causal = {self.is_causal}")
|
| 325 |
-
|
| 326 |
|
| 327 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 328 |
raise ValueError(
|
|
@@ -1049,17 +1047,29 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1049 |
elif self._use_sdpa and not output_attentions:
|
| 1050 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1051 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1052 |
-
|
| 1053 |
-
attention_mask
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1058 |
else:
|
| 1059 |
# 4d mask is passed through the layers
|
| 1060 |
-
|
| 1061 |
-
attention_mask
|
| 1062 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
|
| 1064 |
# embed positions
|
| 1065 |
hidden_states = inputs_embeds
|
|
@@ -1119,7 +1129,6 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
| 1119 |
attentions=all_self_attns,
|
| 1120 |
)
|
| 1121 |
|
| 1122 |
-
|
| 1123 |
class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
| 1124 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1125 |
|
|
@@ -1335,6 +1344,32 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
| 1335 |
return response, history
|
| 1336 |
|
| 1337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1338 |
@add_start_docstrings(
|
| 1339 |
"""
|
| 1340 |
The MiniCPM Model transformer with a sequence classification head on top (linear layer).
|
|
|
|
| 27 |
import torch.utils.checkpoint
|
| 28 |
from torch import nn
|
| 29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 30 |
+
from transformers import LlamaTokenizer
|
| 31 |
from transformers.activations import ACT2FN
|
| 32 |
from transformers.cache_utils import Cache, DynamicCache
|
| 33 |
from transformers.modeling_attn_mask_utils import (
|
|
|
|
| 35 |
_prepare_4d_attention_mask,
|
| 36 |
_prepare_4d_causal_attention_mask,
|
| 37 |
_prepare_4d_causal_attention_mask_for_sdpa,
|
| 38 |
+
_prepare_4d_attention_mask_for_sdpa,
|
| 39 |
)
|
| 40 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 41 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 321 |
self.rope_theta = config.rope_theta
|
| 322 |
|
| 323 |
self.is_causal = config.is_causal
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 326 |
raise ValueError(
|
|
|
|
| 1047 |
elif self._use_sdpa and not output_attentions:
|
| 1048 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1049 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1050 |
+
if self.is_causal:
|
| 1051 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
|
| 1052 |
+
attention_mask,
|
| 1053 |
+
(batch_size, seq_length),
|
| 1054 |
+
inputs_embeds,
|
| 1055 |
+
past_key_values_length,
|
| 1056 |
+
)
|
| 1057 |
+
else:
|
| 1058 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 1059 |
+
attention_mask,
|
| 1060 |
+
inputs_embeds.dtype,
|
| 1061 |
+
)
|
| 1062 |
else:
|
| 1063 |
# 4d mask is passed through the layers
|
| 1064 |
+
if self.is_causal:
|
| 1065 |
+
attention_mask = _prepare_4d_causal_attention_mask (
|
| 1066 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 1067 |
+
)
|
| 1068 |
+
else:
|
| 1069 |
+
attention_mask = _prepare_4d_attention_mask(
|
| 1070 |
+
attention_mask,
|
| 1071 |
+
inputs_embeds.dtype,
|
| 1072 |
+
)
|
| 1073 |
|
| 1074 |
# embed positions
|
| 1075 |
hidden_states = inputs_embeds
|
|
|
|
| 1129 |
attentions=all_self_attns,
|
| 1130 |
)
|
| 1131 |
|
|
|
|
| 1132 |
class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
| 1133 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1134 |
|
|
|
|
| 1344 |
return response, history
|
| 1345 |
|
| 1346 |
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
class MiniCPMRerankerLLamaTokenizer(LlamaTokenizer):
|
| 1350 |
+
def build_inputs_with_special_tokens(
|
| 1351 |
+
self, token_ids_0, token_ids_1 = None
|
| 1352 |
+
):
|
| 1353 |
+
"""
|
| 1354 |
+
- single sequence: `<s> X </s>`
|
| 1355 |
+
- pair of sequences: `<s> A </s> B`
|
| 1356 |
+
|
| 1357 |
+
Args:
|
| 1358 |
+
token_ids_0 (`List[int]`):
|
| 1359 |
+
List of IDs to which the special tokens will be added.
|
| 1360 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 1361 |
+
Optional second list of IDs for sequence pairs.
|
| 1362 |
+
|
| 1363 |
+
Returns:
|
| 1364 |
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
| 1365 |
+
"""
|
| 1366 |
+
|
| 1367 |
+
if token_ids_1 is None:
|
| 1368 |
+
return super().build_inputs_with_special_tokens(token_ids_0)
|
| 1369 |
+
bos = [self.bos_token_id]
|
| 1370 |
+
sep = [self.eos_token_id]
|
| 1371 |
+
return bos + token_ids_0 + sep + token_ids_1
|
| 1372 |
+
|
| 1373 |
@add_start_docstrings(
|
| 1374 |
"""
|
| 1375 |
The MiniCPM Model transformer with a sequence classification head on top (linear layer).
|