OmniVoice_sync_data_and_code / prepare_sync_data.py
Abdelrahman2922's picture
Add files using upload-large-folder tool
a4d9876 verified
Raw
History Blame Contribute Delete
2.84 kB
import argparse
import io
import json
import random
import shutil
from pathlib import Path
import soundfile as sf
from datasets import Audio, load_dataset
from tqdm import tqdm
DEFAULT_REPO = "saleh1312/syncing_data"
MAX_DURATION = 10.0
def main():
parser = argparse.ArgumentParser(description="Prepare data for OmniVoice Training")
parser.add_argument("--repo", default=DEFAULT_REPO, help="HF Dataset ID")
parser.add_argument("--out", default="sync_data", help="Output directory")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
out_root = Path(args.out).resolve()
out_root = out_root / "data"
if out_root.exists():
print(f"Cleaning up old directory: {out_root}")
shutil.rmtree(out_root)
wav_dir = out_root / "wavs"
wav_dir.mkdir(parents=True)
print(f"Loading dataset: {args.repo}")
ds = load_dataset(args.repo, split="train")
ds = ds.cast_column("audio", Audio(decode=False))
processed_records = []
skipped = 0
print("Processing audio files...")
for i, row in enumerate(tqdm(ds)):
audio_data = row["audio"]["bytes"]
if not audio_data:
continue
# Load audio to check duration
with io.BytesIO(audio_data) as f:
data, sr = sf.read(f)
duration = len(data) / sr
if duration > MAX_DURATION:
skipped += 1
continue
# Ensure Mono
if data.ndim > 1:
data = data.mean(axis=1)
sample_id = f"sample_{i:06d}"
wav_path = wav_dir / f"{sample_id}.wav"
sf.write(wav_path, data, sr, subtype='PCM_16')
tone = str(row.get("tone", "neutral")).strip().lower()
processed_records.append({
"id": sample_id,
"audio_path": str(wav_path.resolve()),
"text": row["text"],
"language_id": "ar",
"instruct": f"saudi, conversational, {tone}"
})
random.seed(args.seed)
random.shuffle(processed_records)
split_idx = int(len(processed_records) * 0.95)
train_data = processed_records[:split_idx]
dev_data = processed_records[split_idx:]
# Write input files for the tokenization script
for name, data in [("train_raw.jsonl", train_data), ("dev_raw.jsonl", dev_data)]:
out_path = out_root / name
with open(out_path, "w", encoding="utf-8") as f:
for rec in data:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
print(f"Created {out_path} ({len(data)} samples)")
print(f"\nPreparation Complete!")
print(f"Skipped {skipped} samples (> {MAX_DURATION}s)")
print(f"Next: Run the 'extract_audio_tokens.py' script using 'sync_data/train_raw.jsonl'")
if __name__ == "__main__":
main()