UniFlow-Audio / modeling_uniflow_audio.py
wsntxxn's picture
Fix g2p issue
37f392f
from typing import Any, Sequence
from pathlib import Path
import json
import os
import shutil
import h5py
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from safetensors.torch import load_file
import hydra
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import T5EncoderModel, T5Tokenizer
class UniFlowAudioModel(nn.Module):
def __init__(self, model_name: str = "wsntxxn/UniFlow-Audio-large"):
assert model_name in (
"wsntxxn/UniFlow-Audio-large",
"wsntxxn/UniFlow-Audio-medium",
"wsntxxn/UniFlow-Audio-small",
)
super().__init__()
model_dir = snapshot_download(repo_id=model_name)
model_dir = Path(model_dir)
self.config = OmegaConf.load(model_dir / "config.yaml")
self.config["model"]["autoencoder"]["pretrained_ckpt"] = str(
model_dir / self.config["model"]["autoencoder"]["pretrained_ckpt"]
)
flan_t5_path = os.environ.get("FLAN_T5_PATH", "google/flan-t5-large")
try:
tokenizer = T5Tokenizer.from_pretrained(flan_t5_path)
encoder = T5EncoderModel.from_pretrained(flan_t5_path)
except Exception as e:
raise RuntimeError(
"Failed to initialize Flan-T5, please download it manually and set the `FLAN_T5_PATH`"
"environment variable to the path of the downloaded model."
) from e
self.config["model"]["content_encoder"]["text_encoder"]["model_name"
] = flan_t5_path
self.model = hydra.utils.instantiate(
self.config["model"], _convert_="all"
)
state_dict = load_file(model_dir / "model.safetensors")
self.model.load_pretrained(state_dict)
self.model.eval()
self.g2p_model_path = model_dir / "mfa_g2p" / "english_us_arpa_unhashed.zip"
if not self.g2p_model_path.exists():
ori_model_path = (model_dir / "mfa_g2p" /
"english_us_arpa.zip").resolve()
shutil.copy(ori_model_path, self.g2p_model_path)
self.tts_phone_set_path = model_dir / "mfa_g2p" / "phone_set.json"
self.tts_word2phone_dict_path = model_dir / "mfa_g2p" / "word2phone.json"
self.build_tts_phone_mapping()
self.svs_phone_set_path = model_dir / "svs" / "phone_set.json"
singers = json.load(open(model_dir / "svs" / "spk_set.json", "r"))
self.svs_singer_mapping = {
singer: i
for i, singer in enumerate(singers)
}
self.svs_pinyin2ph = model_dir / "svs" / "m4singer_pinyin2ph.txt"
self.task_to_instructions = {}
with h5py.File(model_dir / "instructions" / "t5_embeddings.h5") as hf:
for key in hf.keys():
self.task_to_instructions[key] = hf[key][()]
self.init_instruction_encoder()
def build_tts_phone_mapping(self):
with open(self.tts_phone_set_path, "r", encoding="utf-8") as f:
phone_set = json.load(f)
self.tts_phone2id = {p: i for i, p in enumerate(phone_set)}
def init_instruction_encoder(self):
flan_t5_path = os.environ.get("FLAN_T5_PATH", "google/flan-t5-large")
try:
self.instruction_tokenizer = T5Tokenizer.from_pretrained(
flan_t5_path
)
self.instruction_encoder = T5EncoderModel.from_pretrained(
flan_t5_path
)
except Exception as e:
raise RuntimeError(
"Failed to initialize Flan-T5, please download it manually and set the `FLAN_T5_PATH`"
"environment variable to the path of the downloaded model."
) from e
self.instruction_encoder.eval()
@torch.inference_mode()
def encode_instruction(self, instruction: list[str], device: torch.device):
with torch.amp.autocast(enabled=False):
tokens = self.instruction_tokenizer(
instruction,
max_length=self.instruction_tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
input_ids = tokens.input_ids.to(device)
attention_mask = tokens.attention_mask.to(device)
output = self.instruction_encoder(
input_ids=input_ids, attention_mask=attention_mask
)
output = output.last_hidden_state
length = attention_mask.sum(dim=1)
return output, length
@torch.inference_mode()
def sample(
self,
content: list[Any],
task: list[str],
is_time_aligned: Sequence[bool],
instruction: list[str] | None = None,
instruction_idx: list[int] | None = None,
num_steps: int = 20,
sway_sampling_coef: float | None = -1.0,
guidance_scale: float = 3.0,
disable_progress: bool = True,
):
device = self.model.dummy_param.device
if instruction is None:
instructions = []
instruction_lengths = []
for sample_idx, task_ in enumerate(task):
if instruction_idx:
instruction_idx_ = instruction_idx[sample_idx]
else:
instruction_idx_ = 0
instruction_ = self.task_to_instructions[
f"{task_}_{instruction_idx_}"]
instructions.append(torch.as_tensor(instruction_))
instruction_lengths.append(instruction_.shape[0])
instructions = pad_sequence(instructions,
batch_first=True).to(device)
instruction_lengths = torch.as_tensor(instruction_lengths
).to(device)
else:
instructions, instruction_lengths = self.encode_instruction(
instruction, device
)
return self.model.inference(
content, task, is_time_aligned, instructions, instruction_lengths,
num_steps, sway_sampling_coef, guidance_scale, disable_progress
)