Rthur2003 commited on
Commit
e93262a
·
1 Parent(s): 8f9848d

fix: add return type annotations and docstrings for clarity in Wav2Vec2MusicClassifier and AudioDataset

Browse files
Files changed (1) hide show
  1. app/training/wav2vec2_classifier.py +10 -5
app/training/wav2vec2_classifier.py CHANGED
@@ -53,7 +53,8 @@ class Wav2Vec2MusicClassifier(nn.Module):
53
  learn task-specific temporal patterns.
54
  """
55
 
56
- def __init__(self, config: Wav2Vec2Config | None = None):
 
57
  super().__init__()
58
  self.config = config or Wav2Vec2Config()
59
 
@@ -109,16 +110,18 @@ class AudioDataset(Dataset):
109
  labels: list[int],
110
  sample_rate: int = 16000,
111
  max_sec: float = 30.0,
112
- ):
 
113
  self.file_paths = file_paths
114
  self.labels = labels
115
  self.sample_rate = sample_rate
116
  self.max_samples = int(max_sec * sample_rate)
117
 
118
- def __len__(self):
 
119
  return len(self.file_paths)
120
 
121
- def __getitem__(self, idx):
122
  import librosa
123
 
124
  path = self.file_paths[idx]
@@ -136,7 +139,9 @@ class AudioDataset(Dataset):
136
  return torch.tensor(y, dtype=torch.float32), label
137
 
138
 
139
- def collate_fn(batch):
 
 
140
  """Collate audio tensors and labels."""
141
  audios, labels = zip(*batch)
142
  audios = torch.stack(audios)
 
53
  learn task-specific temporal patterns.
54
  """
55
 
56
+ def __init__(self, config: Wav2Vec2Config | None = None) -> None:
57
+ """Initialize wav2vec2 classifier with frozen CNN encoder."""
58
  super().__init__()
59
  self.config = config or Wav2Vec2Config()
60
 
 
110
  labels: list[int],
111
  sample_rate: int = 16000,
112
  max_sec: float = 30.0,
113
+ ) -> None:
114
+ """Initialize audio dataset with file paths and labels."""
115
  self.file_paths = file_paths
116
  self.labels = labels
117
  self.sample_rate = sample_rate
118
  self.max_samples = int(max_sec * sample_rate)
119
 
120
+ def __len__(self) -> int:
121
+ """Return dataset size."""
122
  return len(self.file_paths)
123
 
124
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
125
  import librosa
126
 
127
  path = self.file_paths[idx]
 
139
  return torch.tensor(y, dtype=torch.float32), label
140
 
141
 
142
+ def collate_fn(
143
+ batch: list[tuple[torch.Tensor, int]],
144
+ ) -> tuple[torch.Tensor, torch.Tensor]:
145
  """Collate audio tensors and labels."""
146
  audios, labels = zip(*batch)
147
  audios = torch.stack(audios)