Spaces:
Sleeping
Sleeping
fix: add return type annotations and docstrings for clarity in Wav2Vec2MusicClassifier and AudioDataset
Browse files
app/training/wav2vec2_classifier.py
CHANGED
|
@@ -53,7 +53,8 @@ class Wav2Vec2MusicClassifier(nn.Module):
|
|
| 53 |
learn task-specific temporal patterns.
|
| 54 |
"""
|
| 55 |
|
| 56 |
-
def __init__(self, config: Wav2Vec2Config | None = None):
|
|
|
|
| 57 |
super().__init__()
|
| 58 |
self.config = config or Wav2Vec2Config()
|
| 59 |
|
|
@@ -109,16 +110,18 @@ class AudioDataset(Dataset):
|
|
| 109 |
labels: list[int],
|
| 110 |
sample_rate: int = 16000,
|
| 111 |
max_sec: float = 30.0,
|
| 112 |
-
):
|
|
|
|
| 113 |
self.file_paths = file_paths
|
| 114 |
self.labels = labels
|
| 115 |
self.sample_rate = sample_rate
|
| 116 |
self.max_samples = int(max_sec * sample_rate)
|
| 117 |
|
| 118 |
-
def __len__(self):
|
|
|
|
| 119 |
return len(self.file_paths)
|
| 120 |
|
| 121 |
-
def __getitem__(self, idx):
|
| 122 |
import librosa
|
| 123 |
|
| 124 |
path = self.file_paths[idx]
|
|
@@ -136,7 +139,9 @@ class AudioDataset(Dataset):
|
|
| 136 |
return torch.tensor(y, dtype=torch.float32), label
|
| 137 |
|
| 138 |
|
| 139 |
-
def collate_fn(
|
|
|
|
|
|
|
| 140 |
"""Collate audio tensors and labels."""
|
| 141 |
audios, labels = zip(*batch)
|
| 142 |
audios = torch.stack(audios)
|
|
|
|
| 53 |
learn task-specific temporal patterns.
|
| 54 |
"""
|
| 55 |
|
| 56 |
+
def __init__(self, config: Wav2Vec2Config | None = None) -> None:
|
| 57 |
+
"""Initialize wav2vec2 classifier with frozen CNN encoder."""
|
| 58 |
super().__init__()
|
| 59 |
self.config = config or Wav2Vec2Config()
|
| 60 |
|
|
|
|
| 110 |
labels: list[int],
|
| 111 |
sample_rate: int = 16000,
|
| 112 |
max_sec: float = 30.0,
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Initialize audio dataset with file paths and labels."""
|
| 115 |
self.file_paths = file_paths
|
| 116 |
self.labels = labels
|
| 117 |
self.sample_rate = sample_rate
|
| 118 |
self.max_samples = int(max_sec * sample_rate)
|
| 119 |
|
| 120 |
+
def __len__(self) -> int:
|
| 121 |
+
"""Return dataset size."""
|
| 122 |
return len(self.file_paths)
|
| 123 |
|
| 124 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
|
| 125 |
import librosa
|
| 126 |
|
| 127 |
path = self.file_paths[idx]
|
|
|
|
| 139 |
return torch.tensor(y, dtype=torch.float32), label
|
| 140 |
|
| 141 |
|
| 142 |
+
def collate_fn(
|
| 143 |
+
batch: list[tuple[torch.Tensor, int]],
|
| 144 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 145 |
"""Collate audio tensors and labels."""
|
| 146 |
audios, labels = zip(*batch)
|
| 147 |
audios = torch.stack(audios)
|