Heinrich Dinkel commited on
Commit
46a0902
·
1 Parent(s): 9e7a843

Fix crash when last audio chunk has fewer frames than patch stride

Browse files

Pad the last chunk to self.time_patches frames when it's shorter after
splitting, preventing a RuntimeError in the Conv2d patch embedding when
input is not a multiple of target_length.

Files changed (1) hide show
  1. modeling_dasheng_encoder.py +8 -2
modeling_dasheng_encoder.py CHANGED
@@ -235,10 +235,16 @@ class DashengEncoder(nn.Module):
235
  x = rearrange(x, "b f t -> b 1 f t")
236
  x = self.init_bn(x)
237
 
238
- input_splits = x.split(self.target_length, dim=-1)
239
  masks = [None for _ in range(len(input_splits))]
240
  if attention_mask is not None:
241
- masks = attention_mask.split(self.target_length, dim=-1)
 
 
 
 
 
 
242
 
243
  outputs = []
244
  for i, (input_split_x, mask) in enumerate(zip(input_splits, masks)):
 
235
  x = rearrange(x, "b f t -> b 1 f t")
236
  x = self.init_bn(x)
237
 
238
+ input_splits = list(x.split(self.target_length, dim=-1))
239
  masks = [None for _ in range(len(input_splits))]
240
  if attention_mask is not None:
241
+ masks = list(attention_mask.split(self.target_length, dim=-1))
242
+
243
+ if input_splits[-1].shape[-1] < self.time_patches:
244
+ pad_size = self.time_patches - input_splits[-1].shape[-1]
245
+ input_splits[-1] = torch.nn.functional.pad(input_splits[-1], (0, pad_size))
246
+ if masks[-1] is not None:
247
+ masks[-1] = torch.nn.functional.pad(masks[-1], (0, pad_size), value=0)
248
 
249
  outputs = []
250
  for i, (input_split_x, mask) in enumerate(zip(input_splits, masks)):