Spaces:
Running
on
Zero
Running
on
Zero
Fix g2p issue
Browse files- inference_cli.py +26 -13
- modeling_uniflow_audio.py +27 -6
- utils/phonemize.py +87 -0
inference_cli.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
| 2 |
|
| 3 |
from typing import Any, Callable
|
| 4 |
import json
|
| 5 |
-
import os
|
| 6 |
|
| 7 |
import fire
|
| 8 |
import torch
|
|
@@ -11,7 +10,8 @@ import soundfile as sf
|
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
from modeling_uniflow_audio import UniFlowAudioModel
|
| 14 |
-
from constants import TIME_ALIGNED_TASKS
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class InferenceCLI:
|
|
@@ -21,6 +21,7 @@ class InferenceCLI:
|
|
| 21 |
"cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
)
|
| 23 |
self.g2p = None
|
|
|
|
| 24 |
self.speaker_model = None
|
| 25 |
self.svs_processor = None
|
| 26 |
self.singer_mapping = None
|
|
@@ -82,7 +83,7 @@ class InferenceCLI:
|
|
| 82 |
@staticmethod
|
| 83 |
def add_prehook(func: Callable, ):
|
| 84 |
def wrapper(self, *args, **kwargs):
|
| 85 |
-
model_name = kwargs
|
| 86 |
self.on_inference_start(model_name)
|
| 87 |
return func(self, *args, **kwargs)
|
| 88 |
|
|
@@ -144,22 +145,33 @@ class InferenceCLI:
|
|
| 144 |
num_steps: int = 25,
|
| 145 |
output_path: str = "./output.wav",
|
| 146 |
):
|
| 147 |
-
|
| 148 |
-
import
|
| 149 |
|
| 150 |
self.init_speaker_model()
|
| 151 |
|
| 152 |
if not self.g2p:
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
)
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
phonemes = [ph for ph in phonemes if ph != " "]
|
| 163 |
phone_indices = [
|
| 164 |
self.model.tts_phone2id.get(
|
| 165 |
p, self.model.tts_phone2id.get("spn", 0)
|
|
@@ -370,6 +382,7 @@ class InferenceCLI:
|
|
| 370 |
)
|
| 371 |
waveform = waveform[0, 0].cpu().numpy()
|
| 372 |
|
|
|
|
| 373 |
if not output_path.endswith(".mp4"):
|
| 374 |
sf.write(output_path, waveform, self.sample_rate)
|
| 375 |
|
|
|
|
| 2 |
|
| 3 |
from typing import Any, Callable
|
| 4 |
import json
|
|
|
|
| 5 |
|
| 6 |
import fire
|
| 7 |
import torch
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
|
| 12 |
from modeling_uniflow_audio import UniFlowAudioModel
|
| 13 |
+
from constants import TIME_ALIGNED_TASKS
|
| 14 |
+
from utils.phonemize import sentence_to_phones
|
| 15 |
|
| 16 |
|
| 17 |
class InferenceCLI:
|
|
|
|
| 21 |
"cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
)
|
| 23 |
self.g2p = None
|
| 24 |
+
self.word2phone = None
|
| 25 |
self.speaker_model = None
|
| 26 |
self.svs_processor = None
|
| 27 |
self.singer_mapping = None
|
|
|
|
| 83 |
@staticmethod
|
| 84 |
def add_prehook(func: Callable, ):
|
| 85 |
def wrapper(self, *args, **kwargs):
|
| 86 |
+
model_name = kwargs.get("model_name", "UniFlow-Audio-large")
|
| 87 |
self.on_inference_start(model_name)
|
| 88 |
return func(self, *args, **kwargs)
|
| 89 |
|
|
|
|
| 145 |
num_steps: int = 25,
|
| 146 |
output_path: str = "./output.wav",
|
| 147 |
):
|
| 148 |
+
|
| 149 |
+
from montreal_forced_aligner.g2p.generator import PyniniConsoleGenerator
|
| 150 |
|
| 151 |
self.init_speaker_model()
|
| 152 |
|
| 153 |
if not self.g2p:
|
| 154 |
+
self.g2p = PyniniConsoleGenerator(
|
| 155 |
+
g2p_model_path=self.model.g2p_model_path,
|
| 156 |
+
strict_graphemes=False,
|
| 157 |
+
num_pronunciations=1,
|
| 158 |
+
include_bracketed=False
|
| 159 |
+
)
|
| 160 |
+
self.g2p.setup()
|
| 161 |
+
|
| 162 |
+
if not self.word2phone:
|
| 163 |
+
self.word2phone = json.load(
|
| 164 |
+
open(
|
| 165 |
+
self.model.tts_word2phone_dict_path, "r", encoding="utf-8"
|
| 166 |
)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# OOV word will use a g2p model to predict phoneme
|
| 170 |
+
phonemes, OOV_list = sentence_to_phones(
|
| 171 |
+
transcript, self.word2phone, self.g2p
|
| 172 |
+
)
|
| 173 |
|
| 174 |
+
# print(phonemes)
|
|
|
|
| 175 |
phone_indices = [
|
| 176 |
self.model.tts_phone2id.get(
|
| 177 |
p, self.model.tts_phone2id.get("spn", 0)
|
|
|
|
| 382 |
)
|
| 383 |
waveform = waveform[0, 0].cpu().numpy()
|
| 384 |
|
| 385 |
+
output_path = output_path.__str__()
|
| 386 |
if not output_path.endswith(".mp4"):
|
| 387 |
sf.write(output_path, waveform, self.sample_rate)
|
| 388 |
|
modeling_uniflow_audio.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from typing import Any, Sequence
|
| 2 |
from pathlib import Path
|
| 3 |
import json
|
|
|
|
| 4 |
import shutil
|
| 5 |
|
| 6 |
import h5py
|
|
@@ -28,6 +29,17 @@ class UniFlowAudioModel(nn.Module):
|
|
| 28 |
self.config["model"]["autoencoder"]["pretrained_ckpt"] = str(
|
| 29 |
model_dir / self.config["model"]["autoencoder"]["pretrained_ckpt"]
|
| 30 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.model = hydra.utils.instantiate(
|
| 32 |
self.config["model"], _convert_="all"
|
| 33 |
)
|
|
@@ -42,6 +54,7 @@ class UniFlowAudioModel(nn.Module):
|
|
| 42 |
shutil.copy(ori_model_path, self.g2p_model_path)
|
| 43 |
|
| 44 |
self.tts_phone_set_path = model_dir / "mfa_g2p" / "phone_set.json"
|
|
|
|
| 45 |
self.build_tts_phone_mapping()
|
| 46 |
self.svs_phone_set_path = model_dir / "svs" / "phone_set.json"
|
| 47 |
singers = json.load(open(model_dir / "svs" / "spk_set.json", "r"))
|
|
@@ -65,12 +78,20 @@ class UniFlowAudioModel(nn.Module):
|
|
| 65 |
self.tts_phone2id = {p: i for i, p in enumerate(phone_set)}
|
| 66 |
|
| 67 |
def init_instruction_encoder(self):
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
self.instruction_encoder.eval()
|
| 75 |
|
| 76 |
@torch.inference_mode()
|
|
|
|
| 1 |
from typing import Any, Sequence
|
| 2 |
from pathlib import Path
|
| 3 |
import json
|
| 4 |
+
import os
|
| 5 |
import shutil
|
| 6 |
|
| 7 |
import h5py
|
|
|
|
| 29 |
self.config["model"]["autoencoder"]["pretrained_ckpt"] = str(
|
| 30 |
model_dir / self.config["model"]["autoencoder"]["pretrained_ckpt"]
|
| 31 |
)
|
| 32 |
+
flan_t5_path = os.environ.get("FLAN_T5_PATH", "google/flan-t5-large")
|
| 33 |
+
try:
|
| 34 |
+
tokenizer = T5Tokenizer.from_pretrained(flan_t5_path)
|
| 35 |
+
encoder = T5EncoderModel.from_pretrained(flan_t5_path)
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise RuntimeError(
|
| 38 |
+
"Failed to initialize Flan-T5, please download it manually and set the `FLAN_T5_PATH`"
|
| 39 |
+
"environment variable to the path of the downloaded model."
|
| 40 |
+
) from e
|
| 41 |
+
self.config["model"]["content_encoder"]["text_encoder"]["model_name"
|
| 42 |
+
] = flan_t5_path
|
| 43 |
self.model = hydra.utils.instantiate(
|
| 44 |
self.config["model"], _convert_="all"
|
| 45 |
)
|
|
|
|
| 54 |
shutil.copy(ori_model_path, self.g2p_model_path)
|
| 55 |
|
| 56 |
self.tts_phone_set_path = model_dir / "mfa_g2p" / "phone_set.json"
|
| 57 |
+
self.tts_word2phone_dict_path = model_dir / "mfa_g2p" / "word2phone.json"
|
| 58 |
self.build_tts_phone_mapping()
|
| 59 |
self.svs_phone_set_path = model_dir / "svs" / "phone_set.json"
|
| 60 |
singers = json.load(open(model_dir / "svs" / "spk_set.json", "r"))
|
|
|
|
| 78 |
self.tts_phone2id = {p: i for i, p in enumerate(phone_set)}
|
| 79 |
|
| 80 |
def init_instruction_encoder(self):
|
| 81 |
+
flan_t5_path = os.environ.get("FLAN_T5_PATH", "google/flan-t5-large")
|
| 82 |
+
try:
|
| 83 |
+
self.instruction_tokenizer = T5Tokenizer.from_pretrained(
|
| 84 |
+
flan_t5_path
|
| 85 |
+
)
|
| 86 |
+
self.instruction_encoder = T5EncoderModel.from_pretrained(
|
| 87 |
+
flan_t5_path
|
| 88 |
+
)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
raise RuntimeError(
|
| 91 |
+
"Failed to initialize Flan-T5, please download it manually and set the `FLAN_T5_PATH`"
|
| 92 |
+
"environment variable to the path of the downloaded model."
|
| 93 |
+
) from e
|
| 94 |
+
|
| 95 |
self.instruction_encoder.eval()
|
| 96 |
|
| 97 |
@torch.inference_mode()
|
utils/phonemize.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def g2p_resolve(word, g2p_model):
|
| 5 |
+
"""Call G2P to generate pronunciation (used for handling OOV words)."""
|
| 6 |
+
try:
|
| 7 |
+
result = g2p_model.rewriter(word.lower())
|
| 8 |
+
if result and result[0][0]:
|
| 9 |
+
return result[0][0].split()
|
| 10 |
+
except Exception:
|
| 11 |
+
return None
|
| 12 |
+
return None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def text_norm(s):
|
| 16 |
+
"""
|
| 17 |
+
Text normalization (keep internal apostrophes like don't, it's; remove quote-like apostrophes and other punctuation):
|
| 18 |
+
1. Lowercase the text
|
| 19 |
+
2. Keep apostrophes between letters (e.g. don't)
|
| 20 |
+
3. Remove apostrophes that are not between letters (used as quotes or standalone)
|
| 21 |
+
4. Remove other common punctuation marks (.,;!?()[]-"ββ etc.)
|
| 22 |
+
5. Collapse multiple spaces into a single space
|
| 23 |
+
"""
|
| 24 |
+
s = s.lower()
|
| 25 |
+
|
| 26 |
+
# First temporarily replace apostrophes between letters (a'b) with a placeholder to avoid deletion
|
| 27 |
+
# Support both ASCII ' and Unicode β, β
|
| 28 |
+
APOST = "<<<APOST>>>" # Placeholder string (ensured not to appear in normal sentences)
|
| 29 |
+
s = re.sub(r"(?<=[A-Za-z0-9])['\u2019\u2018](?=[A-Za-z0-9])", APOST, s)
|
| 30 |
+
|
| 31 |
+
# Remove all remaining apostrophes (these are quotes or isolated marks)
|
| 32 |
+
s = re.sub(r"['\u2019\u2018]", " ", s)
|
| 33 |
+
|
| 34 |
+
# Remove other punctuation (while keeping internal apostrophes protected by the placeholder)
|
| 35 |
+
s = re.sub(r"[,\.\!\?\;\:\(\)\[\]\"ββ\-]", " ", s)
|
| 36 |
+
|
| 37 |
+
# Restore internal apostrophes back to ASCII apostrophe (or to the original character if needed)
|
| 38 |
+
s = s.replace(APOST, "'")
|
| 39 |
+
|
| 40 |
+
# Merge extra spaces
|
| 41 |
+
s = " ".join(s.split())
|
| 42 |
+
|
| 43 |
+
return s
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------- Core conversion ----------------
|
| 47 |
+
def sentence_to_phones(sentence, word2phones, g2p_model):
|
| 48 |
+
"""
|
| 49 |
+
Convert sentence to phones:
|
| 50 |
+
1. Split the original sentence and keep punctuation positions to insert sil later
|
| 51 |
+
2. Insert sil at punctuation positions
|
| 52 |
+
3. Add sil at the beginning and end of the sentence
|
| 53 |
+
"""
|
| 54 |
+
original_sentence = sentence # Save the original sentence
|
| 55 |
+
sentence = text_norm(sentence)
|
| 56 |
+
|
| 57 |
+
phone_sequence = ["sil"] # Initial silence
|
| 58 |
+
oov_list = []
|
| 59 |
+
|
| 60 |
+
# Split the original sentence to locate punctuation positions
|
| 61 |
+
|
| 62 |
+
tokens = re.findall(r"[A-Za-z]+(?:'[A-Za-z]+)?|[.,;!?]", original_sentence)
|
| 63 |
+
|
| 64 |
+
for token in tokens:
|
| 65 |
+
if re.match(r"[.,;!?]", token): # Punctuation
|
| 66 |
+
phone_sequence.append("sil")
|
| 67 |
+
else:
|
| 68 |
+
word = text_norm(token) # Normalize word
|
| 69 |
+
|
| 70 |
+
if word not in word2phones:
|
| 71 |
+
g2p_ph = g2p_resolve(word, g2p_model)
|
| 72 |
+
if g2p_ph:
|
| 73 |
+
phone_sequence.extend(g2p_ph)
|
| 74 |
+
else:
|
| 75 |
+
phone_sequence.append(
|
| 76 |
+
"spn"
|
| 77 |
+
) # If it really cannot be handled, use a short pause
|
| 78 |
+
|
| 79 |
+
oov_list.append(word)
|
| 80 |
+
|
| 81 |
+
else:
|
| 82 |
+
pron, _ = max(word2phones[word].items(), key=lambda x: x[1])
|
| 83 |
+
phone_sequence.extend(pron.split())
|
| 84 |
+
|
| 85 |
+
if phone_sequence[-1] != 'sil':
|
| 86 |
+
phone_sequence.append("sil") # Ending silence
|
| 87 |
+
return phone_sequence, oov_list
|