Spaces:
Runtime error
Runtime error
Antoni Bigata
commited on
Commit
Β·
6ea1ef7
1
Parent(s):
cb604f6
requirements
Browse files- WavLM_modules.py +112 -36
WavLM_modules.py
CHANGED
|
@@ -121,9 +121,14 @@ class GLU_Linear(nn.Module):
|
|
| 121 |
x = self.linear(x)
|
| 122 |
|
| 123 |
if self.glu_type == "bilinear":
|
| 124 |
-
x =
|
|
|
|
|
|
|
|
|
|
| 125 |
else:
|
| 126 |
-
x = x[:, :, 0 : self.output_dim] * self.glu_act(
|
|
|
|
|
|
|
| 127 |
|
| 128 |
return x
|
| 129 |
|
|
@@ -131,7 +136,9 @@ class GLU_Linear(nn.Module):
|
|
| 131 |
def gelu_accurate(x):
|
| 132 |
if not hasattr(gelu_accurate, "_a"):
|
| 133 |
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 134 |
-
return
|
|
|
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
def gelu(x: torch.Tensor) -> torch.Tensor:
|
|
@@ -223,13 +230,17 @@ def quant_noise(module, p, block_size):
|
|
| 223 |
|
| 224 |
# 2D matrix
|
| 225 |
if not is_conv:
|
| 226 |
-
assert module.weight.size(1) % block_size == 0,
|
|
|
|
|
|
|
| 227 |
|
| 228 |
# 4D matrix
|
| 229 |
else:
|
| 230 |
# 1x1 convolutions
|
| 231 |
if module.kernel_size == (1, 1):
|
| 232 |
-
assert module.in_channels % block_size == 0,
|
|
|
|
|
|
|
| 233 |
# regular convolutions
|
| 234 |
else:
|
| 235 |
k = module.kernel_size[0] * module.kernel_size[1]
|
|
@@ -245,7 +256,9 @@ def quant_noise(module, p, block_size):
|
|
| 245 |
out_features = weight.size(0)
|
| 246 |
|
| 247 |
# split weight matrix into blocks and randomly drop selected blocks
|
| 248 |
-
mask = torch.zeros(
|
|
|
|
|
|
|
| 249 |
mask.bernoulli_(p)
|
| 250 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 251 |
|
|
@@ -264,12 +277,20 @@ def quant_noise(module, p, block_size):
|
|
| 264 |
mask.bernoulli_(p)
|
| 265 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 266 |
else:
|
| 267 |
-
mask = torch.zeros(
|
|
|
|
|
|
|
| 268 |
mask.bernoulli_(p)
|
| 269 |
-
mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
# scale weights and apply mask
|
| 272 |
-
mask = mask.to(
|
|
|
|
|
|
|
| 273 |
s = 1 / (1 - p)
|
| 274 |
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 275 |
|
|
@@ -320,14 +341,16 @@ class MultiheadAttention(nn.Module):
|
|
| 320 |
self.head_dim = embed_dim // num_heads
|
| 321 |
self.q_head_dim = self.head_dim
|
| 322 |
self.k_head_dim = self.head_dim
|
| 323 |
-
assert self.head_dim * num_heads == self.embed_dim,
|
|
|
|
|
|
|
| 324 |
self.scaling = self.head_dim**-0.5
|
| 325 |
|
| 326 |
self.self_attention = self_attention
|
| 327 |
self.encoder_decoder_attention = encoder_decoder_attention
|
| 328 |
|
| 329 |
assert not self.self_attention or self.qkv_same_dim, (
|
| 330 |
-
"Self-attention requires query, key and
|
| 331 |
)
|
| 332 |
|
| 333 |
k_bias = True
|
|
@@ -337,11 +360,19 @@ class MultiheadAttention(nn.Module):
|
|
| 337 |
k_embed_dim = embed_dim
|
| 338 |
q_embed_dim = embed_dim
|
| 339 |
|
| 340 |
-
self.k_proj = quant_noise(
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
-
self.out_proj = quant_noise(
|
|
|
|
|
|
|
| 345 |
|
| 346 |
if add_bias_kv:
|
| 347 |
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
|
@@ -390,7 +421,9 @@ class MultiheadAttention(nn.Module):
|
|
| 390 |
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 391 |
relative_positions = torch.abs(relative_positions)
|
| 392 |
else:
|
| 393 |
-
relative_positions = -torch.min(
|
|
|
|
|
|
|
| 394 |
|
| 395 |
max_exact = num_buckets // 2
|
| 396 |
is_small = relative_positions < max_exact
|
|
@@ -401,18 +434,25 @@ class MultiheadAttention(nn.Module):
|
|
| 401 |
* (num_buckets - max_exact)
|
| 402 |
).to(torch.long)
|
| 403 |
relative_postion_if_large = torch.min(
|
| 404 |
-
relative_postion_if_large,
|
|
|
|
| 405 |
)
|
| 406 |
|
| 407 |
-
relative_buckets += torch.where(
|
|
|
|
|
|
|
| 408 |
return relative_buckets
|
| 409 |
|
| 410 |
def compute_bias(self, query_length, key_length):
|
| 411 |
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 412 |
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 413 |
relative_position = memory_position - context_position
|
| 414 |
-
relative_position_bucket = self._relative_positions_bucket(
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
values = self.relative_attention_bias(relative_position_bucket)
|
| 417 |
values = values.permute([2, 0, 1])
|
| 418 |
return values
|
|
@@ -450,7 +490,7 @@ class MultiheadAttention(nn.Module):
|
|
| 450 |
if need_head_weights:
|
| 451 |
need_weights = True
|
| 452 |
|
| 453 |
-
is_tpu =
|
| 454 |
|
| 455 |
tgt_len, bsz, embed_dim = query.size()
|
| 456 |
src_len = tgt_len
|
|
@@ -466,7 +506,9 @@ class MultiheadAttention(nn.Module):
|
|
| 466 |
if self.has_relative_attention_bias and position_bias is None:
|
| 467 |
position_bias = self.compute_bias(tgt_len, src_len)
|
| 468 |
position_bias = (
|
| 469 |
-
position_bias.unsqueeze(0)
|
|
|
|
|
|
|
| 470 |
)
|
| 471 |
|
| 472 |
if (
|
|
@@ -492,10 +534,14 @@ class MultiheadAttention(nn.Module):
|
|
| 492 |
_B, _H, _L, __ = query_layer.size()
|
| 493 |
|
| 494 |
gate_a, gate_b = torch.sigmoid(
|
| 495 |
-
self.grep_linear(query_layer)
|
|
|
|
|
|
|
| 496 |
).chunk(2, dim=-1)
|
| 497 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 498 |
-
attn_mask_rel_pos =
|
|
|
|
|
|
|
| 499 |
|
| 500 |
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
| 501 |
k_proj_bias = self.k_proj.bias
|
|
@@ -565,7 +611,9 @@ class MultiheadAttention(nn.Module):
|
|
| 565 |
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 566 |
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 567 |
if attn_mask is not None:
|
| 568 |
-
attn_mask = torch.cat(
|
|
|
|
|
|
|
| 569 |
if key_padding_mask is not None:
|
| 570 |
key_padding_mask = torch.cat(
|
| 571 |
[
|
|
@@ -575,11 +623,23 @@ class MultiheadAttention(nn.Module):
|
|
| 575 |
dim=1,
|
| 576 |
)
|
| 577 |
|
| 578 |
-
q =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
if k is not None:
|
| 580 |
-
k =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
if v is not None:
|
| 582 |
-
v =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
|
| 584 |
if saved_state is not None:
|
| 585 |
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
|
@@ -638,12 +698,16 @@ class MultiheadAttention(nn.Module):
|
|
| 638 |
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 639 |
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 640 |
if attn_mask is not None:
|
| 641 |
-
attn_mask = torch.cat(
|
|
|
|
|
|
|
| 642 |
if key_padding_mask is not None:
|
| 643 |
key_padding_mask = torch.cat(
|
| 644 |
[
|
| 645 |
key_padding_mask,
|
| 646 |
-
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
|
|
|
|
|
|
| 647 |
],
|
| 648 |
dim=1,
|
| 649 |
)
|
|
@@ -679,10 +743,14 @@ class MultiheadAttention(nn.Module):
|
|
| 679 |
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
| 680 |
_B, _H, _L, __ = query_layer.size()
|
| 681 |
gate_a, gate_b = torch.sigmoid(
|
| 682 |
-
self.grep_linear(query_layer)
|
|
|
|
|
|
|
| 683 |
).chunk(2, dim=-1)
|
| 684 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 685 |
-
position_bias =
|
|
|
|
|
|
|
| 686 |
|
| 687 |
position_bias = position_bias.view(attn_weights.size())
|
| 688 |
|
|
@@ -699,7 +767,9 @@ class MultiheadAttention(nn.Module):
|
|
| 699 |
attn = self.out_proj(attn)
|
| 700 |
attn_weights: Optional[Tensor] = None
|
| 701 |
if need_weights:
|
| 702 |
-
attn_weights = attn_weights_float.view(
|
|
|
|
|
|
|
| 703 |
if not need_head_weights:
|
| 704 |
# average attention weights over heads
|
| 705 |
attn_weights = attn_weights.mean(dim=0)
|
|
@@ -718,7 +788,9 @@ class MultiheadAttention(nn.Module):
|
|
| 718 |
if prev_key_padding_mask is not None and static_kv:
|
| 719 |
new_key_padding_mask = prev_key_padding_mask
|
| 720 |
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 721 |
-
new_key_padding_mask = torch.cat(
|
|
|
|
|
|
|
| 722 |
# During incremental decoding, as the padding token enters and
|
| 723 |
# leaves the frame, there will be a time when prev or current
|
| 724 |
# is None
|
|
@@ -728,7 +800,9 @@ class MultiheadAttention(nn.Module):
|
|
| 728 |
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 729 |
device=prev_key_padding_mask.device,
|
| 730 |
)
|
| 731 |
-
new_key_padding_mask = torch.cat(
|
|
|
|
|
|
|
| 732 |
else:
|
| 733 |
new_key_padding_mask = prev_key_padding_mask.float()
|
| 734 |
elif key_padding_mask is not None:
|
|
@@ -737,7 +811,9 @@ class MultiheadAttention(nn.Module):
|
|
| 737 |
(batch_size, src_len - key_padding_mask.size(1)),
|
| 738 |
device=key_padding_mask.device,
|
| 739 |
)
|
| 740 |
-
new_key_padding_mask = torch.cat(
|
|
|
|
|
|
|
| 741 |
else:
|
| 742 |
new_key_padding_mask = key_padding_mask.float()
|
| 743 |
else:
|
|
|
|
| 121 |
x = self.linear(x)
|
| 122 |
|
| 123 |
if self.glu_type == "bilinear":
|
| 124 |
+
x = (
|
| 125 |
+
x[:, :, 0 : self.output_dim]
|
| 126 |
+
* x[:, :, self.output_dim : self.output_dim * 2]
|
| 127 |
+
)
|
| 128 |
else:
|
| 129 |
+
x = x[:, :, 0 : self.output_dim] * self.glu_act(
|
| 130 |
+
x[:, :, self.output_dim : self.output_dim * 2]
|
| 131 |
+
)
|
| 132 |
|
| 133 |
return x
|
| 134 |
|
|
|
|
| 136 |
def gelu_accurate(x):
|
| 137 |
if not hasattr(gelu_accurate, "_a"):
|
| 138 |
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 139 |
+
return (
|
| 140 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
| 141 |
+
)
|
| 142 |
|
| 143 |
|
| 144 |
def gelu(x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 230 |
|
| 231 |
# 2D matrix
|
| 232 |
if not is_conv:
|
| 233 |
+
assert module.weight.size(1) % block_size == 0, (
|
| 234 |
+
"Input features must be a multiple of block sizes"
|
| 235 |
+
)
|
| 236 |
|
| 237 |
# 4D matrix
|
| 238 |
else:
|
| 239 |
# 1x1 convolutions
|
| 240 |
if module.kernel_size == (1, 1):
|
| 241 |
+
assert module.in_channels % block_size == 0, (
|
| 242 |
+
"Input channels must be a multiple of block sizes"
|
| 243 |
+
)
|
| 244 |
# regular convolutions
|
| 245 |
else:
|
| 246 |
k = module.kernel_size[0] * module.kernel_size[1]
|
|
|
|
| 256 |
out_features = weight.size(0)
|
| 257 |
|
| 258 |
# split weight matrix into blocks and randomly drop selected blocks
|
| 259 |
+
mask = torch.zeros(
|
| 260 |
+
in_features // block_size * out_features, device=weight.device
|
| 261 |
+
)
|
| 262 |
mask.bernoulli_(p)
|
| 263 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 264 |
|
|
|
|
| 277 |
mask.bernoulli_(p)
|
| 278 |
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 279 |
else:
|
| 280 |
+
mask = torch.zeros(
|
| 281 |
+
weight.size(0), weight.size(1), device=weight.device
|
| 282 |
+
)
|
| 283 |
mask.bernoulli_(p)
|
| 284 |
+
mask = (
|
| 285 |
+
mask.unsqueeze(2)
|
| 286 |
+
.unsqueeze(3)
|
| 287 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 288 |
+
)
|
| 289 |
|
| 290 |
# scale weights and apply mask
|
| 291 |
+
mask = mask.to(
|
| 292 |
+
torch.bool
|
| 293 |
+
) # x.bool() is not currently supported in TorchScript
|
| 294 |
s = 1 / (1 - p)
|
| 295 |
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 296 |
|
|
|
|
| 341 |
self.head_dim = embed_dim // num_heads
|
| 342 |
self.q_head_dim = self.head_dim
|
| 343 |
self.k_head_dim = self.head_dim
|
| 344 |
+
assert self.head_dim * num_heads == self.embed_dim, (
|
| 345 |
+
"embed_dim must be divisible by num_heads"
|
| 346 |
+
)
|
| 347 |
self.scaling = self.head_dim**-0.5
|
| 348 |
|
| 349 |
self.self_attention = self_attention
|
| 350 |
self.encoder_decoder_attention = encoder_decoder_attention
|
| 351 |
|
| 352 |
assert not self.self_attention or self.qkv_same_dim, (
|
| 353 |
+
"Self-attention requires query, key and value to be of the same size"
|
| 354 |
)
|
| 355 |
|
| 356 |
k_bias = True
|
|
|
|
| 360 |
k_embed_dim = embed_dim
|
| 361 |
q_embed_dim = embed_dim
|
| 362 |
|
| 363 |
+
self.k_proj = quant_noise(
|
| 364 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
| 365 |
+
)
|
| 366 |
+
self.v_proj = quant_noise(
|
| 367 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 368 |
+
)
|
| 369 |
+
self.q_proj = quant_noise(
|
| 370 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
| 371 |
+
)
|
| 372 |
|
| 373 |
+
self.out_proj = quant_noise(
|
| 374 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 375 |
+
)
|
| 376 |
|
| 377 |
if add_bias_kv:
|
| 378 |
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
|
|
|
| 421 |
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 422 |
relative_positions = torch.abs(relative_positions)
|
| 423 |
else:
|
| 424 |
+
relative_positions = -torch.min(
|
| 425 |
+
relative_positions, torch.zeros_like(relative_positions)
|
| 426 |
+
)
|
| 427 |
|
| 428 |
max_exact = num_buckets // 2
|
| 429 |
is_small = relative_positions < max_exact
|
|
|
|
| 434 |
* (num_buckets - max_exact)
|
| 435 |
).to(torch.long)
|
| 436 |
relative_postion_if_large = torch.min(
|
| 437 |
+
relative_postion_if_large,
|
| 438 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
| 439 |
)
|
| 440 |
|
| 441 |
+
relative_buckets += torch.where(
|
| 442 |
+
is_small, relative_positions, relative_postion_if_large
|
| 443 |
+
)
|
| 444 |
return relative_buckets
|
| 445 |
|
| 446 |
def compute_bias(self, query_length, key_length):
|
| 447 |
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 448 |
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 449 |
relative_position = memory_position - context_position
|
| 450 |
+
relative_position_bucket = self._relative_positions_bucket(
|
| 451 |
+
relative_position, bidirectional=True
|
| 452 |
+
)
|
| 453 |
+
relative_position_bucket = relative_position_bucket.to(
|
| 454 |
+
self.relative_attention_bias.weight.device
|
| 455 |
+
)
|
| 456 |
values = self.relative_attention_bias(relative_position_bucket)
|
| 457 |
values = values.permute([2, 0, 1])
|
| 458 |
return values
|
|
|
|
| 490 |
if need_head_weights:
|
| 491 |
need_weights = True
|
| 492 |
|
| 493 |
+
is_tpu = False
|
| 494 |
|
| 495 |
tgt_len, bsz, embed_dim = query.size()
|
| 496 |
src_len = tgt_len
|
|
|
|
| 506 |
if self.has_relative_attention_bias and position_bias is None:
|
| 507 |
position_bias = self.compute_bias(tgt_len, src_len)
|
| 508 |
position_bias = (
|
| 509 |
+
position_bias.unsqueeze(0)
|
| 510 |
+
.repeat(bsz, 1, 1, 1)
|
| 511 |
+
.view(bsz * self.num_heads, tgt_len, src_len)
|
| 512 |
)
|
| 513 |
|
| 514 |
if (
|
|
|
|
| 534 |
_B, _H, _L, __ = query_layer.size()
|
| 535 |
|
| 536 |
gate_a, gate_b = torch.sigmoid(
|
| 537 |
+
self.grep_linear(query_layer)
|
| 538 |
+
.view(_B, _H, _L, 2, 4)
|
| 539 |
+
.sum(-1, keepdim=False)
|
| 540 |
).chunk(2, dim=-1)
|
| 541 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 542 |
+
attn_mask_rel_pos = (
|
| 543 |
+
gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
| 544 |
+
)
|
| 545 |
|
| 546 |
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
| 547 |
k_proj_bias = self.k_proj.bias
|
|
|
|
| 611 |
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 612 |
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 613 |
if attn_mask is not None:
|
| 614 |
+
attn_mask = torch.cat(
|
| 615 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 616 |
+
)
|
| 617 |
if key_padding_mask is not None:
|
| 618 |
key_padding_mask = torch.cat(
|
| 619 |
[
|
|
|
|
| 623 |
dim=1,
|
| 624 |
)
|
| 625 |
|
| 626 |
+
q = (
|
| 627 |
+
q.contiguous()
|
| 628 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
| 629 |
+
.transpose(0, 1)
|
| 630 |
+
)
|
| 631 |
if k is not None:
|
| 632 |
+
k = (
|
| 633 |
+
k.contiguous()
|
| 634 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
| 635 |
+
.transpose(0, 1)
|
| 636 |
+
)
|
| 637 |
if v is not None:
|
| 638 |
+
v = (
|
| 639 |
+
v.contiguous()
|
| 640 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 641 |
+
.transpose(0, 1)
|
| 642 |
+
)
|
| 643 |
|
| 644 |
if saved_state is not None:
|
| 645 |
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
|
|
|
| 698 |
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 699 |
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 700 |
if attn_mask is not None:
|
| 701 |
+
attn_mask = torch.cat(
|
| 702 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 703 |
+
)
|
| 704 |
if key_padding_mask is not None:
|
| 705 |
key_padding_mask = torch.cat(
|
| 706 |
[
|
| 707 |
key_padding_mask,
|
| 708 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
| 709 |
+
key_padding_mask
|
| 710 |
+
),
|
| 711 |
],
|
| 712 |
dim=1,
|
| 713 |
)
|
|
|
|
| 743 |
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
| 744 |
_B, _H, _L, __ = query_layer.size()
|
| 745 |
gate_a, gate_b = torch.sigmoid(
|
| 746 |
+
self.grep_linear(query_layer)
|
| 747 |
+
.view(_B, _H, _L, 2, 4)
|
| 748 |
+
.sum(-1, keepdim=False)
|
| 749 |
).chunk(2, dim=-1)
|
| 750 |
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 751 |
+
position_bias = (
|
| 752 |
+
gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
| 753 |
+
)
|
| 754 |
|
| 755 |
position_bias = position_bias.view(attn_weights.size())
|
| 756 |
|
|
|
|
| 767 |
attn = self.out_proj(attn)
|
| 768 |
attn_weights: Optional[Tensor] = None
|
| 769 |
if need_weights:
|
| 770 |
+
attn_weights = attn_weights_float.view(
|
| 771 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 772 |
+
).transpose(1, 0)
|
| 773 |
if not need_head_weights:
|
| 774 |
# average attention weights over heads
|
| 775 |
attn_weights = attn_weights.mean(dim=0)
|
|
|
|
| 788 |
if prev_key_padding_mask is not None and static_kv:
|
| 789 |
new_key_padding_mask = prev_key_padding_mask
|
| 790 |
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 791 |
+
new_key_padding_mask = torch.cat(
|
| 792 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 793 |
+
)
|
| 794 |
# During incremental decoding, as the padding token enters and
|
| 795 |
# leaves the frame, there will be a time when prev or current
|
| 796 |
# is None
|
|
|
|
| 800 |
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 801 |
device=prev_key_padding_mask.device,
|
| 802 |
)
|
| 803 |
+
new_key_padding_mask = torch.cat(
|
| 804 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 805 |
+
)
|
| 806 |
else:
|
| 807 |
new_key_padding_mask = prev_key_padding_mask.float()
|
| 808 |
elif key_padding_mask is not None:
|
|
|
|
| 811 |
(batch_size, src_len - key_padding_mask.size(1)),
|
| 812 |
device=key_padding_mask.device,
|
| 813 |
)
|
| 814 |
+
new_key_padding_mask = torch.cat(
|
| 815 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 816 |
+
)
|
| 817 |
else:
|
| 818 |
new_key_padding_mask = key_padding_mask.float()
|
| 819 |
else:
|