File size: 4,780 Bytes
0998dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""Long-text synthesis with automatic chunking. The acoustic model degrades past ~max_frames
(1400 ~ 15s) because the absolute positional encoding saturates -> garbled syllables in the back
half of long utterances. Fix: split text at punctuation into clauses, greedily pack clauses so each
chunk's PREDICTED frame count stays under a safe budget, synth each chunk, concatenate with a short
gap. No retrain. Same pipeline as synth_from_text.py."""
import argparse, json, re, sys
from pathlib import Path
sys.path.insert(0, "/home/luigi/jetson-tts/mossnano/zhtw8k")
import numpy as np, soundfile as sf, onnxruntime as ort
import frontend_bopomofo as F
from synth_from_text import host_regulate

BN = ["frames", "frame_meta", "local_ctx_raw", "abs_pos", "pitch_frame", "frame_mask"]
# split AFTER these (keep the delimiter with the preceding clause for prosody)
SPLIT_RE = re.compile(r'(?<=[。!?;;!?\n,,、])')


def clauses(text):
    parts = [p for p in SPLIT_RE.split(text) if p.strip()]
    return parts or [text]


class LongSynth:
    def __init__(self, onnx_dir, frame_budget=1100, gap_ms=70):
        self.meta = json.load(open(f"{onnx_dir}/meta.json"))
        so = ort.SessionOptions(); so.intra_op_num_threads = 4
        self.sA = ort.InferenceSession(f"{onnx_dir}/acoustic_encoder.onnx", so, providers=["CPUExecutionProvider"])
        self.sB = ort.InferenceSession(f"{onnx_dir}/acoustic_decoder.onnx", so, providers=["CPUExecutionProvider"])
        self.sV = ort.InferenceSession(f"{onnx_dir}/vocoder.onnx", so, providers=["CPUExecutionProvider"])
        self.sr = self.meta["sample_rate"]
        self.budget = frame_budget
        self.gap = np.zeros(int(self.sr * gap_ms / 1000), np.float32)

    def _encode(self, text):
        o = F.text_to_ids(text)
        if not o["phone_ids"]:
            return None
        phone = np.array([o["phone_ids"]], np.int64); tone = np.array([o["tone_ids"]], np.int64)
        lang = np.array([o["lang_ids"]], np.int64); spk = np.zeros(1, np.int64)
        cond, dur, pitch = self.sA.run(None, {"phone": phone, "tone": tone, "lang": lang, "speaker": spk})
        return cond, dur, pitch

    def _decode(self, enc):
        cond, dur, pitch = enc
        reg = host_regulate(cond, dur, pitch, self.meta["abs_frame_bins"], self.meta["max_frames"])
        feeds = {n: (reg[n].astype(np.float32) if reg[n].dtype != bool else reg[n]) for n in BN}
        feeds["abs_pos"] = reg["abs_pos"].astype(np.int64)
        mel = self.sB.run(None, feeds)[0]
        return self.sV.run(None, {"mel": mel.astype(np.float32)})[0].reshape(-1)

    def pack(self, text):
        """Greedily pack clauses into chunks whose predicted frame total <= budget."""
        chunks, cur, cur_frames = [], "", 0
        for cl in clauses(text):
            enc = self._encode(cl)
            f = int(enc[1].sum()) if enc is not None else 0
            if cur and cur_frames + f > self.budget:
                chunks.append(cur); cur, cur_frames = "", 0
            cur += cl; cur_frames += f
            # a single clause already over budget: still emit it alone (rare)
            if cur_frames > self.budget and cur == cl:
                chunks.append(cur); cur, cur_frames = "", 0
        if cur:
            chunks.append(cur)
        return chunks

    def synth(self, text):
        chs = self.pack(text)
        wavs = []
        for i, c in enumerate(chs):
            enc = self._encode(c)
            if enc is None:
                continue
            wavs.append(self._decode(enc))
            if i < len(chs) - 1:
                wavs.append(self.gap)
        return (np.concatenate(wavs) if wavs else np.zeros(1, np.float32)), chs


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--onnx-dir", required=True)
    ap.add_argument("--out-dir", required=True)
    ap.add_argument("--texts", required=True, help="jsonl with {id,text}")
    ap.add_argument("--frame-budget", type=int, default=1100)
    ap.add_argument("--gap-ms", type=int, default=70)
    a = ap.parse_args()
    ls = LongSynth(a.onnx_dir, a.frame_budget, a.gap_ms)
    Path(a.out_dir).mkdir(parents=True, exist_ok=True)
    man = open(f"{a.out_dir}/synth.jsonl", "w")
    for r in (json.loads(l) for l in open(a.texts) if l.strip()):
        wav, chs = ls.synth(r["text"])
        wp = f"{a.out_dir}/{r['id']}.wav"; sf.write(wp, wav, ls.sr)
        man.write(json.dumps({"id": r["id"], "text": r["text"], "wav": wp, "chunks": len(chs),
                              "dur": round(len(wav)/ls.sr, 2)}, ensure_ascii=False) + "\n")
        print(f"  {r['id']}: {len(chs)} chunks, {len(wav)/ls.sr:.1f}s -> {wp}")
    man.close()
    print(f"DONE synth_long -> {a.out_dir}/synth.jsonl")


if __name__ == "__main__":
    main()