bestfleer commited on
Commit
a69896f
·
verified ·
1 Parent(s): 6b0b0ec

Add files using upload-large-folder tool

Browse files
README.md CHANGED
@@ -56,6 +56,7 @@ Ring-mini-sparse-2.0-exp achieves high inference efficiency through highly spars
56
  Installation requirements:
57
 
58
  ```shell
 
59
  pip install transformers==4.56.1
60
  ```
61
 
@@ -71,6 +72,7 @@ model = AutoModelForCausalLM.from_pretrained(
71
  dtype="auto",
72
  device_map="auto",
73
  trust_remote_code=True,
 
74
  )
75
  tokenizer = AutoTokenizer.from_pretrained(model_name)
76
 
 
56
  Installation requirements:
57
 
58
  ```shell
59
+ pip install flash-attn==2.6.3
60
  pip install transformers==4.56.1
61
  ```
62
 
 
72
  dtype="auto",
73
  device_map="auto",
74
  trust_remote_code=True,
75
+ attn_implementation="flash_attention_2",
76
  )
77
  tokenizer = AutoTokenizer.from_pretrained(model_name)
78
 
configuration_bailing_moe_v2.py CHANGED
@@ -1,45 +1,48 @@
1
- """Bailing MoE model configuration"""
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
 
5
 
6
  class BailingMoeV2Config(PretrainedConfig):
7
- model_type = "bailing_moe_v2"
8
 
9
  def __init__(
10
  self,
11
- vocab_size=30592,
12
- hidden_size=1024,
13
- intermediate_size=None,
14
- num_hidden_layers=24,
15
  num_attention_heads=16,
16
- num_key_value_heads=0,
17
  hidden_act="silu",
18
  use_qkv_bias=False, # bailing only
19
- use_bias=True, # bailing only
20
- rms_norm_eps=1e-05,
21
- norm_head=False, # bailing only
22
  tie_word_embeddings=False, # PretrainedConfig key, here change default value.
23
- embedding_dropout=0.1,
24
- attention_dropout=0.1,
25
- output_dropout=0.1,
26
  initializer_range=0.02,
27
- max_position_embeddings=16384,
28
- rope_theta=10000.0,
29
  use_cache=True,
30
- use_sliding_window=False,
31
- sliding_window=4096,
32
- max_window_layers=28,
33
  rope_scaling=None,
34
- pad_token_id=126081,
35
- num_experts=16,
36
- num_shared_experts=0,
37
- num_experts_per_tok=2,
38
- norm_topk_prob=True,
39
- moe_intermediate_size=None,
40
- first_k_dense_replace=0,
41
- head_dim=None,
 
 
42
  output_router_logits=False,
 
 
 
 
 
43
  **kwargs,
44
  ):
45
  self.num_hidden_layers = num_hidden_layers
@@ -51,28 +54,31 @@ class BailingMoeV2Config(PretrainedConfig):
51
  self.hidden_act = hidden_act
52
  self.use_qkv_bias = use_qkv_bias
53
  self.use_bias = use_bias
54
- self.norm_head = norm_head
55
  self.rms_norm_eps = rms_norm_eps
56
  self.embedding_dropout = embedding_dropout
57
  self.attention_dropout = attention_dropout
58
  self.output_dropout = output_dropout
 
 
59
  self.initializer_range = initializer_range
60
  self.max_position_embeddings = max_position_embeddings
61
  self.rope_theta = rope_theta
62
  self.use_cache = use_cache
63
- self.use_sliding_window = use_sliding_window
64
- self.sliding_window = sliding_window
65
  self.max_window_layers = max_window_layers
66
  self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
67
  self.rope_scaling = rope_scaling
 
 
 
68
 
69
  # MoE configs
70
  self.num_experts = num_experts
71
  self.num_shared_experts = num_shared_experts
72
  self.num_experts_per_tok = num_experts_per_tok
73
- self.norm_topk_prob = norm_topk_prob
 
74
  self.moe_intermediate_size = moe_intermediate_size
75
  self.first_k_dense_replace = first_k_dense_replace
76
  self.output_router_logits = output_router_logits
77
 
78
- super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
 
1
+ """Bailing MoE V2 model configuration"""
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
 
5
 
6
  class BailingMoeV2Config(PretrainedConfig):
 
7
 
8
  def __init__(
9
  self,
10
+ vocab_size=157184,
11
+ hidden_size=2048,
12
+ intermediate_size=5120,
13
+ num_hidden_layers=20,
14
  num_attention_heads=16,
15
+ num_key_value_heads=4,
16
  hidden_act="silu",
17
  use_qkv_bias=False, # bailing only
18
+ use_bias=False, # bailing only
19
+ rms_norm_eps=1e-06,
 
20
  tie_word_embeddings=False, # PretrainedConfig key, here change default value.
21
+ embedding_dropout=0.0,
22
+ attention_dropout=0.0,
23
+ output_dropout=0.0,
24
  initializer_range=0.02,
25
+ max_position_embeddings=32768,
26
+ rope_theta=600000.0,
27
  use_cache=True,
28
+ max_window_layers=20,
 
 
29
  rope_scaling=None,
30
+ pad_token_id=156892,
31
+ eos_token_id=156892,
32
+ num_experts=256,
33
+ num_shared_experts=1,
34
+ num_experts_per_tok=8,
35
+ n_group=8,
36
+ topk_group=4,
37
+ moe_intermediate_size=512,
38
+ first_k_dense_replace=1,
39
+ head_dim=128,
40
  output_router_logits=False,
41
+ use_qk_norm=True,
42
+ num_nextn_predict_layers=0,
43
+ mtp_loss_scaling_factor=0,
44
+ moe_router_enable_expert_bias=True,
45
+ routed_scaling_factor=1.0,
46
  **kwargs,
47
  ):
48
  self.num_hidden_layers = num_hidden_layers
 
54
  self.hidden_act = hidden_act
55
  self.use_qkv_bias = use_qkv_bias
56
  self.use_bias = use_bias
 
57
  self.rms_norm_eps = rms_norm_eps
58
  self.embedding_dropout = embedding_dropout
59
  self.attention_dropout = attention_dropout
60
  self.output_dropout = output_dropout
61
+ self.num_nextn_predict_layers = num_nextn_predict_layers
62
+ self.mtp_loss_scaling_factor = mtp_loss_scaling_factor
63
  self.initializer_range = initializer_range
64
  self.max_position_embeddings = max_position_embeddings
65
  self.rope_theta = rope_theta
66
  self.use_cache = use_cache
 
 
67
  self.max_window_layers = max_window_layers
68
  self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
69
  self.rope_scaling = rope_scaling
70
+ self.use_qk_norm = use_qk_norm
71
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
72
+ self.routed_scaling_factor = routed_scaling_factor
73
 
74
  # MoE configs
75
  self.num_experts = num_experts
76
  self.num_shared_experts = num_shared_experts
77
  self.num_experts_per_tok = num_experts_per_tok
78
+ self.n_group = n_group
79
+ self.topk_group = topk_group
80
  self.moe_intermediate_size = moe_intermediate_size
81
  self.first_k_dense_replace = first_k_dense_replace
82
  self.output_router_logits = output_router_logits
83
 
84
+ super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
modeling_bailing_moe_v2.py CHANGED
@@ -1,14 +1,5 @@
1
- #!/usr/bin/python
2
- #****************************************************************#
3
- # ScriptName: modeling_bailing_moe_v2.py
4
- # Author: $SHTERM_REAL_USER@alibaba-inc.com
5
- # Create Date: 2025-08-12 20:22
6
- # Modify Author: $SHTERM_REAL_USER@alibaba-inc.com
7
- # Modify Date: 2025-08-12 20:22
8
- # Function:
9
- #***************************************************************#
10
  # coding=utf-8
11
- # Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
12
  #
13
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
14
  # and OPT implementations in this library. It has been modified from its
@@ -27,15 +18,14 @@
27
  # See the License for the specific language governing permissions and
28
  # limitations under the License.
29
  """PyTorch BailingMoE model."""
 
30
  import math
31
  import warnings
32
  from typing import List, Optional, Tuple, Union
33
 
34
  import torch
35
  import torch.nn.functional as F
36
- import torch.utils.checkpoint
37
  from torch import nn
38
- from torch.nn import CrossEntropyLoss
39
 
40
  from transformers.activations import ACT2FN
41
  from transformers.cache_utils import Cache, DynamicCache
@@ -45,10 +35,8 @@ from transformers.modeling_attn_mask_utils import (
45
  _prepare_4d_causal_attention_mask,
46
  _prepare_4d_causal_attention_mask_for_sdpa,
47
  )
48
- from transformers.modeling_outputs import (
49
- MoeModelOutputWithPast,
50
- MoeCausalLMOutputWithPast,
51
- )
52
  from transformers.modeling_utils import PreTrainedModel
53
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
54
  from transformers.utils import (
@@ -61,7 +49,11 @@ from transformers.utils import (
61
  )
62
  from transformers.utils.import_utils import is_torch_fx_available
63
  from .configuration_bailing_moe_v2 import BailingMoeV2Config
64
-
 
 
 
 
65
 
66
  if is_flash_attn_2_available():
67
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -82,6 +74,383 @@ logger = logging.get_logger(__name__)
82
  _CONFIG_FOR_DOC = "BailingMoeV2Config"
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def _get_unpad_data(attention_mask):
86
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
87
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
@@ -133,171 +502,37 @@ ALL_LAYERNORM_LAYERS.append(BailingMoeV2RMSNorm)
133
 
134
 
135
  class BailingMoeV2RotaryEmbedding(nn.Module):
136
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
137
  super().__init__()
 
 
 
 
 
 
 
138
 
139
- self.dim = dim
140
- self.max_position_embeddings = max_position_embeddings
141
- self.base = base
142
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
143
- self.register_buffer("inv_freq", inv_freq, persistent=False)
144
-
145
- # Build here to make `torch.jit.trace` work.
146
- self._set_cos_sin_cache(
147
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
148
- )
149
- self.max_seq_len_cached = None
150
-
151
- def _set_cos_sin_cache(self, seq_len, device, dtype):
152
- self.max_seq_len_cached = seq_len
153
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
154
-
155
- freqs = torch.outer(t, self.inv_freq.to(t.device))
156
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
157
- emb = torch.cat((freqs, freqs), dim=-1)
158
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
159
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
160
-
161
- def forward(self, x, seq_len=None):
162
- # x: [bs, num_attention_heads, seq_len, head_size]
163
- if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
164
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
165
-
166
- return (
167
- self.cos_cached[:seq_len].to(dtype=x.dtype),
168
- self.sin_cached[:seq_len].to(dtype=x.dtype),
169
- )
170
-
171
-
172
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->BailingMoeV2
173
- class BailingMoeV2LinearScalingRotaryEmbedding(BailingMoeV2RotaryEmbedding):
174
- """BailingMoeV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
175
-
176
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
177
- self.scaling_factor = scaling_factor
178
- super().__init__(dim, max_position_embeddings, base, device)
179
-
180
- def _set_cos_sin_cache(self, seq_len, device, dtype):
181
- self.max_seq_len_cached = seq_len
182
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
183
- t = t / self.scaling_factor
184
-
185
- freqs = torch.outer(t, self.inv_freq)
186
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
187
- emb = torch.cat((freqs, freqs), dim=-1)
188
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
189
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
190
-
191
-
192
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->BailingMoeV2
193
- class BailingMoeV2DynamicNTKScalingRotaryEmbedding(BailingMoeV2RotaryEmbedding):
194
- """BailingMoeV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
195
-
196
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
197
- self.scaling_factor = scaling_factor
198
- super().__init__(dim, max_position_embeddings, base, device)
199
-
200
- def _set_cos_sin_cache(self, seq_len, device, dtype):
201
- self.max_seq_len_cached = seq_len
202
-
203
- if seq_len > self.max_position_embeddings:
204
- base = self.base * (
205
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
206
- ) ** (self.dim / (self.dim - 2))
207
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
208
- self.register_buffer("inv_freq", inv_freq, persistent=False)
209
-
210
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
211
-
212
- freqs = torch.outer(t, self.inv_freq)
213
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
214
- emb = torch.cat((freqs, freqs), dim=-1)
215
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
216
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
217
-
218
-
219
- # Inverse dim formula to find dim based on number of rotations
220
- def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
221
- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
222
-
223
-
224
- # Find dim range bounds based on rotations
225
- def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
226
- low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
227
- high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
228
- return max(low, 0), min(high, dim - 1) # Clamp values just in case
229
-
230
-
231
- def yarn_get_mscale(scale=1, mscale=1):
232
- if scale <= 1:
233
- return 1.0
234
- return 0.1 * mscale * math.log(scale) + 1.0
235
-
236
-
237
- def yarn_linear_ramp_mask(min, max, dim):
238
- if min == max:
239
- max += 0.001 # Prevent singularity
240
-
241
- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
242
- ramp_func = torch.clamp(linear_func, 0, 1)
243
- return ramp_func
244
-
245
-
246
- class BailingMoeV2YarnRotaryEmbedding(BailingMoeV2RotaryEmbedding):
247
-
248
- def __init__(
249
- self,
250
- dim,
251
- max_position_embeddings=2048,
252
- base=10000,
253
- device=None,
254
- scaling_factor=1.0,
255
- original_max_position_embeddings=4096,
256
- beta_fast=32,
257
- beta_slow=1,
258
- mscale=1,
259
- mscale_all_dim=0,
260
- ):
261
- self.scaling_factor = scaling_factor
262
- self.original_max_position_embeddings = original_max_position_embeddings
263
- self.beta_fast = beta_fast
264
- self.beta_slow = beta_slow
265
- self.mscale = mscale
266
- self.mscale_all_dim = mscale_all_dim
267
- super().__init__(dim, max_position_embeddings, base, device)
268
-
269
- def _set_cos_sin_cache(self, seq_len, device, dtype):
270
- self.max_seq_len_cached = seq_len
271
- dim = self.dim
272
-
273
- freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
274
- freq_inter = 1.0 / (
275
- self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
276
- )
277
 
278
- low, high = yarn_find_correction_range(
279
- self.beta_fast,
280
- self.beta_slow,
281
- dim,
282
- self.base,
283
- self.original_max_position_embeddings,
284
- )
285
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32)
286
- inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
287
  self.register_buffer("inv_freq", inv_freq, persistent=False)
 
288
 
289
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
 
 
 
 
290
 
291
- freqs = torch.outer(t, inv_freq)
 
 
 
 
 
292
 
293
- _mscale = float(
294
- yarn_get_mscale(self.scaling_factor, self.mscale)
295
- / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
296
- )
297
-
298
- emb = torch.cat((freqs, freqs), dim=-1)
299
- self.register_buffer("cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False)
300
- self.register_buffer("sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False)
301
 
302
 
303
  # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -309,17 +544,13 @@ def rotate_half(x):
309
 
310
 
311
  # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
312
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
313
  """Applies Rotary Position Embedding to the query and key tensors.
314
-
315
  Args:
316
  q (`torch.Tensor`): The query tensor.
317
  k (`torch.Tensor`): The key tensor.
318
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
319
  sin (`torch.Tensor`): The sine part of the rotary embedding.
320
- position_ids (`torch.Tensor`):
321
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
322
- used to pass offsetted position ids when working with a KV-cache.
323
  unsqueeze_dim (`int`, *optional*, defaults to 1):
324
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
325
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -330,10 +561,21 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
330
  Returns:
331
  `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
332
  """
333
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
334
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
335
- q_embed = (q * cos) + (rotate_half(q) * sin)
336
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
 
 
337
  return q_embed, k_embed
338
 
339
 
@@ -360,14 +602,15 @@ class BailingMoeV2Gate(nn.Module):
360
  self.top_k = config.num_experts_per_tok
361
  self.num_experts = config.num_experts
362
 
 
 
 
363
  # topk selection algorithm
364
- self.norm_topk_prob = config.norm_topk_prob
365
  self.gating_dim = config.hidden_size
366
  self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
367
- self.moe_router_topk_scaling_factor = config.moe_router_topk_scaling_factor
368
 
369
- if self.config.use_expert_bias:
370
- self.register_buffer("expert_bias", torch.zeros((self.num_experts)))
371
  self.reset_parameters()
372
 
373
  def reset_parameters(self) -> None:
@@ -375,39 +618,45 @@ class BailingMoeV2Gate(nn.Module):
375
 
376
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  def forward(self, hidden_states):
379
  # compute gating score
380
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
381
- logits = F.linear(hidden_states, self.weight, None)
382
-
383
- if self.config.gate_score_function == 'softmax':
384
- scores = logits.softmax(dim=-1, dtype=torch.float32)
385
 
386
- # select top-k experts
387
- topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1)
388
 
389
- # norm gate to sum 1
390
- if self.top_k > 1 and self.norm_topk_prob:
391
- denominator = topk_weight.sum(dim=-1, keepdim=True)
392
- topk_weight = topk_weight / denominator
393
- topk_weight = topk_weight * self.moe_router_topk_scaling_factor
394
 
395
- return topk_idx, topk_weight, logits
396
- elif self.config.gate_score_function == 'sigmoid':
397
- scores = torch.sigmoid(logits)
398
 
399
- if self.config.use_expert_bias:
400
- scores_for_routing = scores + self.expert_bias
401
- _, topk_idx = torch.topk(scores_for_routing, k=self.top_k, dim=-1)
402
- scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
403
- else:
404
- scores, topk_idx = torch.topk(scores, k=self.top_k, dim=-1)
405
- topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
406
- topk_weight = topk_weight * self.moe_router_topk_scaling_factor
407
 
408
- return topk_idx, topk_weight, logits
409
- else:
410
- raise ValueError(f"Unsupported gate_score_function: {self.config.gate_score_function}")
411
 
412
 
413
  class BailingMoeV2SparseMoeBlock(nn.Module):
@@ -460,7 +709,6 @@ class BailingMoeV2SparseMoeBlock(nn.Module):
460
  tokens_per_expert = cnts.sum(dim=0)
461
  idxs = topk_ids.view(-1).argsort()
462
  sorted_tokens = x[idxs // topk_ids.shape[1]]
463
- sorted_tokens_shape = sorted_tokens.shape
464
  tokens_per_expert = tokens_per_expert.cpu().numpy()
465
  outputs = []
466
  start_idx = 0
@@ -471,7 +719,7 @@ class BailingMoeV2SparseMoeBlock(nn.Module):
471
  expert = self.experts[i]
472
  tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
473
  expert_out = expert(tokens_for_this_expert)
474
- outputs.append(expert_out)
475
  start_idx = end_idx
476
 
477
  outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
@@ -519,6 +767,8 @@ class BailingMoeV2Attention(nn.Module):
519
  self.hidden_size = config.hidden_size
520
  self.num_heads = config.num_attention_heads
521
  self.head_dim = config.head_dim or self.hidden_size // self.num_heads
 
 
522
  self.num_key_value_heads = config.num_key_value_heads
523
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
524
  self.max_position_embeddings = config.max_position_embeddings
@@ -532,56 +782,9 @@ class BailingMoeV2Attention(nn.Module):
532
  )
533
 
534
  if self.config.use_qk_norm:
535
- self.q_norm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
536
- self.k_norm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
537
  self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
538
- self._init_rope()
539
-
540
- def _init_rope(self):
541
- if self.config.rope_scaling is None:
542
- self.rotary_emb = BailingMoeV2RotaryEmbedding(
543
- self.head_dim,
544
- max_position_embeddings=self.max_position_embeddings,
545
- base=self.rope_theta,
546
- )
547
- else:
548
- scaling_type = self.config.rope_scaling["type"]
549
- scaling_factor = self.config.rope_scaling["factor"]
550
- if scaling_type == "linear":
551
- self.rotary_emb = BailingMoeV2LinearScalingRotaryEmbedding(
552
- self.head_dim,
553
- max_position_embeddings=self.max_position_embeddings,
554
- scaling_factor=scaling_factor,
555
- base=self.rope_theta,
556
- )
557
- elif scaling_type == "dynamic":
558
- self.rotary_emb = BailingMoeV2DynamicNTKScalingRotaryEmbedding(
559
- self.head_dim,
560
- max_position_embeddings=self.max_position_embeddings,
561
- scaling_factor=scaling_factor,
562
- base=self.rope_theta,
563
- )
564
- elif scaling_type == "yarn":
565
- kwargs = {
566
- key: self.config.rope_scaling[key]
567
- for key in [
568
- "original_max_position_embeddings",
569
- "beta_fast",
570
- "beta_slow",
571
- "mscale",
572
- "mscale_all_dim",
573
- ]
574
- if key in self.config.rope_scaling
575
- }
576
- self.rotary_emb = BailingMoeV2YarnRotaryEmbedding(
577
- self.head_dim,
578
- max_position_embeddings=self.max_position_embeddings,
579
- scaling_factor=scaling_factor,
580
- base=self.rope_theta,
581
- **kwargs,
582
- )
583
- else:
584
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
585
 
586
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
587
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -594,12 +797,9 @@ class BailingMoeV2Attention(nn.Module):
594
  past_key_value: Optional[Cache] = None,
595
  output_attentions: bool = False,
596
  use_cache: bool = False,
 
597
  **kwargs,
598
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
599
- if "padding_mask" in kwargs:
600
- warnings.warn(
601
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
602
- )
603
 
604
  bsz, q_len, _ = hidden_states.size()
605
 
@@ -614,10 +814,12 @@ class BailingMoeV2Attention(nn.Module):
614
  value_states = value_states.transpose(1, 2)
615
 
616
  if self.config.use_qk_norm:
617
- query_states = self.q_norm(query_states)
618
- key_states = self.k_norm(key_states)
 
 
 
619
 
620
- kv_seq_len = key_states.shape[-2]
621
  if past_key_value is not None:
622
  if self.layer_idx is None:
623
  raise ValueError(
@@ -625,19 +827,15 @@ class BailingMoeV2Attention(nn.Module):
625
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
626
  "with a layer index."
627
  )
628
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
629
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
630
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
631
-
632
- if past_key_value is not None:
633
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
634
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
635
 
636
  key_states = repeat_kv(key_states, self.num_key_value_groups)
637
  value_states = repeat_kv(value_states, self.num_key_value_groups)
638
 
639
- attn_weights = torch.matmul(query_states / math.sqrt(self.head_dim), key_states.transpose(2, 3))
640
 
 
641
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
642
  raise ValueError(
643
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@@ -698,17 +896,10 @@ class BailingMoeV2FlashAttention2(BailingMoeV2Attention):
698
  past_key_value: Optional[Cache] = None,
699
  output_attentions: bool = False,
700
  use_cache: bool = False,
 
701
  **kwargs,
702
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
703
  # BailingMoeV2FlashAttention2 attention does not support output_attentions
704
- if "padding_mask" in kwargs:
705
- warnings.warn(
706
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
707
- )
708
-
709
- # overwrite attention_mask with padding_mask
710
- attention_mask = kwargs.pop("padding_mask")
711
-
712
  output_attentions = False
713
 
714
  bsz, q_len, _ = hidden_states.size()
@@ -728,17 +919,14 @@ class BailingMoeV2FlashAttention2(BailingMoeV2Attention):
728
  value_states = value_states.transpose(1, 2)
729
 
730
  if self.config.use_qk_norm:
731
- query_states = self.q_norm(query_states)
732
- key_states = self.k_norm(key_states)
733
 
734
- kv_seq_len = key_states.shape[-2]
735
- if past_key_value is not None:
736
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
737
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
738
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
739
 
740
  if past_key_value is not None:
741
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
742
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
743
 
744
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
@@ -763,7 +951,7 @@ class BailingMoeV2FlashAttention2(BailingMoeV2Attention):
763
  elif torch.is_autocast_enabled():
764
  target_dtype = torch.get_autocast_gpu_dtype()
765
  else:
766
- target_dtype = self.q_proj.weight.dtype
767
 
768
  logger.warning_once(
769
  f"The input hidden states seems to be silently casted in float32, this might be related to"
@@ -774,10 +962,14 @@ class BailingMoeV2FlashAttention2(BailingMoeV2Attention):
774
  query_states = query_states.to(target_dtype)
775
  key_states = key_states.to(target_dtype)
776
  value_states = value_states.to(target_dtype)
777
-
778
- attn_output = self._flash_attention_forward(
779
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
780
- )
 
 
 
 
781
 
782
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
783
  attn_output = self.dense(attn_output)
@@ -786,6 +978,85 @@ class BailingMoeV2FlashAttention2(BailingMoeV2Attention):
786
  attn_weights = None
787
 
788
  return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
 
790
  def _flash_attention_forward(
791
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
@@ -793,7 +1064,6 @@ class BailingMoeV2FlashAttention2(BailingMoeV2Attention):
793
  """
794
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
795
  first unpad the input, then computes the attention scores and pad the final attention scores.
796
-
797
  Args:
798
  query_states (`torch.Tensor`):
799
  Input query states to be passed to Flash Attention API
@@ -906,6 +1176,7 @@ class BailingMoeV2SdpaAttention(BailingMoeV2Attention):
906
  past_key_value: Optional[Cache] = None,
907
  output_attentions: bool = False,
908
  use_cache: bool = False,
 
909
  **kwargs,
910
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
911
  if output_attentions:
@@ -936,24 +1207,21 @@ class BailingMoeV2SdpaAttention(BailingMoeV2Attention):
936
  value_states = value_states.transpose(1, 2)
937
 
938
  if self.config.use_qk_norm:
939
- query_states = self.q_norm(query_states)
940
- key_states = self.k_norm(key_states)
941
 
942
- kv_seq_len = key_states.shape[-2]
943
- if past_key_value is not None:
944
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
945
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
946
-
947
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
948
 
949
  if past_key_value is not None:
950
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
951
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
952
 
953
  key_states = repeat_kv(key_states, self.num_key_value_groups)
954
  value_states = repeat_kv(value_states, self.num_key_value_groups)
955
 
956
  if attention_mask is not None:
 
957
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
958
  raise ValueError(
959
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
@@ -991,6 +1259,78 @@ ATTENTION_CLASSES = {
991
  }
992
 
993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994
  class BailingMoeV2DecoderLayer(nn.Module):
995
  def __init__(self, config: BailingMoeV2Config, layer_idx: int):
996
  super().__init__()
@@ -1015,6 +1355,7 @@ class BailingMoeV2DecoderLayer(nn.Module):
1015
  output_attentions: Optional[bool] = False,
1016
  output_router_logits: Optional[bool] = False,
1017
  use_cache: Optional[bool] = False,
 
1018
  **kwargs,
1019
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1020
  """
@@ -1038,10 +1379,6 @@ class BailingMoeV2DecoderLayer(nn.Module):
1038
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1039
  (see `past_key_values`).
1040
  """
1041
- if "padding_mask" in kwargs:
1042
- warnings.warn(
1043
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1044
- )
1045
  residual = hidden_states
1046
 
1047
  hidden_states = self.input_layernorm(hidden_states)
@@ -1053,6 +1390,7 @@ class BailingMoeV2DecoderLayer(nn.Module):
1053
  position_ids=position_ids,
1054
  past_key_value=past_key_value,
1055
  output_attentions=output_attentions,
 
1056
  use_cache=use_cache,
1057
  )
1058
  hidden_states = residual + hidden_states
@@ -1065,7 +1403,7 @@ class BailingMoeV2DecoderLayer(nn.Module):
1065
  hidden_states, router_logits = hidden_states
1066
  else:
1067
  router_logits = None
1068
- hidden_states = residual + hidden_states
1069
 
1070
  outputs = (hidden_states,)
1071
 
@@ -1085,11 +1423,9 @@ BAILINGMOEV2_START_DOCSTRING = r"""
1085
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1086
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1087
  etc.)
1088
-
1089
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1090
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1091
  and behavior.
1092
-
1093
  Parameters:
1094
  config ([`BailingMoeV2Config`]):
1095
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -1129,50 +1465,38 @@ BAILINGMOEV2_INPUTS_DOCSTRING = r"""
1129
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1130
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1131
  it.
1132
-
1133
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1134
  [`PreTrainedTokenizer.__call__`] for details.
1135
-
1136
  [What are input IDs?](../glossary#input-ids)
1137
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1138
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1139
-
1140
  - 1 for tokens that are **not masked**,
1141
  - 0 for tokens that are **masked**.
1142
-
1143
  [What are attention masks?](../glossary#attention-mask)
1144
-
1145
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1146
  [`PreTrainedTokenizer.__call__`] for details.
1147
-
1148
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1149
  `past_key_values`).
1150
-
1151
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1152
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1153
  information on the default strategy.
1154
-
1155
  - 1 indicates the head is **not masked**,
1156
  - 0 indicates the head is **masked**.
1157
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1158
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1159
  config.n_positions - 1]`.
1160
-
1161
  [What are position IDs?](../glossary#position-ids)
1162
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1163
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1164
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1165
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1166
-
1167
  Two formats are allowed:
1168
  - a [`~cache_utils.Cache`] instance;
1169
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1170
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1171
  cache format.
1172
-
1173
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1174
  legacy cache format will be returned.
1175
-
1176
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1177
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1178
  of shape `(batch_size, sequence_length)`.
@@ -1201,7 +1525,6 @@ BAILINGMOEV2_INPUTS_DOCSTRING = r"""
1201
  class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1202
  """
1203
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BailingMoeV2DecoderLayer`]
1204
-
1205
  Args:
1206
  config: BailingMoeV2Config
1207
  """
@@ -1210,15 +1533,20 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1210
  super().__init__(config)
1211
  self.padding_idx = config.pad_token_id
1212
  self.vocab_size = config.vocab_size
 
1213
 
1214
  self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1215
- self.layers = nn.ModuleList(
1216
- [BailingMoeV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1217
- )
 
 
 
 
1218
  self._use_sdpa = config._attn_implementation == "sdpa"
1219
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1220
  self.norm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1221
-
1222
  self.gradient_checkpointing = False
1223
  # Initialize weights and apply final processing
1224
  self.post_init()
@@ -1243,7 +1571,7 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1243
  output_router_logits: Optional[bool] = None,
1244
  return_dict: Optional[bool] = None,
1245
  **kwargs,
1246
- ) -> Union[Tuple, MoeModelOutputWithPast]:
1247
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1248
  output_hidden_states = (
1249
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1272,23 +1600,20 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1272
  )
1273
  use_cache = False
1274
 
1275
- past_key_values_length = 0
1276
- if use_cache:
1277
- use_legacy_cache = not isinstance(past_key_values, Cache)
1278
- if use_legacy_cache:
1279
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1280
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
1281
 
1282
  if position_ids is None:
1283
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1284
  position_ids = torch.arange(
1285
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1286
  )
1287
  position_ids = position_ids.unsqueeze(0)
1288
 
1289
- if inputs_embeds is None:
1290
- inputs_embeds = self.word_embeddings(input_ids)
1291
-
1292
  if self._use_flash_attention_2:
1293
  # 2d mask is passed through the layers
1294
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
@@ -1299,24 +1624,29 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1299
  attention_mask,
1300
  (batch_size, seq_length),
1301
  inputs_embeds,
1302
- past_key_values_length,
1303
  )
1304
  else:
1305
  # 4d mask is passed through the layers
1306
  attention_mask = _prepare_4d_causal_attention_mask(
1307
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1308
  )
1309
 
1310
  # embed positions
1311
  hidden_states = inputs_embeds
1312
 
 
 
 
1313
  # decoder layers
1314
  all_hidden_states = () if output_hidden_states else None
1315
  all_self_attns = () if output_attentions else None
1316
  all_router_logits = () if output_router_logits else None
1317
  next_decoder_cache = None
 
 
1318
 
1319
- for decoder_layer in self.layers:
1320
  if output_hidden_states:
1321
  all_hidden_states += (hidden_states,)
1322
 
@@ -1330,6 +1660,7 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1330
  output_attentions,
1331
  output_router_logits,
1332
  use_cache,
 
1333
  )
1334
  else:
1335
  layer_outputs = decoder_layer(
@@ -1340,6 +1671,7 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1340
  output_attentions=output_attentions,
1341
  output_router_logits=output_router_logits,
1342
  use_cache=use_cache,
 
1343
  )
1344
  hidden_states = layer_outputs[0]
1345
 
@@ -1353,38 +1685,90 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1353
  all_router_logits += (layer_outputs[-1],)
1354
 
1355
  hidden_states = self.norm(hidden_states)
 
1356
 
1357
  # add hidden states from the last decoder layer
1358
  if output_hidden_states:
1359
- all_hidden_states += (hidden_states,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1360
 
1361
  next_cache = None
1362
  if use_cache:
1363
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1364
  if not return_dict:
1365
  return tuple(
1366
  v
1367
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1368
  if v is not None
1369
  )
1370
- return MoeModelOutputWithPast(
1371
- last_hidden_state=hidden_states,
1372
  past_key_values=next_cache,
1373
  hidden_states=all_hidden_states,
 
1374
  attentions=all_self_attns,
1375
  router_logits=all_router_logits,
1376
  )
1377
 
1378
 
1379
- class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel):
1380
  _tied_weights_keys = ["lm_head.weight"]
1381
 
1382
  def __init__(self, config: BailingMoeV2Config):
1383
  super().__init__(config)
1384
  self.model = BailingMoeV2Model(config)
1385
  self.vocab_size = config.vocab_size
1386
- self.norm_head = config.norm_head
1387
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
1388
 
1389
  # Initialize weights and apply final processing
1390
  self.post_init()
@@ -1407,26 +1791,8 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel):
1407
  def get_decoder(self):
1408
  return self.model
1409
 
1410
- def compute_logit(self, hidden_states):
1411
- if self.norm_head:
1412
- if self.training:
1413
- norm_weight = (
1414
- self.lm_head.weight / (torch.norm(self.lm_head.weight, p=2, dim=0, keepdim=True) + 1e-7).detach()
1415
- )
1416
- logits = F.linear(hidden_states, norm_weight, None)
1417
- else:
1418
- self.lm_head.weight.data = (
1419
- self.lm_head.weight.data.float()
1420
- / (torch.norm(self.lm_head.weight.data.float(), p=2, dim=0, keepdim=True) + 1e-7)
1421
- ).to(hidden_states.dtype)
1422
- logits = F.linear(hidden_states, self.lm_head.weight.data, None)
1423
- self.norm_head = False
1424
- else:
1425
- logits = self.lm_head(hidden_states)
1426
- return logits
1427
-
1428
  @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING)
1429
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1430
  def forward(
1431
  self,
1432
  input_ids: torch.LongTensor = None,
@@ -1441,27 +1807,21 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel):
1441
  output_router_logits: Optional[bool] = None,
1442
  return_dict: Optional[bool] = None,
1443
  **kwargs,
1444
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1445
  r"""
1446
  Args:
1447
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1448
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1449
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1450
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1451
-
1452
  Returns:
1453
-
1454
  Example:
1455
-
1456
  ```python
1457
  >>> from transformers import AutoTokenizer
1458
-
1459
  >>> model = BailingMoeV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1460
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1461
-
1462
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1463
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1464
-
1465
  >>> # Generate
1466
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1467
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1490,25 +1850,40 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel):
1490
  **kwargs,
1491
  )
1492
 
1493
- hidden_states = outputs[0]
1494
-
1495
- logits = self.compute_logit(hidden_states=hidden_states)
1496
- logits = logits.float()
1497
-
1498
  loss = None
 
1499
  aux_loss = None
 
 
 
1500
 
1501
  if labels is not None:
1502
- # Shift so that tokens < n predict n
1503
- shift_logits = logits[..., :-1, :].contiguous()
1504
- shift_labels = labels[..., 1:].contiguous()
1505
- # Flatten the tokens
1506
- loss_fct = CrossEntropyLoss()
1507
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1508
- shift_labels = shift_labels.view(-1)
1509
- # Enable model parallelism
1510
- shift_labels = shift_labels.to(shift_logits.device)
1511
- loss = loss_fct(shift_logits, shift_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1512
 
1513
  if not return_dict:
1514
  output = (logits,) + outputs[1:]
@@ -1516,82 +1891,14 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel):
1516
  output = (aux_loss,) + output
1517
  return (loss,) + output if loss is not None else output
1518
 
1519
- return MoeCausalLMOutputWithPast(
1520
  loss=loss,
 
1521
  aux_loss=aux_loss,
1522
  logits=logits,
 
1523
  past_key_values=outputs.past_key_values,
1524
  hidden_states=outputs.hidden_states,
1525
  attentions=outputs.attentions,
1526
  router_logits=outputs.router_logits,
1527
  )
1528
-
1529
- def prepare_inputs_for_generation(
1530
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_type_ids=None, **kwargs
1531
- ):
1532
- if past_key_values is not None:
1533
- if isinstance(past_key_values, Cache):
1534
- cache_length = past_key_values.get_seq_length()
1535
- past_length = past_key_values.seen_tokens
1536
- max_cache_length = (
1537
- past_key_values.get_max_length()
1538
- if hasattr(past_key_values, "get_max_length")
1539
- else past_key_values.get_max_cache_shape()
1540
- )
1541
- else:
1542
- cache_length = past_length = past_key_values[0][0].shape[2]
1543
- max_cache_length = None
1544
-
1545
- # Keep only the unprocessed tokens:
1546
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1547
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as input)
1548
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1549
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1550
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1551
- # input_ids based on the past_length.
1552
- elif past_length < input_ids.shape[1]:
1553
- input_ids = input_ids[:, past_length:]
1554
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1555
-
1556
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1557
- if (
1558
- max_cache_length is not None
1559
- and attention_mask is not None
1560
- and cache_length + input_ids.shape[1] > max_cache_length
1561
- ):
1562
- attention_mask = attention_mask[:, -max_cache_length:]
1563
-
1564
- position_ids = kwargs.get("position_ids", None)
1565
- if attention_mask is not None and position_ids is None:
1566
- # create position_ids on the fly for batch generation
1567
- position_ids = attention_mask.long().cumsum(-1) - 1
1568
- position_ids.masked_fill_(attention_mask == 0, 1)
1569
- if past_key_values:
1570
- position_ids = position_ids[:, -input_ids.shape[1] :]
1571
-
1572
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1573
- if inputs_embeds is not None and past_key_values is None:
1574
- model_inputs = {"inputs_embeds": inputs_embeds}
1575
- else:
1576
- model_inputs = {"input_ids": input_ids}
1577
-
1578
- model_inputs.update(
1579
- {
1580
- "position_ids": position_ids,
1581
- "past_key_values": past_key_values,
1582
- "use_cache": kwargs.get("use_cache"),
1583
- "attention_mask": attention_mask,
1584
- }
1585
- )
1586
- return model_inputs
1587
-
1588
- @staticmethod
1589
- def _reorder_cache(past_key_values, beam_idx):
1590
- reordered_past = ()
1591
- for layer_past in past_key_values:
1592
- reordered_past += (
1593
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1594
- )
1595
- return reordered_past
1596
-
1597
-
 
 
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2025 Antgroup and The HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
  """PyTorch BailingMoE model."""
21
+
22
  import math
23
  import warnings
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
27
  import torch.nn.functional as F
 
28
  from torch import nn
 
29
 
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache
 
35
  _prepare_4d_causal_attention_mask,
36
  _prepare_4d_causal_attention_mask_for_sdpa,
37
  )
38
+ from transformers.modeling_outputs import MoeModelOutputWithPast
39
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
 
 
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
42
  from transformers.utils import (
 
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_bailing_moe_v2 import BailingMoeV2Config
52
+ from transformers.generation.utils import GenerationMixin
53
+ from dataclasses import dataclass
54
+ from transformers.utils import ModelOutput
55
+ from einops import rearrange
56
+ from functools import lru_cache
57
 
58
  if is_flash_attn_2_available():
59
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
74
  _CONFIG_FOR_DOC = "BailingMoeV2Config"
75
 
76
 
77
+ def nonzero(x):
78
+ return x.nonzero(as_tuple=True)
79
+
80
+
81
+ @lru_cache(maxsize=16)
82
+ def calc_chunks(cu_seqlen, moba_chunk_size):
83
+ """calc chunks that needs moba attention"""
84
+
85
+ # batch_sizes[batch_idx] = batch size ( seqlen ) of batch idx
86
+ batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1]
87
+ # batch_num_chunk[batch_idx] = how many chunk in batch idx
88
+ batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size
89
+ # cu_num_chunk[batch_idx] = first chunk id of this batch
90
+ cu_num_chunk = torch.ones(
91
+ batch_num_chunk.numel() + 1,
92
+ device=cu_seqlen.device,
93
+ dtype=batch_num_chunk.dtype,
94
+ )
95
+ cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0)
96
+ # total chunk ( for all batch )
97
+ num_chunk = cu_num_chunk[-1]
98
+ # chunk_sizes[chunk_idx] = chunk_size of chunk idx
99
+ chunk_sizes = torch.full((num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device)
100
+ chunk_sizes[0] = 0 # for calc cu chunk
101
+ batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size
102
+ chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size
103
+ # cu_chunk[chunk_idx] = the start chunk offset of chunk idx
104
+ cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32)
105
+ # chunk_to_batch[chunk_idx] = batch idx of the chunk idx
106
+ chunk_to_batch = torch.zeros((num_chunk,), dtype=torch.int32, device=cu_seqlen.device)
107
+ chunk_to_batch[cu_num_chunk[1:-1]] = 1
108
+ chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32)
109
+
110
+ """ filter chunks that need moba attn """
111
+
112
+ # filter chunks ( remove last chunk of each batch )
113
+ # filtered_chunk_indices: chunk index list that excludes the last chunk of each batch
114
+ chunk_to_remove = cu_num_chunk[1:] - 1
115
+ chunk_to_remain = torch.ones((num_chunk,), dtype=torch.bool, device=cu_seqlen.device)
116
+ chunk_to_remain[chunk_to_remove] = False
117
+ filtered_chunk_indices = chunk_to_remain.nonzero(as_tuple=True)[0]
118
+ num_filtered_chunk = len(filtered_chunk_indices)
119
+
120
+ return (
121
+ cu_chunk,
122
+ filtered_chunk_indices,
123
+ num_filtered_chunk,
124
+ chunk_to_batch,
125
+ )
126
+
127
+
128
+ def _prepare_for_moba(
129
+ q: torch.Tensor,
130
+ k: torch.Tensor,
131
+ v: torch.Tensor,
132
+ cu_seqlens: torch.Tensor,
133
+ max_seqlen: int,
134
+ moba_chunk_size: int,
135
+ moba_topk: int,
136
+ is_decode: bool = False
137
+ ) -> torch.Tensor:
138
+ """An efficient version of moba implementation with triton kernels and flash-attn, the core logic:
139
+ 1. Calculate the chunks and the number of chunks, n = floor(data_size / chunk_size)
140
+ - tokens in the tail chunk are reserved for self attn
141
+ - tokens in other chunks will be processed in later steps
142
+ 2. K in each chunk will calculate mean value as the representative k, and Q will attend to these representative
143
+ k to get the gate logit, which will be used to select topk chunks
144
+ 3. Select the topk chunks and get the dense q for each kv chunk pair and do the varlen attention
145
+ 4. Combine the varlen attn and self attn results via online softmax to get the final result
146
+
147
+ Args:
148
+ q (torch.Tensor): [seqlen, head, head_dim]
149
+ k (torch.Tensor): [seqlen, head, head_dim]
150
+ v (torch.Tensor): [seqlen, head, head_dim]
151
+ cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn
152
+ max_seqlen (int): the max sequence length of the batch, same definition in flash attn
153
+
154
+ Returns:
155
+ attn_output (torch.Tensor): [seqlen, head, head_dim]
156
+ """
157
+
158
+ kv = torch.stack((k, v), dim=1)
159
+
160
+ """ some basic variables """
161
+ # qkv shape = [ S, H, D ]
162
+ seqlen_q, num_head, head_dim = q.shape
163
+ seqlen_kv, num_head_kv, _ = k.shape
164
+ replicas = num_head // num_head_kv
165
+
166
+ """ prepare chunk meta """
167
+ (
168
+ cu_chunk,
169
+ filtered_chunk_indices,
170
+ num_filtered_chunk,
171
+ chunk_to_batch,
172
+ ) = calc_chunks(cu_seqlens, moba_chunk_size)
173
+ # cu_chunk: [num_chunks + 1], the start position of each chunk
174
+ # filtered_chunk_indices: [num_filtered_chunk], the indices of filtered chunk (filter out last in each batch)
175
+ # chunk_to_batch: [total_num_chunks], chunk_to_batch[i] stands for the batch index of i-th chunk
176
+
177
+ self_attn_cu_seqlen = cu_chunk
178
+ # filtered_kv is a dense matrix that only contains filtered chunk of kv
179
+ filtered_kv_indices = torch.arange(0, moba_chunk_size, dtype=torch.int64, device=q.device)[None, :].repeat(
180
+ num_filtered_chunk, 1
181
+ )
182
+ filtered_kv_indices += cu_chunk[filtered_chunk_indices][:, None]
183
+ index_expanded = filtered_kv_indices.view(-1).view(-1, 1, 1, 1).expand(-1, 2, kv.shape[-2], kv.shape[-1])
184
+ filtered_kv = torch.gather(kv, 0, index_expanded)
185
+
186
+ """ calc key_gate_weight and gate """
187
+
188
+ # key_gate_weight [ F_N_CHUNK, HEAD, HEAD_DIM ]
189
+ key_gate_weight = (
190
+ filtered_kv[:, 0].view(num_filtered_chunk, moba_chunk_size, num_head_kv, head_dim).mean(dim=1)
191
+ )
192
+
193
+ # we will adjust selective topk to moba_topk - 1, as the last chunk is always chosen
194
+ moba_topk = min(moba_topk - 1, num_filtered_chunk)
195
+ need_moba_attn = moba_topk > 0
196
+ # corner case: if no moba attn needed, just return self attn
197
+ if not need_moba_attn:
198
+ return None, None, None, None, None, None, None
199
+
200
+ query_gate_weight = q.view(seqlen_q, num_head_kv, replicas, head_dim).mean(dim=2).float()
201
+ key_gate_weight = key_gate_weight.type(torch.float32) # float logit for better gate logit perception
202
+ gate = torch.einsum("nhd,shd->nhs", key_gate_weight, query_gate_weight) # gate [ F_N_CHUNK, HEAD, SEQ ]
203
+ key_gate_weight = key_gate_weight.type_as(k)
204
+ q = q.type_as(k)
205
+
206
+ # pose process gate, masking unchosen batch and apply causal mask to current chunk
207
+ gate_seq_idx = torch.arange(0, seqlen_q, device=q.device, dtype=torch.int32)[None, :].repeat(num_filtered_chunk, 1)
208
+ chunk_end = cu_chunk[filtered_chunk_indices + 1]
209
+ batch_end = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1]
210
+ gate_chunk_end_mask = gate_seq_idx < chunk_end[:, None]
211
+ gate_batch_end_mask = gate_seq_idx >= batch_end[:, None]
212
+ gate_inf_mask = gate_chunk_end_mask | gate_batch_end_mask
213
+ gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf"))
214
+
215
+ """ find moba q that needs moba attn """
216
+ # find topk chunks
217
+ # gate_mask [ N_CHUNK, HEAD, SEQ ], true indicates that needs attention
218
+ _, gate_top_k_idx = torch.topk(gate, k=moba_topk, dim=0, largest=True, sorted=False)
219
+ # apply causal mask
220
+ gate_mask = torch.logical_not(gate.isinf())
221
+ # select topk chunks
222
+ gate_idx_mask = torch.zeros(gate_mask.shape, dtype=torch.bool, device=q.device)
223
+ gate_idx_mask = gate_idx_mask.scatter_(dim=0, index=gate_top_k_idx, value=True)
224
+ gate_mask = torch.logical_and(gate_mask, gate_idx_mask)
225
+
226
+ moba_q_indices = nonzero(gate_mask.reshape(gate_mask.shape[0], -1))[-1] # .nonzero(as_tuple=True)[
227
+ # -1
228
+ # ] # [ HS indices ] * N
229
+ # moba_seqlen_q indicates that how many q chunks are selected for each kv chunk - head
230
+ moba_seqlen_q = gate_mask.sum(dim=-1).flatten()
231
+ # select all q that needs moba attn based on the moba_q_indices
232
+
233
+ # GQA
234
+ # moba_q_pre = q.transpose(0, 1).reshape(-1, q.size(-1))
235
+ moba_q = q.view(seqlen_q, num_head_kv, replicas, head_dim)
236
+ moba_q_pre = moba_q.transpose(0, 1).reshape(-1, *moba_q.shape[2:])
237
+
238
+ # GQA
239
+ index_expanded = moba_q_indices.view(-1, 1, 1).expand(-1, replicas, moba_q_pre.size(-1))
240
+
241
+ moba_q = torch.gather(moba_q_pre, 0, index_expanded)
242
+
243
+ # moba_q_sh_indices represents the position in the origin q tensor of each q token inside moba_q
244
+ # GQA
245
+ moba_q_sh_indices = moba_q_indices % seqlen_q * num_head_kv + moba_q_indices // seqlen_q
246
+
247
+ """ prepare moba kv """
248
+ # Since moba_q is organized as HS * N, we need to reorganize kv to adapt to q
249
+
250
+ # cut off zero experts
251
+ q_zero_mask = moba_seqlen_q == 0
252
+ valid_expert_mask = ~q_zero_mask
253
+ zero_expert_count = q_zero_mask.sum()
254
+ # only keep the kv that has q select > 0
255
+ if zero_expert_count > 0:
256
+ moba_seqlen_q = moba_seqlen_q[valid_expert_mask]
257
+ # moba cu_seqlen for flash attn
258
+ moba_cu_seqlen_q = torch.cat(
259
+ (
260
+ torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype),
261
+ moba_seqlen_q.cumsum(dim=0),
262
+ ),
263
+ dim=0,
264
+ ).to(torch.int32)
265
+ moba_kv = filtered_kv.permute(2, 0, 1, 3)
266
+ moba_kv = moba_kv.split(moba_chunk_size, dim=1)
267
+ moba_kv = torch.cat(moba_kv, dim=0)
268
+
269
+ if zero_expert_count > 0:
270
+ assert valid_expert_mask.sum() == moba_kv.shape[0] - zero_expert_count
271
+ moba_kv = moba_kv[valid_expert_mask] # cut off zero Q expert from kv , or the grad may be nan
272
+ moba_kv = moba_kv.flatten(start_dim=0, end_dim=1).unsqueeze(2)
273
+ moba_cu_seqlen_kv = (
274
+ torch.arange(
275
+ 0,
276
+ num_filtered_chunk * num_head_kv + 1 - zero_expert_count,
277
+ dtype=torch.int32,
278
+ device=q.device,
279
+ )
280
+ * moba_chunk_size
281
+ )
282
+
283
+ return self_attn_cu_seqlen, moba_q, moba_kv, moba_cu_seqlen_q, moba_cu_seqlen_kv, moba_chunk_size, moba_q_sh_indices
284
+
285
+
286
+ def _moba_attn_varlen_prefill(
287
+ q: torch.Tensor,
288
+ k: torch.Tensor,
289
+ v: torch.Tensor,
290
+ cu_seqlens: torch.Tensor,
291
+ max_seqlen: int,
292
+ moba_chunk_size: int,
293
+ moba_topk: int,
294
+ ) -> torch.Tensor:
295
+ """An efficient version of moba implementation with triton kernels and flash-attn, the core logic:
296
+ 1. Calculate the chunks and the number of chunks, n = floor(data_size / chunk_size)
297
+ - tokens in the tail chunk are reserved for self attn
298
+ - tokens in other chunks will be processed in later steps
299
+ 2. K in each chunk will calculate mean value as the representative k, and Q will attend to these representative
300
+ k to get the gate logit, which will be used to select topk chunks
301
+ 3. Select the topk chunks and get the dense q for each kv chunk pair and do the varlen attention
302
+ 4. Combine the varlen attn and self attn results via online softmax to get the final result
303
+
304
+ Args:
305
+ q (torch.Tensor): [seqlen, head, head_dim]
306
+ k (torch.Tensor): [seqlen, head, head_dim]
307
+ v (torch.Tensor): [seqlen, head, head_dim]
308
+ cu_seqlens (torch.Tensor): the cumulative sequence length tensor, same definition in flash attn
309
+ max_seqlen (int): the max sequence length of the batch, same definition in flash attn
310
+
311
+ Returns:
312
+ attn_output (torch.Tensor): [seqlen, head, head_dim]
313
+ """
314
+
315
+ self_attn_cu_seqlen, moba_q, moba_kv, moba_cu_seqlen_q, moba_cu_seqlen_kv, moba_chunk_size, moba_q_sh_indices = (
316
+ _prepare_for_moba(q, k, v, cu_seqlens, max_seqlen, moba_chunk_size, moba_topk)
317
+ )
318
+
319
+ if moba_q is None:
320
+ return flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=True)
321
+ softmax_scale = q.shape[-1] ** (-0.5)
322
+
323
+ # self attn
324
+ self_attn_out_sh, self_attn_lse_hs, *rest = flash_attn_varlen_func(
325
+ q=q,
326
+ k=k,
327
+ v=v,
328
+ cu_seqlens_q=self_attn_cu_seqlen,
329
+ cu_seqlens_k=self_attn_cu_seqlen,
330
+ max_seqlen_q=max_seqlen,
331
+ max_seqlen_k=max_seqlen,
332
+ softmax_scale=softmax_scale,
333
+ causal=True,
334
+ return_attn_probs=True
335
+ )
336
+
337
+ # moba attn
338
+ moba_attn_out, moba_attn_lse_hs, *rest = flash_attn_varlen_func(
339
+ q=moba_q,
340
+ k=moba_kv[:, 0],
341
+ v=moba_kv[:, 1],
342
+ cu_seqlens_q=moba_cu_seqlen_q,
343
+ cu_seqlens_k=moba_cu_seqlen_kv,
344
+ max_seqlen_q=max_seqlen,
345
+ max_seqlen_k=moba_chunk_size,
346
+ softmax_scale=softmax_scale,
347
+ causal=False,
348
+ return_attn_probs=True
349
+ )
350
+
351
+ kv_replicas = q.shape[1] // k.shape[1]
352
+ h, s = self_attn_lse_hs.shape
353
+
354
+ # convert lse shape hs -> sh ( follow the legacy mix attn logic )
355
+ self_attn_lse_sh = self_attn_lse_hs.t().view(s, k.shape[1], kv_replicas).contiguous()
356
+ moba_attn_lse = moba_attn_lse_hs.t().contiguous()
357
+
358
+ max_lse_1d = self_attn_lse_sh.view(-1, kv_replicas)
359
+ max_lse_1d = max_lse_1d.index_reduce(0, moba_q_sh_indices, moba_attn_lse, "amax")
360
+ self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)
361
+
362
+ moba_attn_lse = (
363
+ moba_attn_lse.view(-1, kv_replicas).sub(max_lse_1d.index_select(0, moba_q_sh_indices)).reshape_as(moba_attn_lse)
364
+ )
365
+
366
+ mixed_attn_se_sh = self_attn_lse_sh.exp()
367
+ moba_attn_se = moba_attn_lse.exp()
368
+
369
+ mixed_view = mixed_attn_se_sh.view(-1, kv_replicas)
370
+ result_view = mixed_view.index_add(0, moba_q_sh_indices, moba_attn_se.view(-1, kv_replicas))
371
+
372
+ mixed_attn_se_sh = result_view.view_as(mixed_attn_se_sh)
373
+ mixed_attn_lse_sh = mixed_attn_se_sh.log()
374
+
375
+ # add attn output
376
+ factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [ vS, H ]
377
+ self_attn_out_sh = self_attn_out_sh * factor.view(self_attn_out_sh.shape[0], self_attn_out_sh.shape[1], 1)
378
+ output_2d = self_attn_out_sh.reshape(q.shape[0] * k.shape[1], kv_replicas, q.shape[2])
379
+
380
+ # add moba output
381
+ mixed_attn_lse = mixed_attn_lse_sh.view(-1, kv_replicas).index_select(0, moba_q_sh_indices).view_as(moba_attn_lse)
382
+ factor = (moba_attn_lse - mixed_attn_lse).exp() # [ vS, H ]
383
+ moba_attn_out = moba_attn_out * factor.unsqueeze(-1)
384
+ raw_attn_out = moba_attn_out.view(-1, kv_replicas, moba_attn_out.shape[-1])
385
+ output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)
386
+
387
+ # add back max lse
388
+ mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)
389
+
390
+ return output_2d.view(q.shape[0], q.shape[1], q.shape[2]).to(q.dtype)
391
+
392
+
393
+ def roll_tensor(tensor, shifts=-1, dims=-1, fill_value=0):
394
+ """Roll the tensor input along the given dimension(s).
395
+ Inserted elements are set to be 0.0.
396
+ """
397
+ rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
398
+ rolled_tensor.select(dims, shifts).fill_(fill_value)
399
+ return rolled_tensor, rolled_tensor.sum()
400
+
401
+
402
+ @dataclass
403
+ class MoEV2CausalLMOutputWithPast(ModelOutput):
404
+ """
405
+ Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
406
+ states terms, to train a MoE model.
407
+ Args:
408
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
409
+ Language modeling loss (for next-token prediction).
410
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
411
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
412
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
413
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
414
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
415
+ `past_key_values` input) to speed up sequential decoding.
416
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
417
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
418
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
419
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
420
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
421
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
422
+ sequence_length)`.
423
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
424
+ heads.
425
+ z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
426
+ z_loss for the sparse modules.
427
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
428
+ aux_loss for the sparse modules.
429
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
430
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
431
+ Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
432
+ modules.
433
+ """
434
+
435
+ loss: Optional[torch.FloatTensor] = None
436
+ logits: Optional[torch.FloatTensor] = None
437
+ past_key_values: Optional[Cache] = None
438
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
439
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
440
+ z_loss: Optional[torch.FloatTensor] = None
441
+ aux_loss: Optional[torch.FloatTensor] = None
442
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
443
+ mtp_loss: Optional[torch.FloatTensor] = None
444
+ mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None
445
+
446
+
447
+ class MoeV2ModelOutputWithPast(MoeModelOutputWithPast):
448
+
449
+ def __init__(self, mtp_hidden_states=None, **kwargs):
450
+ super().__init__(**kwargs)
451
+ self.mtp_hidden_states = mtp_hidden_states
452
+
453
+
454
  def _get_unpad_data(attention_mask):
455
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
456
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 
502
 
503
 
504
  class BailingMoeV2RotaryEmbedding(nn.Module):
505
+ def __init__(self, config: BailingMoeV2Config, device=None):
506
  super().__init__()
507
+ # BC: "rope_type" was originally "type"
508
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
509
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
510
+ else:
511
+ self.rope_type = "default"
512
+ self.max_seq_len_cached = config.max_position_embeddings
513
+ self.original_max_seq_len = config.max_position_embeddings
514
 
515
+ self.config = config
516
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
 
 
 
519
  self.register_buffer("inv_freq", inv_freq, persistent=False)
520
+ self.original_inv_freq = self.inv_freq
521
 
522
+ @torch.no_grad()
523
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
524
+ def forward(self, x, position_ids):
525
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
526
+ position_ids_expanded = position_ids[:, None, :].float()
527
 
528
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
529
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
530
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
531
+ emb = torch.cat((freqs, freqs), dim=-1)
532
+ cos = emb.cos() * self.attention_scaling
533
+ sin = emb.sin() * self.attention_scaling
534
 
535
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 
 
 
 
 
 
 
536
 
537
 
538
  # Copied from transformers.models.llama.modeling_llama.rotate_half
 
544
 
545
 
546
  # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
547
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
548
  """Applies Rotary Position Embedding to the query and key tensors.
 
549
  Args:
550
  q (`torch.Tensor`): The query tensor.
551
  k (`torch.Tensor`): The key tensor.
552
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
553
  sin (`torch.Tensor`): The sine part of the rotary embedding.
 
 
 
554
  unsqueeze_dim (`int`, *optional*, defaults to 1):
555
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
556
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
 
561
  Returns:
562
  `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
563
  """
564
+ cos = cos.unsqueeze(unsqueeze_dim)
565
+ sin = sin.unsqueeze(unsqueeze_dim)
566
+
567
+ # Keep half or full tensor for later concatenation
568
+ rotary_dim = cos.shape[-1]
569
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
570
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
571
+
572
+ # Apply rotary embeddings on the first half or full tensor
573
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
574
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
575
+
576
+ # Concatenate back to full shape
577
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
578
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
579
  return q_embed, k_embed
580
 
581
 
 
602
  self.top_k = config.num_experts_per_tok
603
  self.num_experts = config.num_experts
604
 
605
+ self.n_group = config.n_group
606
+ self.topk_group = config.topk_group
607
+
608
  # topk selection algorithm
 
609
  self.gating_dim = config.hidden_size
610
  self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
611
+ self.routed_scaling_factor = config.routed_scaling_factor
612
 
613
+ self.register_buffer("expert_bias", torch.zeros((self.num_experts)))
 
614
  self.reset_parameters()
615
 
616
  def reset_parameters(self) -> None:
 
618
 
619
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
620
 
621
+ def group_limited_topk(
622
+ self,
623
+ scores: torch.Tensor,
624
+ ):
625
+ num_tokens, _ = scores.size()
626
+ # Organize the experts into groups
627
+ group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
628
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
629
+ group_mask = torch.zeros_like(group_scores)
630
+ group_mask.scatter_(1, group_idx, 1)
631
+
632
+ # Mask the experts based on selection groups
633
+ score_mask = (
634
+ group_mask.unsqueeze(-1)
635
+ .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
636
+ .reshape(num_tokens, -1)
637
+ )
638
+
639
+ masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
640
+ probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1)
641
+
642
+ return probs, top_indices
643
+
644
  def forward(self, hidden_states):
645
  # compute gating score
646
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
647
+ logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
 
 
 
648
 
649
+ scores = torch.sigmoid(logits.float()).type_as(logits)
 
650
 
651
+ scores_for_routing = scores + self.expert_bias
652
+ _, topk_idx = self.group_limited_topk(scores_for_routing)
 
 
 
653
 
654
+ scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
 
 
655
 
656
+ topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
657
+ topk_weight = topk_weight * self.routed_scaling_factor
 
 
 
 
 
 
658
 
659
+ return topk_idx, topk_weight, logits
 
 
660
 
661
 
662
  class BailingMoeV2SparseMoeBlock(nn.Module):
 
709
  tokens_per_expert = cnts.sum(dim=0)
710
  idxs = topk_ids.view(-1).argsort()
711
  sorted_tokens = x[idxs // topk_ids.shape[1]]
 
712
  tokens_per_expert = tokens_per_expert.cpu().numpy()
713
  outputs = []
714
  start_idx = 0
 
719
  expert = self.experts[i]
720
  tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
721
  expert_out = expert(tokens_for_this_expert)
722
+ outputs.append(expert_out.to(x.device))
723
  start_idx = end_idx
724
 
725
  outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
 
767
  self.hidden_size = config.hidden_size
768
  self.num_heads = config.num_attention_heads
769
  self.head_dim = config.head_dim or self.hidden_size // self.num_heads
770
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
771
+ self.rope_dim = int(self.head_dim * partial_rotary_factor)
772
  self.num_key_value_heads = config.num_key_value_heads
773
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
774
  self.max_position_embeddings = config.max_position_embeddings
 
782
  )
783
 
784
  if self.config.use_qk_norm:
785
+ self.query_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
786
+ self.key_layernorm = BailingMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
787
  self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
 
789
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
790
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
797
  past_key_value: Optional[Cache] = None,
798
  output_attentions: bool = False,
799
  use_cache: bool = False,
800
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
801
  **kwargs,
802
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
803
 
804
  bsz, q_len, _ = hidden_states.size()
805
 
 
814
  value_states = value_states.transpose(1, 2)
815
 
816
  if self.config.use_qk_norm:
817
+ query_states = self.query_layernorm(query_states)
818
+ key_states = self.key_layernorm(key_states)
819
+
820
+ cos, sin = position_embeddings
821
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
822
 
 
823
  if past_key_value is not None:
824
  if self.layer_idx is None:
825
  raise ValueError(
 
827
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
828
  "with a layer index."
829
  )
830
+ cache_kwargs = {"sin": sin, "cos": cos}
 
 
 
 
 
831
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
832
 
833
  key_states = repeat_kv(key_states, self.num_key_value_groups)
834
  value_states = repeat_kv(value_states, self.num_key_value_groups)
835
 
836
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
837
 
838
+ kv_seq_len = key_states.shape[-2]
839
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
840
  raise ValueError(
841
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
 
896
  past_key_value: Optional[Cache] = None,
897
  output_attentions: bool = False,
898
  use_cache: bool = False,
899
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
900
  **kwargs,
901
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
902
  # BailingMoeV2FlashAttention2 attention does not support output_attentions
 
 
 
 
 
 
 
 
903
  output_attentions = False
904
 
905
  bsz, q_len, _ = hidden_states.size()
 
919
  value_states = value_states.transpose(1, 2)
920
 
921
  if self.config.use_qk_norm:
922
+ query_states = self.query_layernorm(query_states)
923
+ key_states = self.key_layernorm(key_states)
924
 
925
+ cos, sin = position_embeddings
926
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
927
 
928
  if past_key_value is not None:
929
+ cache_kwargs = {"sin": sin, "cos": cos}
930
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
931
 
932
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
 
951
  elif torch.is_autocast_enabled():
952
  target_dtype = torch.get_autocast_gpu_dtype()
953
  else:
954
+ target_dtype = self.query_key_value.weight.dtype
955
 
956
  logger.warning_once(
957
  f"The input hidden states seems to be silently casted in float32, this might be related to"
 
962
  query_states = query_states.to(target_dtype)
963
  key_states = key_states.to(target_dtype)
964
  value_states = value_states.to(target_dtype)
965
+ if hasattr(self.config, "moba_topk"):
966
+ attn_output = self._mixture_attention_forward(
967
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
968
+ )
969
+ else:
970
+ attn_output = self._flash_attention_forward(
971
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
972
+ )
973
 
974
  attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
975
  attn_output = self.dense(attn_output)
 
978
  attn_weights = None
979
 
980
  return attn_output, attn_weights, past_key_value
981
+
982
+ def _mixture_attention_forward(
983
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
984
+ ):
985
+ """
986
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
987
+ first unpad the input, then computes the attention scores and pad the final attention scores.
988
+ Args:
989
+ query_states (`torch.Tensor`):
990
+ Input query states to be passed to Flash Attention API
991
+ key_states (`torch.Tensor`):
992
+ Input key states to be passed to Flash Attention API
993
+ value_states (`torch.Tensor`):
994
+ Input value states to be passed to Flash Attention API
995
+ attention_mask (`torch.Tensor`):
996
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
997
+ position of padding tokens and 1 for the position of non-padding tokens.
998
+ dropout (`int`, *optional*):
999
+ Attention dropout
1000
+ softmax_scale (`float`, *optional*):
1001
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1002
+ query_length (`int`):
1003
+ The length of the query sequence in terms of tokens. This represents the number of tokens in the
1004
+ `query_states` tensor along the sequence dimension. It is used to determine the effective sequence
1005
+ length for attention computations.
1006
+ """
1007
+ if not self._flash_attn_uses_top_left_mask:
1008
+ causal = self.is_causal
1009
+ else:
1010
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in BailingMoeV2FlashAttention2 __init__.
1011
+ causal = self.is_causal and query_length != 1
1012
+
1013
+ if query_length != 1:
1014
+ # prefill
1015
+ # Contains at least one padding token in the sequence
1016
+ if attention_mask is not None:
1017
+ batch_size = query_states.shape[0]
1018
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1019
+ query_states, key_states, value_states, attention_mask, query_length
1020
+ )
1021
+
1022
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1023
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1024
+ attn_output_unpad = _moba_attn_varlen_prefill(
1025
+ query_states,
1026
+ key_states,
1027
+ value_states,
1028
+ cu_seqlens=cu_seqlens_k,
1029
+ max_seqlen=max_seqlen_in_batch_k,
1030
+ moba_chunk_size=self.config.moba_block_size,
1031
+ moba_topk=self.config.moba_topk
1032
+ )
1033
+
1034
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1035
+ else:
1036
+ batch_size = query_states.shape[0]
1037
+ cu_seqlens_k = torch.cumsum(
1038
+ torch.tensor([0] + [query_length] * batch_size, device=query_states.device),
1039
+ dim=0,
1040
+ dtype=torch.int32,
1041
+ )
1042
+ query_states = query_states.view(-1, self.num_heads, self.head_dim)
1043
+ key_states = key_states.view(-1, self.num_key_value_heads, self.head_dim)
1044
+ value_states = value_states.view(-1, self.num_key_value_heads, self.head_dim)
1045
+ attn_output = _moba_attn_varlen_prefill(
1046
+ query_states,
1047
+ key_states,
1048
+ value_states,
1049
+ cu_seqlens=cu_seqlens_k,
1050
+ max_seqlen=query_length,
1051
+ moba_chunk_size=self.config.moba_block_size,
1052
+ moba_topk=self.config.moba_topk
1053
+ ).view(batch_size, query_length, -1)
1054
+ else:
1055
+ # decode
1056
+ attn_output = self._flash_attention_forward(
1057
+ query_states, key_states, value_states, attention_mask, query_length, dropout, softmax_scale
1058
+ )
1059
+ return attn_output
1060
 
1061
  def _flash_attention_forward(
1062
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
1064
  """
1065
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1066
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
1067
  Args:
1068
  query_states (`torch.Tensor`):
1069
  Input query states to be passed to Flash Attention API
 
1176
  past_key_value: Optional[Cache] = None,
1177
  output_attentions: bool = False,
1178
  use_cache: bool = False,
1179
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1180
  **kwargs,
1181
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1182
  if output_attentions:
 
1207
  value_states = value_states.transpose(1, 2)
1208
 
1209
  if self.config.use_qk_norm:
1210
+ query_states = self.query_layernorm(query_states)
1211
+ key_states = self.key_layernorm(key_states)
1212
 
1213
+ cos, sin = position_embeddings
1214
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
 
1215
 
1216
  if past_key_value is not None:
1217
+ cache_kwargs = {"sin": sin, "cos": cos}
1218
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1219
 
1220
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1221
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1222
 
1223
  if attention_mask is not None:
1224
+ kv_seq_len = key_states.shape[-2]
1225
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1226
  raise ValueError(
1227
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
 
1259
  }
1260
 
1261
 
1262
+ class BailingMoeV2MTPLayer(nn.Module):
1263
+ def __init__(self, config: BailingMoeV2Config, layer_idx: int):
1264
+ super().__init__()
1265
+ self.layer_idx = layer_idx
1266
+ self.input_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1267
+ self.enorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1268
+
1269
+ self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
1270
+ self.post_attention_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1271
+ self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1272
+ self.mlp = BailingMoeV2SparseMoeBlock(config)
1273
+
1274
+ self.hnorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1275
+ self.final_layernorm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1276
+
1277
+ def forward(
1278
+ self,
1279
+ input_embeds,
1280
+ hidden_states: torch.Tensor,
1281
+ attention_mask: Optional[torch.Tensor] = None,
1282
+ position_ids: Optional[torch.LongTensor] = None,
1283
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1284
+ output_attentions: Optional[bool] = False,
1285
+ output_router_logits: Optional[bool] = False,
1286
+ use_cache: Optional[bool] = False,
1287
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1288
+ **kwargs,
1289
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1290
+ input_embeds = self.enorm(input_embeds)
1291
+ hidden_states = self.hnorm(hidden_states)
1292
+ hidden_states = self.eh_proj(torch.cat([input_embeds, hidden_states], dim=-1))
1293
+ residual = hidden_states
1294
+
1295
+ hidden_states = self.input_layernorm(hidden_states)
1296
+
1297
+ # Self Attention
1298
+ hidden_states, self_attn_weights, present_key_value = self.attention(
1299
+ hidden_states=hidden_states,
1300
+ attention_mask=attention_mask,
1301
+ position_ids=position_ids,
1302
+ past_key_value=past_key_value,
1303
+ output_attentions=output_attentions,
1304
+ position_embeddings=position_embeddings,
1305
+ use_cache=use_cache,
1306
+ )
1307
+ hidden_states = residual + hidden_states
1308
+
1309
+ # Fully Connected
1310
+ residual = hidden_states
1311
+ hidden_states = self.post_attention_layernorm(hidden_states)
1312
+ hidden_states = self.mlp(hidden_states)
1313
+ if isinstance(hidden_states, tuple):
1314
+ hidden_states, router_logits = hidden_states
1315
+ else:
1316
+ router_logits = None
1317
+ hidden_states = residual + hidden_states.to(residual.device)
1318
+ hidden_states = self.final_layernorm(hidden_states)
1319
+
1320
+ outputs = (hidden_states,)
1321
+
1322
+ if output_attentions:
1323
+ outputs += (self_attn_weights,)
1324
+
1325
+ if use_cache:
1326
+ outputs += (present_key_value,)
1327
+
1328
+ if output_router_logits:
1329
+ outputs += (router_logits,)
1330
+
1331
+ return outputs
1332
+
1333
+
1334
  class BailingMoeV2DecoderLayer(nn.Module):
1335
  def __init__(self, config: BailingMoeV2Config, layer_idx: int):
1336
  super().__init__()
 
1355
  output_attentions: Optional[bool] = False,
1356
  output_router_logits: Optional[bool] = False,
1357
  use_cache: Optional[bool] = False,
1358
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1359
  **kwargs,
1360
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1361
  """
 
1379
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1380
  (see `past_key_values`).
1381
  """
 
 
 
 
1382
  residual = hidden_states
1383
 
1384
  hidden_states = self.input_layernorm(hidden_states)
 
1390
  position_ids=position_ids,
1391
  past_key_value=past_key_value,
1392
  output_attentions=output_attentions,
1393
+ position_embeddings=position_embeddings,
1394
  use_cache=use_cache,
1395
  )
1396
  hidden_states = residual + hidden_states
 
1403
  hidden_states, router_logits = hidden_states
1404
  else:
1405
  router_logits = None
1406
+ hidden_states = residual + hidden_states.to(residual.device)
1407
 
1408
  outputs = (hidden_states,)
1409
 
 
1423
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1424
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1425
  etc.)
 
1426
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1427
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1428
  and behavior.
 
1429
  Parameters:
1430
  config ([`BailingMoeV2Config`]):
1431
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
1465
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1466
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1467
  it.
 
1468
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1469
  [`PreTrainedTokenizer.__call__`] for details.
 
1470
  [What are input IDs?](../glossary#input-ids)
1471
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1472
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
1473
  - 1 for tokens that are **not masked**,
1474
  - 0 for tokens that are **masked**.
 
1475
  [What are attention masks?](../glossary#attention-mask)
 
1476
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1477
  [`PreTrainedTokenizer.__call__`] for details.
 
1478
  If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1479
  `past_key_values`).
 
1480
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1481
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1482
  information on the default strategy.
 
1483
  - 1 indicates the head is **not masked**,
1484
  - 0 indicates the head is **masked**.
1485
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1486
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1487
  config.n_positions - 1]`.
 
1488
  [What are position IDs?](../glossary#position-ids)
1489
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1490
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1491
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1492
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
1493
  Two formats are allowed:
1494
  - a [`~cache_utils.Cache`] instance;
1495
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1496
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1497
  cache format.
 
1498
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1499
  legacy cache format will be returned.
 
1500
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1501
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1502
  of shape `(batch_size, sequence_length)`.
 
1525
  class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1526
  """
1527
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BailingMoeV2DecoderLayer`]
 
1528
  Args:
1529
  config: BailingMoeV2Config
1530
  """
 
1533
  super().__init__(config)
1534
  self.padding_idx = config.pad_token_id
1535
  self.vocab_size = config.vocab_size
1536
+ self.num_nextn_predict_layers = config.num_nextn_predict_layers
1537
 
1538
  self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1539
+ self.layers = []
1540
+ for layer_idx in range(config.num_hidden_layers + config.num_nextn_predict_layers):
1541
+ layer_cls = BailingMoeV2DecoderLayer if layer_idx < config.num_hidden_layers else BailingMoeV2MTPLayer
1542
+ self.layers.append(layer_cls(config, layer_idx))
1543
+
1544
+ self.layers = nn.ModuleList(self.layers)
1545
+
1546
  self._use_sdpa = config._attn_implementation == "sdpa"
1547
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1548
  self.norm = BailingMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1549
+ self.rotary_emb = BailingMoeV2RotaryEmbedding(config=config)
1550
  self.gradient_checkpointing = False
1551
  # Initialize weights and apply final processing
1552
  self.post_init()
 
1571
  output_router_logits: Optional[bool] = None,
1572
  return_dict: Optional[bool] = None,
1573
  **kwargs,
1574
+ ) -> Union[Tuple, MoeV2ModelOutputWithPast]:
1575
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1576
  output_hidden_states = (
1577
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1600
  )
1601
  use_cache = False
1602
 
1603
+ if use_cache and past_key_values is None:
1604
+ past_key_values = DynamicCache()
1605
+
1606
+ if inputs_embeds is None:
1607
+ inputs_embeds = self.word_embeddings(input_ids)
1608
+
1609
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1610
 
1611
  if position_ids is None:
 
1612
  position_ids = torch.arange(
1613
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1614
  )
1615
  position_ids = position_ids.unsqueeze(0)
1616
 
 
 
 
1617
  if self._use_flash_attention_2:
1618
  # 2d mask is passed through the layers
1619
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
 
1624
  attention_mask,
1625
  (batch_size, seq_length),
1626
  inputs_embeds,
1627
+ past_seen_tokens,
1628
  )
1629
  else:
1630
  # 4d mask is passed through the layers
1631
  attention_mask = _prepare_4d_causal_attention_mask(
1632
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens
1633
  )
1634
 
1635
  # embed positions
1636
  hidden_states = inputs_embeds
1637
 
1638
+ # create position embeddings to be shared across the decoder layers
1639
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1640
+
1641
  # decoder layers
1642
  all_hidden_states = () if output_hidden_states else None
1643
  all_self_attns = () if output_attentions else None
1644
  all_router_logits = () if output_router_logits else None
1645
  next_decoder_cache = None
1646
+ layers = self.layers[: -self.num_nextn_predict_layers] if self.num_nextn_predict_layers > 0 else self.layers
1647
+ mtp_layers = self.layers[-self.num_nextn_predict_layers :] if self.num_nextn_predict_layers > 0 else None
1648
 
1649
+ for decoder_layer in layers:
1650
  if output_hidden_states:
1651
  all_hidden_states += (hidden_states,)
1652
 
 
1660
  output_attentions,
1661
  output_router_logits,
1662
  use_cache,
1663
+ position_embeddings,
1664
  )
1665
  else:
1666
  layer_outputs = decoder_layer(
 
1671
  output_attentions=output_attentions,
1672
  output_router_logits=output_router_logits,
1673
  use_cache=use_cache,
1674
+ position_embeddings=position_embeddings,
1675
  )
1676
  hidden_states = layer_outputs[0]
1677
 
 
1685
  all_router_logits += (layer_outputs[-1],)
1686
 
1687
  hidden_states = self.norm(hidden_states)
1688
+ main_hidden_states = hidden_states
1689
 
1690
  # add hidden states from the last decoder layer
1691
  if output_hidden_states:
1692
+ all_hidden_states += (main_hidden_states,)
1693
+
1694
+ mtp_hidden_states = None
1695
+
1696
+ if mtp_layers:
1697
+ for decoder_layer in mtp_layers:
1698
+ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1)
1699
+ inputs_embeds = self.word_embeddings(input_ids)
1700
+
1701
+ if self.gradient_checkpointing and self.training:
1702
+ layer_outputs = self._gradient_checkpointing_func(
1703
+ decoder_layer.__call__,
1704
+ inputs_embeds,
1705
+ hidden_states,
1706
+ attention_mask,
1707
+ position_ids,
1708
+ past_key_values,
1709
+ output_attentions,
1710
+ output_router_logits,
1711
+ use_cache,
1712
+ position_embeddings,
1713
+ )
1714
+ else:
1715
+ layer_outputs = decoder_layer(
1716
+ inputs_embeds,
1717
+ hidden_states,
1718
+ attention_mask=attention_mask,
1719
+ position_ids=position_ids,
1720
+ past_key_value=past_key_values,
1721
+ output_attentions=output_attentions,
1722
+ output_router_logits=output_router_logits,
1723
+ use_cache=use_cache,
1724
+ position_embeddings=position_embeddings,
1725
+ )
1726
+ if mtp_hidden_states is None:
1727
+ mtp_hidden_states = []
1728
+ hidden_states = layer_outputs[0]
1729
+ mtp_hidden_states.append(hidden_states)
1730
+
1731
+ if output_hidden_states:
1732
+ all_hidden_states += (hidden_states,)
1733
+
1734
+ if use_cache:
1735
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1736
+
1737
+ if output_attentions:
1738
+ all_self_attns += (layer_outputs[1],)
1739
+
1740
+ if output_router_logits and layer_outputs[-1] is not None:
1741
+ all_router_logits += (layer_outputs[-1],)
1742
 
1743
  next_cache = None
1744
  if use_cache:
1745
+ next_cache = next_decoder_cache
1746
  if not return_dict:
1747
  return tuple(
1748
  v
1749
+ for v in [main_hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1750
  if v is not None
1751
  )
1752
+ return MoeV2ModelOutputWithPast(
1753
+ last_hidden_state=main_hidden_states,
1754
  past_key_values=next_cache,
1755
  hidden_states=all_hidden_states,
1756
+ mtp_hidden_states=mtp_hidden_states,
1757
  attentions=all_self_attns,
1758
  router_logits=all_router_logits,
1759
  )
1760
 
1761
 
1762
+ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
1763
  _tied_weights_keys = ["lm_head.weight"]
1764
 
1765
  def __init__(self, config: BailingMoeV2Config):
1766
  super().__init__(config)
1767
  self.model = BailingMoeV2Model(config)
1768
  self.vocab_size = config.vocab_size
 
1769
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1770
+ self.num_nextn_predict_layers = config.num_nextn_predict_layers
1771
+ self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
1772
 
1773
  # Initialize weights and apply final processing
1774
  self.post_init()
 
1791
  def get_decoder(self):
1792
  return self.model
1793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1794
  @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING)
1795
+ @replace_return_docstrings(output_type=MoEV2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1796
  def forward(
1797
  self,
1798
  input_ids: torch.LongTensor = None,
 
1807
  output_router_logits: Optional[bool] = None,
1808
  return_dict: Optional[bool] = None,
1809
  **kwargs,
1810
+ ) -> Union[Tuple, MoEV2CausalLMOutputWithPast]:
1811
  r"""
1812
  Args:
1813
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1814
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1815
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1816
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1817
  Returns:
 
1818
  Example:
 
1819
  ```python
1820
  >>> from transformers import AutoTokenizer
 
1821
  >>> model = BailingMoeV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1822
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
 
1823
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1824
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1825
  >>> # Generate
1826
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1827
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
1850
  **kwargs,
1851
  )
1852
 
 
 
 
 
 
1853
  loss = None
1854
+ all_mtp_loss = None
1855
  aux_loss = None
1856
+ hidden_states = outputs[0]
1857
+ logits = self.lm_head(hidden_states)
1858
+ logits = logits.float()
1859
 
1860
  if labels is not None:
1861
+ loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
1862
+
1863
+ all_mtp_logits = None
1864
+ if self.num_nextn_predict_layers > 0:
1865
+ mtp_hidden_states = outputs.mtp_hidden_states
1866
+ shift_labels_mtp = None
1867
+ for i in range(self.num_nextn_predict_layers):
1868
+ mtp_hidden_states = mtp_hidden_states[i]
1869
+ mtp_logits = self.lm_head(mtp_hidden_states).float()
1870
+ if all_mtp_logits is None:
1871
+ all_mtp_logits = []
1872
+ all_mtp_logits.append(mtp_logits)
1873
+ if labels is not None:
1874
+ if shift_labels_mtp is None:
1875
+ shift_labels_mtp = labels.clone()
1876
+ shift_labels_mtp, _ = roll_tensor(shift_labels_mtp, shifts=-1, dims=-1, fill_value=-100)
1877
+ mtp_logits_ = mtp_logits.view(-1, self.config.vocab_size)
1878
+ mtp_loss = self.loss_function(mtp_logits_, shift_labels_mtp.to(mtp_logits_.device).view(-1), self.config.vocab_size, **kwargs)
1879
+ if loss is not None:
1880
+ loss += self.mtp_loss_scaling_factor * mtp_loss
1881
+ else:
1882
+ loss = self.mtp_loss_scaling_factor * mtp_loss
1883
+
1884
+ if all_mtp_loss is None:
1885
+ all_mtp_loss = []
1886
+ all_mtp_loss.append(mtp_loss)
1887
 
1888
  if not return_dict:
1889
  output = (logits,) + outputs[1:]
 
1891
  output = (aux_loss,) + output
1892
  return (loss,) + output if loss is not None else output
1893
 
1894
+ return MoEV2CausalLMOutputWithPast(
1895
  loss=loss,
1896
+ mtp_loss=all_mtp_loss,
1897
  aux_loss=aux_loss,
1898
  logits=logits,
1899
+ mtp_logits=all_mtp_logits,
1900
  past_key_values=outputs.past_key_values,
1901
  hidden_states=outputs.hidden_states,
1902
  attentions=outputs.attentions,
1903
  router_logits=outputs.router_logits,
1904
  )