| import argparse |
| import os |
| from pathlib import Path |
| import sys |
| import torchaudio |
| import numpy as np |
| from time import time |
| import torch |
| import typing as tp |
| from omegaconf import OmegaConf |
| from vocos import VocosDecoder |
| from models.soundstream_hubert_new import SoundStream |
| from tqdm import tqdm |
|
|
| def build_soundstream_model(config): |
| model = eval(config.generator.name)(**config.generator.config) |
| return model |
|
|
| def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path): |
| vocal_decoder = VocosDecoder.from_hparams(config_path=config_path) |
| vocal_decoder.load_state_dict(torch.load(vocal_decoder_path)) |
| inst_decoder = VocosDecoder.from_hparams(config_path=config_path) |
| inst_decoder.load_state_dict(torch.load(inst_decoder_path)) |
| return vocal_decoder, inst_decoder |
|
|
| def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False): |
| limit = 0.99 |
| mx = wav.abs().max() |
| if rescale: |
| wav = wav * min(limit / mx, 1) |
| else: |
| wav = wav.clamp(-limit, limit) |
| |
| path = str(Path(path).with_suffix('.mp3')) |
| torchaudio.save(path, wav, sample_rate=sample_rate) |
|
|
| def process_audio(input_file, output_file, rescale, args, decoder, soundstream): |
| compressed = np.load(input_file, allow_pickle=True).astype(np.int16) |
| print(f"Processing {input_file}") |
| print(f"Compressed shape: {compressed.shape}") |
| |
| args.bw = float(4) |
| compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1) |
| compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}")) |
| compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}") |
| |
| start_time = time() |
| with torch.no_grad(): |
| decoder.eval() |
| decoder = decoder.to(f"cuda:{args.cuda_idx}") |
| out = decoder(compressed) |
| out = out.detach().cpu() |
| duration = time() - start_time |
| rtf = (out.shape[1] / 44100.0) / duration |
| print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)") |
| |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) |
| save_audio(out, output_file, 44100, rescale=rescale) |
| print(f"Saved: {output_file}") |
| return out |
|
|
| def find_matching_pairs(input_folder): |
| if str(input_folder).endswith('.lst'): |
| with open(input_folder, 'r') as file: |
| files = [line.strip() for line in file if line.strip()] |
| else: |
| files = list(Path(input_folder).glob('*.npy')) |
| print(f"found {len(files)} npy.") |
| instrumental_files = {} |
| vocal_files = {} |
| |
| for file in files: |
| if not isinstance(file, Path): |
| file = Path(file) |
| name = file.stem |
| if 'instrumental' in name.lower(): |
| base_name = name.lower().replace('instrumental', '') |
| instrumental_files[base_name] = file |
| elif 'vocal' in name.lower(): |
| |
| last_index = name.lower().rfind('vocal') |
| if last_index != -1: |
| |
| base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):] |
| else: |
| base_name = name.lower() |
| vocal_files[base_name] = file |
| |
| |
| pairs = [] |
| for base_name in instrumental_files.keys(): |
| if base_name in vocal_files: |
| pairs.append(( |
| instrumental_files[base_name], |
| vocal_files[base_name], |
| base_name |
| )) |
| |
| return pairs |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.') |
| parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.') |
| parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.') |
| parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.') |
| parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.') |
| parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.') |
| parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.') |
| parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.') |
| args = parser.parse_args() |
|
|
| |
| if not args.input_folder.exists(): |
| sys.exit(f"Input folder {args.input_folder} does not exist.") |
| if not os.path.isfile(args.config_path): |
| sys.exit(f"{args.config_path} file does not exist.") |
| |
| |
|
|
| |
| mix_dir = args.output_base / 'mix' |
| stems_dir = args.output_base / 'stems' |
| os.makedirs(mix_dir, exist_ok=True) |
| os.makedirs(stems_dir, exist_ok=True) |
|
|
| |
| config_ss = OmegaConf.load("./final_ckpt/config.yaml") |
| soundstream = build_soundstream_model(config_ss) |
| parameter_dict = torch.load(args.resume_path) |
| soundstream.load_state_dict(parameter_dict['codec_model']) |
| soundstream.eval() |
| |
| vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path) |
| |
| |
| pairs = find_matching_pairs(args.input_folder) |
| print(f"Found {len(pairs)} matching pairs") |
| pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')] |
| print(f"{len(pairs)} to reconstruct...") |
| |
| for instrumental_file, vocal_file, base_name in tqdm(pairs): |
| print(f"\nProcessing pair: {base_name}") |
| |
| song_stems_dir = stems_dir / base_name |
| os.makedirs(song_stems_dir, exist_ok=True) |
| |
| try: |
| |
| instrumental_output = process_audio( |
| instrumental_file, |
| song_stems_dir / 'instrumental.mp3', |
| args.rescale, |
| args, |
| inst_decoder, |
| soundstream |
| ) |
| |
| |
| vocal_output = process_audio( |
| vocal_file, |
| song_stems_dir / 'vocal.mp3', |
| args.rescale, |
| args, |
| vocal_decoder, |
| soundstream |
| ) |
| except IndexError as e: |
| print(e) |
| continue |
| |
| |
| try: |
| mix_output = instrumental_output + vocal_output |
| save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale) |
| print(f"Created mix: {mix_dir / f'{base_name}.mp3'}") |
| except RuntimeError as e: |
| print(e) |
| print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|
| |
| |