| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from pathlib import Path |
| import argparse |
| from tqdm import tqdm |
| from safetensors.torch import save_file, load_file |
| from collections import deque |
| from model import LocalSongModel |
|
|
| HARDCODED_TAGS = [1908] |
| torch.set_float32_matmul_precision('high') |
|
|
| class LoRALinear(nn.Module): |
| def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): |
| super().__init__() |
| self.original_linear = original_linear |
| self.rank = rank |
| self.alpha = alpha |
| self.scaling = alpha / rank |
|
|
| self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) |
| self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) |
|
|
| nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) |
| nn.init.zeros_(self.lora_B) |
|
|
| self.original_linear.weight.requires_grad = False |
| if self.original_linear.bias is not None: |
| self.original_linear.bias.requires_grad = False |
|
|
| def forward(self, x): |
| result = self.original_linear(x) |
| lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling |
| return result + lora_out |
|
|
| def inject_lora(model: LocalSongModel, rank: int = 8, alpha: float = 16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): |
| """Inject LoRA layers into the model.""" |
|
|
| lora_modules = [] |
|
|
| if device is None: |
| device = next(model.parameters()).device |
|
|
| for name, module in model.named_modules(): |
|
|
| if isinstance(module, nn.Linear): |
|
|
| if any(target in name for target in target_modules): |
|
|
| *parent_path, attr_name = name.split('.') |
| parent = model |
| for p in parent_path: |
| parent = getattr(parent, p) |
|
|
| lora_layer = LoRALinear(module, rank=rank, alpha=alpha) |
|
|
| lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) |
| lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) |
| setattr(parent, attr_name, lora_layer) |
| lora_modules.append(name) |
|
|
| print(f"Injected LoRA into {len(lora_modules)} layers:") |
| for name in lora_modules[:5]: |
| print(f" - {name}") |
| if len(lora_modules) > 5: |
| print(f" ... and {len(lora_modules) - 5} more") |
|
|
| return model |
|
|
| def get_lora_parameters(model): |
| """Extract only LoRA parameters for optimization.""" |
| lora_params = [] |
| for module in model.modules(): |
| if isinstance(module, LoRALinear): |
| lora_params.extend([module.lora_A, module.lora_B]) |
| return lora_params |
|
|
| def save_lora_weights(model, output_path): |
| """Save LoRA weights to a safetensors file.""" |
| lora_state_dict = {} |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALinear): |
| lora_state_dict[f"{name}.lora_A"] = module.lora_A |
| lora_state_dict[f"{name}.lora_B"] = module.lora_B |
|
|
| save_file(lora_state_dict, output_path) |
| print(f"Saved {len(lora_state_dict)} LoRA parameters to {output_path}") |
|
|
| class LatentDataset(Dataset): |
| """Dataset for pre-encoded latents.""" |
|
|
| def __init__(self, latents_dir: str): |
| self.latents_dir = Path(latents_dir) |
|
|
| self.latent_files = sorted(list(self.latents_dir.glob("*.pt"))) |
|
|
| if len(self.latent_files) == 0: |
| raise ValueError(f"No .pt files found in {latents_dir}") |
|
|
| print(f"Found {len(self.latent_files)} latent files") |
|
|
| def __len__(self): |
| return len(self.latent_files) |
|
|
| def __getitem__(self, idx): |
| latent = torch.load(self.latent_files[idx]) |
|
|
| if latent.ndim == 3: |
| latent = latent.unsqueeze(0) |
|
|
| return latent |
|
|
| class RectifiedFlow: |
| """Simplified rectified flow matching.""" |
|
|
| def __init__(self, model): |
| self.model = model |
|
|
| def forward(self, x, cond): |
| """Compute flow matching loss.""" |
| b = x.size(0) |
|
|
| nt = torch.randn((b,), device=x.device) |
| t = torch.sigmoid(nt) |
|
|
| texp = t.view([b, *([1] * len(x.shape[1:]))]) |
| z1 = torch.randn_like(x) |
| zt = (1 - texp) * x + texp * z1 |
|
|
| vtheta = self.model(zt, t, cond) |
|
|
| target = z1 - x |
| loss = ((vtheta - target) ** 2).mean() |
|
|
| return loss |
|
|
| def collate_fn(batch, subsection_length=1024): |
| """Custom collate function to sample random subsections.""" |
| sampled_latents = [] |
|
|
| for latent in batch: |
| if latent.ndim == 3: |
| latent = latent.unsqueeze(0) |
|
|
| _, _, _, width = latent.shape |
|
|
| if width < subsection_length: |
| |
| pad_amount = subsection_length - width |
| latent = torch.nn.functional.pad(latent, (0, pad_amount), mode='constant', value=0) |
| else: |
| |
| max_start = width - subsection_length |
| start_idx = torch.randint(0, max_start + 1, (1,)).item() |
| latent = latent[:, :, :, start_idx:start_idx + subsection_length] |
|
|
| sampled_latents.append(latent.squeeze(0)) |
|
|
| batch_latents = torch.stack(sampled_latents) |
|
|
| batch_tags = [HARDCODED_TAGS] * len(batch) |
|
|
| return batch_latents, batch_tags |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='LoRA training for LocalSong model with embedding training') |
|
|
| parser.add_argument('--latents_dir', type=str, required=True, |
| help='Directory containing VAE-encoded latents (.pt files)') |
| |
| parser.add_argument('--checkpoint', type=str, default='checkpoints/checkpoint_461260.safetensors', |
| help='Path to base model checkpoint') |
| parser.add_argument('--lora_rank', type=int, default=16, |
| help='LoRA rank') |
| parser.add_argument('--lora_alpha', type=float, default=16, |
| help='LoRA alpha (scaling factor)') |
| parser.add_argument('--batch_size', type=int, default=16, |
| help='Batch size') |
| parser.add_argument('--lr', type=float, default=2e-4, |
| help='Learning rate') |
| parser.add_argument('--steps', type=int, default=1500, |
| help='Number of training steps') |
| parser.add_argument('--subsection_length', type=int, default=512, |
| help='Latent subsection length') |
| parser.add_argument('--output', type=str, default='lora.safetensors', |
| help='Output path for LoRA weights') |
| parser.add_argument('--save_every', type=int, default=500, |
| help='Save checkpoint every N steps') |
|
|
| args = parser.parse_args() |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| print(f"Using hardcoded tags: {HARDCODED_TAGS}") |
|
|
| print(f"Loading base model from {args.checkpoint}") |
| model = LocalSongModel( |
| in_channels=8, |
| num_groups=16, |
| hidden_size=1024, |
| decoder_hidden_size=2048, |
| num_blocks=36, |
| patch_size=(16, 1), |
| num_classes=2304, |
| max_tags=8, |
| ) |
|
|
| print(f"Loading checkpoint from {args.checkpoint}") |
| state_dict = load_file(args.checkpoint) |
| model.load_state_dict(state_dict, strict=True) |
| print("Base model loaded") |
|
|
| model = model.to(device) |
| model = inject_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, device=device) |
|
|
| model.train() |
|
|
| lora_params = get_lora_parameters(model) |
| optimizer = optim.Adam(lora_params, lr=args.lr) |
| print(f"Training {len(lora_params)} LoRA parameters") |
|
|
| dataset = LatentDataset(args.latents_dir) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=0, |
| collate_fn=lambda batch: collate_fn(batch, args.subsection_length) |
| ) |
|
|
| rf = RectifiedFlow(model) |
|
|
| print("\nStarting training...") |
| step = 0 |
| pbar = tqdm(total=args.steps, desc="Training") |
|
|
| loss_history = deque(maxlen=50) |
|
|
| while step < args.steps: |
| for batch_latents, batch_tags in dataloader: |
| batch_latents = batch_latents.to(device) |
|
|
| optimizer.zero_grad() |
| loss = rf.forward(batch_latents, batch_tags) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(lora_params, 1.0) |
| optimizer.step() |
|
|
| |
| loss_history.append(loss.item()) |
| avg_loss = sum(loss_history) / len(loss_history) |
|
|
| pbar.set_postfix({'loss': f'{avg_loss:.4f}'}) |
| pbar.update(1) |
| step += 1 |
|
|
| if step % args.save_every == 0: |
| save_path = args.output.replace('.safetensors', f'_step{step}.safetensors') |
| save_lora_weights(model, save_path) |
|
|
| if step >= args.steps: |
| break |
|
|
| save_lora_weights(model, args.output) |
| print(f"\nTraining complete! LoRA weights saved to {args.output}") |
|
|
| if __name__ == '__main__': |
| main() |
|
|