import sys import tempfile import unittest from pathlib import Path import numpy as np import soundfile as sf REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) class ModelDefaultTests(unittest.TestCase): def test_roformer_default_uses_public_audio_separator_sota_model(self): from infer import separator self.assertEqual( separator.ROFORMER_DEFAULT_MODEL, "ensemble:vocal_rvc", ) self.assertEqual( separator.ROFORMER_SOTA_MODELS, [ "melband_roformer_big_beta6x.ckpt", "mel_band_roformer_vocals_fv4_gabox.ckpt", ], ) self.assertIn( "vocals_mel_band_roformer.ckpt", separator.ROFORMER_LEGACY_SINGLE_MODEL, ) def test_karaoke_default_uses_public_sota_ensemble(self): from infer import separator self.assertEqual( separator.KARAOKE_DEFAULT_MODEL, "ensemble:karaoke", ) self.assertEqual( separator.KARAOKE_SOTA_MODEL, "ensemble:karaoke", ) self.assertEqual( separator.KARAOKE_SOTA_MODELS, [ "mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt", "mel_band_roformer_karaoke_gabox_v2.ckpt", "mel_band_roformer_karaoke_becruily.ckpt", ], ) self.assertEqual( separator.KARAOKE_LEGACY_SINGLE_MODEL, "mel_band_roformer_karaoke_gabox.ckpt", ) def test_deecho_default_uses_public_roformer_dereverb_model(self): from infer import separator self.assertEqual( separator.ROFORMER_DEREVERB_DEFAULT_MODEL, "dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt", ) def test_strict_sota_defaults_do_not_expose_model_fallback_lists(self): from infer import separator self.assertFalse(hasattr(separator, "ROFORMER_FALLBACK_MODELS")) self.assertFalse(hasattr(separator, "KARAOKE_FALLBACK_MODELS")) self.assertFalse(hasattr(separator, "ROFORMER_DEREVERB_FALLBACK_MODELS")) class KaraokeCandidateScoringTests(unittest.TestCase): def test_karaoke_candidate_score_rewards_reconstruction_and_low_correlation(self): from tools.evaluate_karaoke_models import score_karaoke_stems sr = 16000 t = np.arange(sr, dtype=np.float32) / sr lead_good = 0.18 * np.sin(2 * np.pi * 220 * t) backing_good = 0.05 * np.sin(2 * np.pi * 330 * t + 0.4) input_vocals = lead_good + backing_good with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) input_path = tmp_path / "input.wav" lead_good_path = tmp_path / "lead_good.wav" backing_good_path = tmp_path / "backing_good.wav" lead_bad_path = tmp_path / "lead_bad.wav" backing_bad_path = tmp_path / "backing_bad.wav" sf.write(input_path, input_vocals, sr) sf.write(lead_good_path, lead_good, sr) sf.write(backing_good_path, backing_good, sr) sf.write(lead_bad_path, input_vocals, sr) sf.write(backing_bad_path, 0.7 * input_vocals, sr) good = score_karaoke_stems(input_path, lead_good_path, backing_good_path) bad = score_karaoke_stems(input_path, lead_bad_path, backing_bad_path) self.assertGreater(good["score"], bad["score"]) self.assertLess(good["reconstruction_error"], bad["reconstruction_error"]) self.assertLess(good["lead_backing_abs_corr"], bad["lead_backing_abs_corr"]) def test_karaoke_candidate_score_penalizes_truncated_stems(self): from tools.evaluate_karaoke_models import score_karaoke_stems sr = 16000 t = np.arange(sr, dtype=np.float32) / sr lead_good = 0.18 * np.sin(2 * np.pi * 220 * t) backing_good = 0.04 * np.sin(2 * np.pi * 330 * t + 0.4) input_vocals = lead_good + backing_good short_len = sr // 4 with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) input_path = tmp_path / "input.wav" lead_short_path = tmp_path / "lead_short.wav" backing_short_path = tmp_path / "backing_short.wav" lead_full_path = tmp_path / "lead_full.wav" backing_full_path = tmp_path / "backing_full.wav" sf.write(input_path, input_vocals, sr) sf.write(lead_short_path, lead_good[:short_len], sr) sf.write(backing_short_path, backing_good[:short_len], sr) sf.write(lead_full_path, 0.97 * lead_good, sr) sf.write(backing_full_path, 0.97 * backing_good, sr) short = score_karaoke_stems(input_path, lead_short_path, backing_short_path) full = score_karaoke_stems(input_path, lead_full_path, backing_full_path) self.assertIn("length_coverage", short) self.assertLess(short["length_coverage"], 0.999) self.assertGreaterEqual(full["length_coverage"], 0.999) self.assertGreater(full["score"], short["score"]) def test_reference_karaoke_score_uses_true_si_sdr_when_refs_exist(self): from tools.evaluate_karaoke_models import score_reference_stems sr = 16000 t = np.arange(sr, dtype=np.float32) / sr lead = 0.18 * np.sin(2 * np.pi * 220 * t) backing = 0.04 * np.sin(2 * np.pi * 330 * t + 0.4) with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) reference_lead_path = tmp_path / "reference_lead.wav" reference_backing_path = tmp_path / "reference_backing.wav" lead_path = tmp_path / "lead.wav" backing_path = tmp_path / "backing.wav" sf.write(reference_lead_path, lead, sr) sf.write(reference_backing_path, backing, sr) sf.write(lead_path, lead, sr) sf.write(backing_path, backing, sr) metrics = score_reference_stems( reference_lead_path, reference_backing_path, lead_path, backing_path, ) self.assertGreater(metrics["mean_si_sdr"], 100.0) self.assertIn("lead", metrics["stems"]) self.assertIn("backing", metrics["stems"]) if __name__ == "__main__": unittest.main()