PrimeTTS / scripts /synth_long.py
Luigi's picture
add long-text chunking synth
0998dc1 verified
Raw
History Blame Contribute Delete
4.78 kB
#!/usr/bin/env python3
"""Long-text synthesis with automatic chunking. The acoustic model degrades past ~max_frames
(1400 ~ 15s) because the absolute positional encoding saturates -> garbled syllables in the back
half of long utterances. Fix: split text at punctuation into clauses, greedily pack clauses so each
chunk's PREDICTED frame count stays under a safe budget, synth each chunk, concatenate with a short
gap. No retrain. Same pipeline as synth_from_text.py."""
import argparse, json, re, sys
from pathlib import Path
sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k")
import numpy as np, soundfile as sf, onnxruntime as ort
import frontend_bopomofo as F
from synth_from_text import host_regulate
BN = ["frames", "frame_meta", "local_ctx_raw", "abs_pos", "pitch_frame", "frame_mask"]
# split AFTER these (keep the delimiter with the preceding clause for prosody)
SPLIT_RE = re.compile(r'(?<=[。!?;;!?\n,,、])')
def clauses(text):
parts = [p for p in SPLIT_RE.split(text) if p.strip()]
return parts or [text]
class LongSynth:
def __init__(self, onnx_dir, frame_budget=1100, gap_ms=70):
self.meta = json.load(open(f"{onnx_dir}/meta.json"))
so = ort.SessionOptions(); so.intra_op_num_threads = 4
self.sA = ort.InferenceSession(f"{onnx_dir}/acoustic_encoder.onnx", so, providers=["CPUExecutionProvider"])
self.sB = ort.InferenceSession(f"{onnx_dir}/acoustic_decoder.onnx", so, providers=["CPUExecutionProvider"])
self.sV = ort.InferenceSession(f"{onnx_dir}/vocoder.onnx", so, providers=["CPUExecutionProvider"])
self.sr = self.meta["sample_rate"]
self.budget = frame_budget
self.gap = np.zeros(int(self.sr * gap_ms / 1000), np.float32)
def _encode(self, text):
o = F.text_to_ids(text)
if not o["phone_ids"]:
return None
phone = np.array([o["phone_ids"]], np.int64); tone = np.array([o["tone_ids"]], np.int64)
lang = np.array([o["lang_ids"]], np.int64); spk = np.zeros(1, np.int64)
cond, dur, pitch = self.sA.run(None, {"phone": phone, "tone": tone, "lang": lang, "speaker": spk})
return cond, dur, pitch
def _decode(self, enc):
cond, dur, pitch = enc
reg = host_regulate(cond, dur, pitch, self.meta["abs_frame_bins"], self.meta["max_frames"])
feeds = {n: (reg[n].astype(np.float32) if reg[n].dtype != bool else reg[n]) for n in BN}
feeds["abs_pos"] = reg["abs_pos"].astype(np.int64)
mel = self.sB.run(None, feeds)[0]
return self.sV.run(None, {"mel": mel.astype(np.float32)})[0].reshape(-1)
def pack(self, text):
"""Greedily pack clauses into chunks whose predicted frame total <= budget."""
chunks, cur, cur_frames = [], "", 0
for cl in clauses(text):
enc = self._encode(cl)
f = int(enc[1].sum()) if enc is not None else 0
if cur and cur_frames + f > self.budget:
chunks.append(cur); cur, cur_frames = "", 0
cur += cl; cur_frames += f
# a single clause already over budget: still emit it alone (rare)
if cur_frames > self.budget and cur == cl:
chunks.append(cur); cur, cur_frames = "", 0
if cur:
chunks.append(cur)
return chunks
def synth(self, text):
chs = self.pack(text)
wavs = []
for i, c in enumerate(chs):
enc = self._encode(c)
if enc is None:
continue
wavs.append(self._decode(enc))
if i < len(chs) - 1:
wavs.append(self.gap)
return (np.concatenate(wavs) if wavs else np.zeros(1, np.float32)), chs
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--onnx-dir", required=True)
ap.add_argument("--out-dir", required=True)
ap.add_argument("--texts", required=True, help="jsonl with {id,text}")
ap.add_argument("--frame-budget", type=int, default=1100)
ap.add_argument("--gap-ms", type=int, default=70)
a = ap.parse_args()
ls = LongSynth(a.onnx_dir, a.frame_budget, a.gap_ms)
Path(a.out_dir).mkdir(parents=True, exist_ok=True)
man = open(f"{a.out_dir}/synth.jsonl", "w")
for r in (json.loads(l) for l in open(a.texts) if l.strip()):
wav, chs = ls.synth(r["text"])
wp = f"{a.out_dir}/{r['id']}.wav"; sf.write(wp, wav, ls.sr)
man.write(json.dumps({"id": r["id"], "text": r["text"], "wav": wp, "chunks": len(chs),
"dur": round(len(wav)/ls.sr, 2)}, ensure_ascii=False) + "\n")
print(f" {r['id']}: {len(chs)} chunks, {len(wav)/ls.sr:.1f}s -> {wp}")
man.close()
print(f"DONE synth_long -> {a.out_dir}/synth.jsonl")
if __name__ == "__main__":
main()