wsntxxn commited on
Commit
37f392f
Β·
1 Parent(s): b929107

Fix g2p issue

Browse files
Files changed (3) hide show
  1. inference_cli.py +26 -13
  2. modeling_uniflow_audio.py +27 -6
  3. utils/phonemize.py +87 -0
inference_cli.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  from typing import Any, Callable
4
  import json
5
- import os
6
 
7
  import fire
8
  import torch
@@ -11,7 +10,8 @@ import soundfile as sf
11
  import numpy as np
12
 
13
  from modeling_uniflow_audio import UniFlowAudioModel
14
- from constants import TIME_ALIGNED_TASKS, NON_TIME_ALIGNED_TASKS
 
15
 
16
 
17
  class InferenceCLI:
@@ -21,6 +21,7 @@ class InferenceCLI:
21
  "cuda" if torch.cuda.is_available() else "cpu"
22
  )
23
  self.g2p = None
 
24
  self.speaker_model = None
25
  self.svs_processor = None
26
  self.singer_mapping = None
@@ -82,7 +83,7 @@ class InferenceCLI:
82
  @staticmethod
83
  def add_prehook(func: Callable, ):
84
  def wrapper(self, *args, **kwargs):
85
- model_name = kwargs["model_name"]
86
  self.on_inference_start(model_name)
87
  return func(self, *args, **kwargs)
88
 
@@ -144,22 +145,33 @@ class InferenceCLI:
144
  num_steps: int = 25,
145
  output_path: str = "./output.wav",
146
  ):
147
- from g2p_en import G2p
148
- import nltk
149
 
150
  self.init_speaker_model()
151
 
152
  if not self.g2p:
153
- if not os.path.exists(
154
- os.path.expanduser(
155
- "~/nltk_data/taggers/averaged_perceptron_tagger_eng"
 
 
 
 
 
 
 
 
 
156
  )
157
- ):
158
- nltk.download("averaged_perceptron_tagger_eng")
159
- self.g2p = G2p()
 
 
 
160
 
161
- phonemes = self.g2p(transcript)
162
- phonemes = [ph for ph in phonemes if ph != " "]
163
  phone_indices = [
164
  self.model.tts_phone2id.get(
165
  p, self.model.tts_phone2id.get("spn", 0)
@@ -370,6 +382,7 @@ class InferenceCLI:
370
  )
371
  waveform = waveform[0, 0].cpu().numpy()
372
 
 
373
  if not output_path.endswith(".mp4"):
374
  sf.write(output_path, waveform, self.sample_rate)
375
 
 
2
 
3
  from typing import Any, Callable
4
  import json
 
5
 
6
  import fire
7
  import torch
 
10
  import numpy as np
11
 
12
  from modeling_uniflow_audio import UniFlowAudioModel
13
+ from constants import TIME_ALIGNED_TASKS
14
+ from utils.phonemize import sentence_to_phones
15
 
16
 
17
  class InferenceCLI:
 
21
  "cuda" if torch.cuda.is_available() else "cpu"
22
  )
23
  self.g2p = None
24
+ self.word2phone = None
25
  self.speaker_model = None
26
  self.svs_processor = None
27
  self.singer_mapping = None
 
83
  @staticmethod
84
  def add_prehook(func: Callable, ):
85
  def wrapper(self, *args, **kwargs):
86
+ model_name = kwargs.get("model_name", "UniFlow-Audio-large")
87
  self.on_inference_start(model_name)
88
  return func(self, *args, **kwargs)
89
 
 
145
  num_steps: int = 25,
146
  output_path: str = "./output.wav",
147
  ):
148
+
149
+ from montreal_forced_aligner.g2p.generator import PyniniConsoleGenerator
150
 
151
  self.init_speaker_model()
152
 
153
  if not self.g2p:
154
+ self.g2p = PyniniConsoleGenerator(
155
+ g2p_model_path=self.model.g2p_model_path,
156
+ strict_graphemes=False,
157
+ num_pronunciations=1,
158
+ include_bracketed=False
159
+ )
160
+ self.g2p.setup()
161
+
162
+ if not self.word2phone:
163
+ self.word2phone = json.load(
164
+ open(
165
+ self.model.tts_word2phone_dict_path, "r", encoding="utf-8"
166
  )
167
+ )
168
+
169
+ # OOV word will use a g2p model to predict phoneme
170
+ phonemes, OOV_list = sentence_to_phones(
171
+ transcript, self.word2phone, self.g2p
172
+ )
173
 
174
+ # print(phonemes)
 
175
  phone_indices = [
176
  self.model.tts_phone2id.get(
177
  p, self.model.tts_phone2id.get("spn", 0)
 
382
  )
383
  waveform = waveform[0, 0].cpu().numpy()
384
 
385
+ output_path = output_path.__str__()
386
  if not output_path.endswith(".mp4"):
387
  sf.write(output_path, waveform, self.sample_rate)
388
 
modeling_uniflow_audio.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Any, Sequence
2
  from pathlib import Path
3
  import json
 
4
  import shutil
5
 
6
  import h5py
@@ -28,6 +29,17 @@ class UniFlowAudioModel(nn.Module):
28
  self.config["model"]["autoencoder"]["pretrained_ckpt"] = str(
29
  model_dir / self.config["model"]["autoencoder"]["pretrained_ckpt"]
30
  )
 
 
 
 
 
 
 
 
 
 
 
31
  self.model = hydra.utils.instantiate(
32
  self.config["model"], _convert_="all"
33
  )
@@ -42,6 +54,7 @@ class UniFlowAudioModel(nn.Module):
42
  shutil.copy(ori_model_path, self.g2p_model_path)
43
 
44
  self.tts_phone_set_path = model_dir / "mfa_g2p" / "phone_set.json"
 
45
  self.build_tts_phone_mapping()
46
  self.svs_phone_set_path = model_dir / "svs" / "phone_set.json"
47
  singers = json.load(open(model_dir / "svs" / "spk_set.json", "r"))
@@ -65,12 +78,20 @@ class UniFlowAudioModel(nn.Module):
65
  self.tts_phone2id = {p: i for i, p in enumerate(phone_set)}
66
 
67
  def init_instruction_encoder(self):
68
- self.instruction_tokenizer = T5Tokenizer.from_pretrained(
69
- "google/flan-t5-large"
70
- )
71
- self.instruction_encoder = T5EncoderModel.from_pretrained(
72
- "google/flan-t5-large"
73
- )
 
 
 
 
 
 
 
 
74
  self.instruction_encoder.eval()
75
 
76
  @torch.inference_mode()
 
1
  from typing import Any, Sequence
2
  from pathlib import Path
3
  import json
4
+ import os
5
  import shutil
6
 
7
  import h5py
 
29
  self.config["model"]["autoencoder"]["pretrained_ckpt"] = str(
30
  model_dir / self.config["model"]["autoencoder"]["pretrained_ckpt"]
31
  )
32
+ flan_t5_path = os.environ.get("FLAN_T5_PATH", "google/flan-t5-large")
33
+ try:
34
+ tokenizer = T5Tokenizer.from_pretrained(flan_t5_path)
35
+ encoder = T5EncoderModel.from_pretrained(flan_t5_path)
36
+ except Exception as e:
37
+ raise RuntimeError(
38
+ "Failed to initialize Flan-T5, please download it manually and set the `FLAN_T5_PATH`"
39
+ "environment variable to the path of the downloaded model."
40
+ ) from e
41
+ self.config["model"]["content_encoder"]["text_encoder"]["model_name"
42
+ ] = flan_t5_path
43
  self.model = hydra.utils.instantiate(
44
  self.config["model"], _convert_="all"
45
  )
 
54
  shutil.copy(ori_model_path, self.g2p_model_path)
55
 
56
  self.tts_phone_set_path = model_dir / "mfa_g2p" / "phone_set.json"
57
+ self.tts_word2phone_dict_path = model_dir / "mfa_g2p" / "word2phone.json"
58
  self.build_tts_phone_mapping()
59
  self.svs_phone_set_path = model_dir / "svs" / "phone_set.json"
60
  singers = json.load(open(model_dir / "svs" / "spk_set.json", "r"))
 
78
  self.tts_phone2id = {p: i for i, p in enumerate(phone_set)}
79
 
80
  def init_instruction_encoder(self):
81
+ flan_t5_path = os.environ.get("FLAN_T5_PATH", "google/flan-t5-large")
82
+ try:
83
+ self.instruction_tokenizer = T5Tokenizer.from_pretrained(
84
+ flan_t5_path
85
+ )
86
+ self.instruction_encoder = T5EncoderModel.from_pretrained(
87
+ flan_t5_path
88
+ )
89
+ except Exception as e:
90
+ raise RuntimeError(
91
+ "Failed to initialize Flan-T5, please download it manually and set the `FLAN_T5_PATH`"
92
+ "environment variable to the path of the downloaded model."
93
+ ) from e
94
+
95
  self.instruction_encoder.eval()
96
 
97
  @torch.inference_mode()
utils/phonemize.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def g2p_resolve(word, g2p_model):
5
+ """Call G2P to generate pronunciation (used for handling OOV words)."""
6
+ try:
7
+ result = g2p_model.rewriter(word.lower())
8
+ if result and result[0][0]:
9
+ return result[0][0].split()
10
+ except Exception:
11
+ return None
12
+ return None
13
+
14
+
15
+ def text_norm(s):
16
+ """
17
+ Text normalization (keep internal apostrophes like don't, it's; remove quote-like apostrophes and other punctuation):
18
+ 1. Lowercase the text
19
+ 2. Keep apostrophes between letters (e.g. don't)
20
+ 3. Remove apostrophes that are not between letters (used as quotes or standalone)
21
+ 4. Remove other common punctuation marks (.,;!?()[]-"β€œβ€ etc.)
22
+ 5. Collapse multiple spaces into a single space
23
+ """
24
+ s = s.lower()
25
+
26
+ # First temporarily replace apostrophes between letters (a'b) with a placeholder to avoid deletion
27
+ # Support both ASCII ' and Unicode ’, β€˜
28
+ APOST = "<<<APOST>>>" # Placeholder string (ensured not to appear in normal sentences)
29
+ s = re.sub(r"(?<=[A-Za-z0-9])['\u2019\u2018](?=[A-Za-z0-9])", APOST, s)
30
+
31
+ # Remove all remaining apostrophes (these are quotes or isolated marks)
32
+ s = re.sub(r"['\u2019\u2018]", " ", s)
33
+
34
+ # Remove other punctuation (while keeping internal apostrophes protected by the placeholder)
35
+ s = re.sub(r"[,\.\!\?\;\:\(\)\[\]\"β€œβ€\-]", " ", s)
36
+
37
+ # Restore internal apostrophes back to ASCII apostrophe (or to the original character if needed)
38
+ s = s.replace(APOST, "'")
39
+
40
+ # Merge extra spaces
41
+ s = " ".join(s.split())
42
+
43
+ return s
44
+
45
+
46
+ # ---------------- Core conversion ----------------
47
+ def sentence_to_phones(sentence, word2phones, g2p_model):
48
+ """
49
+ Convert sentence to phones:
50
+ 1. Split the original sentence and keep punctuation positions to insert sil later
51
+ 2. Insert sil at punctuation positions
52
+ 3. Add sil at the beginning and end of the sentence
53
+ """
54
+ original_sentence = sentence # Save the original sentence
55
+ sentence = text_norm(sentence)
56
+
57
+ phone_sequence = ["sil"] # Initial silence
58
+ oov_list = []
59
+
60
+ # Split the original sentence to locate punctuation positions
61
+
62
+ tokens = re.findall(r"[A-Za-z]+(?:'[A-Za-z]+)?|[.,;!?]", original_sentence)
63
+
64
+ for token in tokens:
65
+ if re.match(r"[.,;!?]", token): # Punctuation
66
+ phone_sequence.append("sil")
67
+ else:
68
+ word = text_norm(token) # Normalize word
69
+
70
+ if word not in word2phones:
71
+ g2p_ph = g2p_resolve(word, g2p_model)
72
+ if g2p_ph:
73
+ phone_sequence.extend(g2p_ph)
74
+ else:
75
+ phone_sequence.append(
76
+ "spn"
77
+ ) # If it really cannot be handled, use a short pause
78
+
79
+ oov_list.append(word)
80
+
81
+ else:
82
+ pron, _ = max(word2phones[word].items(), key=lambda x: x[1])
83
+ phone_sequence.extend(pron.split())
84
+
85
+ if phone_sequence[-1] != 'sil':
86
+ phone_sequence.append("sil") # Ending silence
87
+ return phone_sequence, oov_list