import math import random from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import reduce from itertools import accumulate from random import choices from typing import List, Optional, Sequence, Tuple import pytorch_lightning as ptl import torch from datasets import (DatasetDict, concatenate_datasets, load_dataset, load_from_disk) from einops import rearrange from torch.nn.utils.rnn import pad_sequence from torch.utils.data import (BatchSampler, DataLoader, Sampler, SubsetRandomSampler) from transformers import PreTrainedTokenizerFast from tts.tools import (audio_to_text_partial_neighbor_mask, packmask_2d, pad_2d_sequence, sequence_mask) class BucketSampler(Sampler[List[int]]): def __init__( self, buckets: List[List[int]], batch_sizes: List[int] | int, bucket_sampling_weights: List[Tuple[float]] = None, drop_last: bool = True, distributed: bool = True, # TODO - implement not distributed as well sample_bucket: Optional[int] = None, seed: int = 123, epoch_seed: bool = True, ): if type(batch_sizes) is int: batch_sizes = [batch_sizes] * len(buckets) else: assert len(buckets) == len(batch_sizes) if bucket_sampling_weights is not None: assert len(bucket_sampling_weights) == len(batch_sizes) self.bucket_sampling_weights = bucket_sampling_weights self.num_replicas = torch.distributed.get_world_size() self.rank = torch.distributed.get_rank() self.buckets = [ b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas] for b in buckets ] self.num_samples = [len(b) // self.num_replicas for b in buckets] self.batch_sizes = batch_sizes self.total_sizes = [ ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes) ] self.drop_last = drop_last self.seed = seed self.epoch = 0 self.sample_bucket = sample_bucket self.epoch_seed = epoch_seed self.batch_size = batch_sizes[0] def set_epoch(self, epoch: int): self.epoch = epoch def __len__(self): return sum(self.total_sizes) def __iter__(self): generator = torch.Generator() generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank) pool = [ BatchSampler( SubsetRandomSampler(b, generator=generator), bs, drop_last=self.drop_last, ) for b, bs in zip(self.buckets, self.batch_sizes) ] pool = [iter(b) for b in pool] weights = ( [w for w in self.bucket_sampling_weights] if self.bucket_sampling_weights is not None else None ) while pool: # sample until all buckets are done idx, bucket = choices(list(enumerate(pool)), weights=weights)[0] try: batch = next(bucket) yield batch except StopIteration: pool.pop(idx) # if bucket is done, throw it if weights is not None: weights.pop(idx) class DatasetFactory(ABC): @abstractmethod def build(self): pass class HiFiTTS2_AudioLatent(DatasetFactory): def __init__( self, path: str | list[str] = "hifitts2_vae8_dataset", duration_column: str = "audio_duration", duration_path: str | None = None, expresso_path: str | None = None, min_dur: float = 3.0, max_dur: float = 20.1, framerate: float = 25.0, ): self.min_dur = min_dur self.max_dur = max_dur self.path = path self.duration_column = duration_column self.duration_path = duration_path self.framerate = framerate self.expresso_path = expresso_path def build(self): if type(self.path) is str: self.path = [self.path] datasets = [load_from_disk(x) for x in self.path] dataset = concatenate_datasets(datasets).with_format("torch") if self.duration_path is not None: duration_dataset = load_from_disk(self.duration_path) dataset = concatenate_datasets( [dataset, duration_dataset], axis=1 ).with_format("torch") dataset = dataset.filter( lambda dur: dur > self.min_dur and dur < self.max_dur, input_columns=self.duration_column, ) dataset = dataset.rename_column(self.duration_column, "audio_duration") # dataset = dataset.map( # lambda x: {"audio_duration": x.shape[1] / self.framerate}, # input_columns="audio_latent", # ).filter( # lambda dur: dur > self.min_dur and dur < self.max_dur, # input_columns="audio_duration", # ) if self.expresso_path is not None: expresso_dataset = load_from_disk(self.expresso_path).with_format("torch") dataset = dataset.sort("audio_duration") return DatasetDict({"train": dataset}) @dataclass class SegmentsCollateArgs: abs_style_intensity: bool = False merge_endpoints: bool = True block_crossatt_mask: bool = True alternate_crossatt_pos: bool = False block_crossatt_past_tokens: int = 0 block_crossatt_future_tokens: int = 0 eos: bool = True bos: bool = True @dataclass class CollateArgs: abs_style_intensity: bool = False random_text_segment: bool = False eos: bool = True bos: bool = True num_stop_tokens: int = 1 def random_log_breakpoints( seq: Sequence, a: int, b: int, gap: bool = False ) -> List[int]: """ Generate random breakpoints in a sequence where the gap X between successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)]. Gaps are then rounded to the nearest integer in [a, b]. Parameters ---------- seq : Sequence The input sequence in which to place breakpoints. a : int Minimum gap (>= 1). b : int Maximum gap (>= a). Returns ------- List[int] Sorted list of breakpoint indices (0 < idx < len(seq)). """ if a < 1 or b < a: raise ValueError("Require 1 <= a <= b") n = len(seq) breakpoints: List[int] = [] pos = 0 while True: # sample U ~ Uniform(log2(a), log2(b)) u = random.uniform(math.log2(a), math.log2(b)) # map back to X = 2^U, then round to nearest integer x = 2**u gap = int(math.floor(x + 0.5)) # enforce integer bounds exactly gap = max(a, min(b, gap)) pos += gap if pos >= n: if gap: breakpoints.append(n - sum(breakpoints)) break if gap: breakpoints.append(gap) else: breakpoints.append(pos) return breakpoints def standalone_collate_latent( batch, tokenizer, abs_style_intensity: bool = False, random_text_segment: bool = False, bos: bool = True, eos: bool = True, num_stop_tokens: int = 1, ): audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch]) audio_latent = [x.squeeze() for x in audio_latent] text_pp = [] for t in text: if bos: t = "[BOS]" + t if eos: t = t + "[EOS]" text_pp.append(t) text_token = [torch.LongTensor(tokenizer.encode(x)) for x in text_pp] xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) stop_token = [] text_stop_token = [] for x, y in zip(xlen, ylen): tst = torch.zeros(x) st = torch.zeros(y) st_idx = random.randint(1, num_stop_tokens) st[-1] = st_idx tst[-1] = st_idx stop_token.append(st) text_stop_token.append(tst) stop_token = pad_sequence(stop_token, batch_first=True).long() text_stop_token = pad_sequence(text_stop_token, batch_first=True).long() x_mask, y_mask = map( lambda x: sequence_mask(x, device="cpu"), (torch.tensor(xlen), torch.tensor(ylen)), ) text_rel_pos = None if random_text_segment: breakpoints = [random_log_breakpoints(t, 32, 256, gap=True) for t in text_token] encoder_mask = pad_2d_sequence([packmask_2d(b, b) for b in breakpoints]) text_rel_pos = [torch.cat([torch.arange(bb) for bb in b]) for b in breakpoints] text_rel_pos = pad_sequence(text_rel_pos, batch_first=True) else: encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) audio_latent, text_token = map( lambda x: pad_sequence(x, batch_first=True, padding_value=0.0), (audio_latent, text_token), ) if abs_style_intensity: abs_style_intensity = [x["abs_style_intensity"] for x in batch] abs_style_intensity = [ torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity ] abs_style_intensity = torch.stack(abs_style_intensity) else: abs_style_intensity = None return { "text_token": text_token, "audio_token": audio_latent, "crossatt_mask": crossatt_mask, "encoder_mask": encoder_mask, "y_mask": y_mask, "stop_token": stop_token, "text_stop_token": text_stop_token, "x_len": xlen, "y_len": ylen, "abs_style_intensity": abs_style_intensity, "text_rel_pos": text_rel_pos, } def standalone_collate_latent_segments( batch, tokenizer, abs_style_intensity: bool = False, merge_endpoints: bool = True, block_crossatt_mask: bool = True, block_crossatt_past_tokens: int = 0, block_crossatt_future_tokens: int = 0, alternate_crossatt_pos: bool = False, alternate_crossatt_shift: int = 1000, eos: bool = True, bos: bool = True, ): audio_latent, text, token_duration = zip( *[(x["audio_latent"], x["text"], x["token_duration"]) for x in batch] ) text_pp = [] for t in text: if bos: t = "[BOS]" + t if eos: t = t + "[EOS]" text_pp.append(t) if merge_endpoints: tokens = [tokenizer.encode(x) for x in text] new_td = [] for td in token_duration: begin, end = td[0], td[-1] tdd = td[1:-1] tdd[0] += begin tdd[-1] += end new_td.append(tdd) token_duration = new_td else: tokens = [tokenizer.encode(x) for x in text_pp] segments = [ random_segments_from_text_and_durations(t, td.tolist()) for t, td in zip(tokens, token_duration) ] bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]")) audio_segments = [] text_segments = [] audio_segments_len = [] text_segments_len = [] for aud, seg in zip(audio_latent, segments): tt, at, tt_l, at_l = [], [], [], [] for i, s in enumerate(seg): ttoken = s["text_token"] if bos: ttoken = bos + ttoken if eos: ttoken = ttoken + eos tt.append(ttoken) a_s = aud[:, s["start"] : s["end"]] at.append(a_s) at_l.append(a_s.shape[1]) tt_l.append(len(ttoken)) audio_segments.append(at) text_segments.append(tt) audio_segments_len.append(at_l) text_segments_len.append(tt_l) text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments] audio_latent = [torch.cat(a_ss, dim=1).squeeze(0) for a_ss in audio_segments] xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) x_mask, y_mask = map( lambda x: sequence_mask(x, device="cpu"), (torch.tensor(xlen), torch.tensor(ylen)), ) audio_latent, text_token = map( lambda x: pad_sequence(x, batch_first=True, padding_value=0), (audio_latent, text_token), ) encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) if block_crossatt_mask: crossatt_mask = [ audio_to_text_partial_neighbor_mask( x, y, past_tokens=block_crossatt_past_tokens, future_tokens=block_crossatt_future_tokens, ) for x, y in zip(text_segments_len, audio_segments_len) ] crossatt_mask = pad_2d_sequence(crossatt_mask) pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange( torch.tensor(ylen), "n -> n 1 1" ) else: crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) text_rel_pos = pad_sequence( [torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len], batch_first=True, ) crossatt_rel_pos = None if alternate_crossatt_pos: crossatt_rel_pos = [] for tsl in text_segments_len: rel_pos = [] random_shift = int(random.random() < 0.5) for i, x in enumerate(tsl): rel_pos.append( torch.arange(x) + ((random_shift + i) % 2) * alternate_crossatt_shift ) crossatt_rel_pos.append(torch.cat(rel_pos)) crossatt_rel_pos = pad_sequence(crossatt_rel_pos, batch_first=True) audio_rel_pos = pad_sequence( [torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len], batch_first=True, ) stop_token = [] for asl in audio_segments_len: sts = [] for x in asl: st = torch.zeros(x) st[-1] = 1 sts.append(st) stop_token.append(torch.cat(sts)) stop_token = pad_sequence(stop_token, batch_first=True).int() text_stop_token = [] for asl in text_segments_len: sts = [] for x in asl: st = torch.zeros(x) st[-1] = 1 sts.append(st) text_stop_token.append(torch.cat(sts)) text_stop_token = pad_sequence(text_stop_token, batch_first=True).int() if abs_style_intensity: abs_style_intensity = [x["abs_style_intensity"] for x in batch] abs_style_intensity = [ torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity ] abs_style_intensity = torch.stack(abs_style_intensity) else: abs_style_intensity = None return { "text_token": text_token, "audio_token": audio_latent, "crossatt_mask": crossatt_mask, "encoder_mask": encoder_mask, "y_mask": y_mask, "stop_token": stop_token, "text_stop_token": text_stop_token, "x_mask": x_mask, "x_len": xlen, "y_len": ylen, "abs_style_intensity": abs_style_intensity, "text_rel_pos": text_rel_pos, "crossatt_rel_pos": crossatt_rel_pos, "audio_rel_pos": audio_rel_pos, "segments": segments, } def random_segments_from_text_and_durations( text, dur, low_bnd: int = 8, up_bnd: int = 384, ): b = random_log_breakpoints(text, low_bnd, up_bnd) bounds = [0] + b + [len(text)] segs, durs = [], [] for a, b in zip(bounds[:-1], bounds[1:]): segs.append(text[a:b]) durs.append(sum(dur[a:b])) bounds = [0] + list(accumulate(durs, int.__add__)) segs_dicts = [] for t, s, e in zip(segs, bounds[:-1], bounds[1:]): segs_dicts.append( { "start": s, "end": e, "text_token": t, } ) segs_dicts[-1]["end"] += 1 return segs_dicts class LinaDataModule(ptl.LightningDataModule): def __init__( self, path: str | DatasetFactory, quant_layer: list[int], train_batch_size: int = 8, token_by_batch: int | None = None, n_buckets=5, codec_rate_hz: int = 75, num_workers: int = 8, test_size: int = 2000, val_batch_size: int = 8, seed: int = 123, train_test_seed: int = 123, segments: bool = False, segments_args: SegmentsCollateArgs = field( default_factory=lambda: SegmentsCollateArgs() ), collate_args: CollateArgs = field(default_factory=lambda: CollateArgs()), block_mask_segments: bool = False, tokenizer_file=None, trail_end_frame: int | None = None, split="train", add_columns: str | list[str] | None = None, add_text_tokens: list[str] | None = None, type: str = "latent", ): super().__init__() self.path = path self.codec_rate_hz = codec_rate_hz self.num_workers = num_workers self.quant_layer = quant_layer self.seed = seed self.segments = segments self.segments_args = segments_args self.collate_args = collate_args self.train_test_seed = train_test_seed self.test_size = test_size self.val_batch_size = val_batch_size self.train_batch_size = train_batch_size self.split = split self.trail_end_frame = trail_end_frame self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) if add_text_tokens: self.tokenizer.add_tokens(add_text_tokens) self.add_columns = add_columns self.n_buckets = n_buckets self.token_by_batch = token_by_batch self.type = type def setup(self, stage): if isinstance(self.path, DatasetFactory): self.dataset = self.path.build() else: self.dataset = load_dataset(self.path) split = self.split columns = [ "audio_latent" if self.type == "latent" else "audio_token", "text", "audio_duration", ] if self.add_columns is not None: if type(self.add_columns) is str: self.add_columns = [self.add_columns] columns += self.add_columns if self.segments: columns += ["token_duration"] self.collate_fn = lambda x: segments_collate(x, self.tokenizer) else: self.collate_fn = lambda x: standalone_collate( x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns ) self.dataset = ( self.dataset[split] .train_test_split(test_size=self.test_size, seed=self.train_test_seed) .select_columns(columns) ) if self.type == "latent": if self.segments: self.collate_fn = lambda x: standalone_collate_latent_segments( x, self.tokenizer, **self.segments_args, ) else: self.collate_fn = lambda x: standalone_collate_latent( x, self.tokenizer, **self.collate_args, ) def get_buckets_by_quantile(duration, n_quantile, is_sorted=False): if is_sorted: size = len(duration) bucket_size = size // n_quantile buckets = [ list(range(i, min(i + bucket_size, size))) for i in range(0, size, bucket_size) ] else: idxdur = list(enumerate(duration)) idxdur.sort(key=lambda x: x[1]) idx, dur = zip(*idxdur) bucket_size = len(idx) // n_quantile buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)] return buckets if self.token_by_batch is not None: train_buckets = get_buckets_by_quantile( self.dataset["train"]["audio_duration"], self.n_buckets ) max_audio_durations = [ self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets ] batch_sizes = [ int(self.token_by_batch // (self.codec_rate_hz * ad)) for ad in max_audio_durations ] self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes) def train_dataloader(self): if self.token_by_batch is not None: return DataLoader( self.dataset["train"].with_format("torch"), num_workers=self.num_workers, collate_fn=self.collate_fn, batch_sampler=self.train_batch_sampler, ) else: return DataLoader( self.dataset["train"].with_format("torch"), num_workers=self.num_workers, batch_size=self.train_batch_size, collate_fn=self.collate_fn, ) def val_dataloader(self): return DataLoader( self.dataset["test"].with_format("torch"), batch_size=self.val_batch_size, num_workers=0, collate_fn=self.collate_fn, )