| |
| """Long-text synthesis with automatic chunking. The acoustic model degrades past ~max_frames |
| (1400 ~ 15s) because the absolute positional encoding saturates -> garbled syllables in the back |
| half of long utterances. Fix: split text at punctuation into clauses, greedily pack clauses so each |
| chunk's PREDICTED frame count stays under a safe budget, synth each chunk, concatenate with a short |
| gap. No retrain. Same pipeline as synth_from_text.py.""" |
| import argparse, json, re, sys |
| from pathlib import Path |
| sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k") |
| import numpy as np, soundfile as sf, onnxruntime as ort |
| import frontend_bopomofo as F |
| from synth_from_text import host_regulate |
|
|
| BN = ["frames", "frame_meta", "local_ctx_raw", "abs_pos", "pitch_frame", "frame_mask"] |
| |
| SPLIT_RE = re.compile(r'(?<=[。!?;;!?\n,,、])') |
|
|
|
|
| def clauses(text): |
| parts = [p for p in SPLIT_RE.split(text) if p.strip()] |
| return parts or [text] |
|
|
|
|
| class LongSynth: |
| def __init__(self, onnx_dir, frame_budget=1100, gap_ms=70): |
| self.meta = json.load(open(f"{onnx_dir}/meta.json")) |
| so = ort.SessionOptions(); so.intra_op_num_threads = 4 |
| self.sA = ort.InferenceSession(f"{onnx_dir}/acoustic_encoder.onnx", so, providers=["CPUExecutionProvider"]) |
| self.sB = ort.InferenceSession(f"{onnx_dir}/acoustic_decoder.onnx", so, providers=["CPUExecutionProvider"]) |
| self.sV = ort.InferenceSession(f"{onnx_dir}/vocoder.onnx", so, providers=["CPUExecutionProvider"]) |
| self.sr = self.meta["sample_rate"] |
| self.budget = frame_budget |
| self.gap = np.zeros(int(self.sr * gap_ms / 1000), np.float32) |
|
|
| def _encode(self, text): |
| o = F.text_to_ids(text) |
| if not o["phone_ids"]: |
| return None |
| phone = np.array([o["phone_ids"]], np.int64); tone = np.array([o["tone_ids"]], np.int64) |
| lang = np.array([o["lang_ids"]], np.int64); spk = np.zeros(1, np.int64) |
| cond, dur, pitch = self.sA.run(None, {"phone": phone, "tone": tone, "lang": lang, "speaker": spk}) |
| return cond, dur, pitch |
|
|
| def _decode(self, enc): |
| cond, dur, pitch = enc |
| reg = host_regulate(cond, dur, pitch, self.meta["abs_frame_bins"], self.meta["max_frames"]) |
| feeds = {n: (reg[n].astype(np.float32) if reg[n].dtype != bool else reg[n]) for n in BN} |
| feeds["abs_pos"] = reg["abs_pos"].astype(np.int64) |
| mel = self.sB.run(None, feeds)[0] |
| return self.sV.run(None, {"mel": mel.astype(np.float32)})[0].reshape(-1) |
|
|
| def pack(self, text): |
| """Greedily pack clauses into chunks whose predicted frame total <= budget.""" |
| chunks, cur, cur_frames = [], "", 0 |
| for cl in clauses(text): |
| enc = self._encode(cl) |
| f = int(enc[1].sum()) if enc is not None else 0 |
| if cur and cur_frames + f > self.budget: |
| chunks.append(cur); cur, cur_frames = "", 0 |
| cur += cl; cur_frames += f |
| |
| if cur_frames > self.budget and cur == cl: |
| chunks.append(cur); cur, cur_frames = "", 0 |
| if cur: |
| chunks.append(cur) |
| return chunks |
|
|
| def synth(self, text): |
| chs = self.pack(text) |
| wavs = [] |
| for i, c in enumerate(chs): |
| enc = self._encode(c) |
| if enc is None: |
| continue |
| wavs.append(self._decode(enc)) |
| if i < len(chs) - 1: |
| wavs.append(self.gap) |
| return (np.concatenate(wavs) if wavs else np.zeros(1, np.float32)), chs |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--onnx-dir", required=True) |
| ap.add_argument("--out-dir", required=True) |
| ap.add_argument("--texts", required=True, help="jsonl with {id,text}") |
| ap.add_argument("--frame-budget", type=int, default=1100) |
| ap.add_argument("--gap-ms", type=int, default=70) |
| a = ap.parse_args() |
| ls = LongSynth(a.onnx_dir, a.frame_budget, a.gap_ms) |
| Path(a.out_dir).mkdir(parents=True, exist_ok=True) |
| man = open(f"{a.out_dir}/synth.jsonl", "w") |
| for r in (json.loads(l) for l in open(a.texts) if l.strip()): |
| wav, chs = ls.synth(r["text"]) |
| wp = f"{a.out_dir}/{r['id']}.wav"; sf.write(wp, wav, ls.sr) |
| man.write(json.dumps({"id": r["id"], "text": r["text"], "wav": wp, "chunks": len(chs), |
| "dur": round(len(wav)/ls.sr, 2)}, ensure_ascii=False) + "\n") |
| print(f" {r['id']}: {len(chs)} chunks, {len(wav)/ls.sr:.1f}s -> {wp}") |
| man.close() |
| print(f"DONE synth_long -> {a.out_dir}/synth.jsonl") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|