rahul7star commited on
Commit
e94e7ec
·
verified ·
1 Parent(s): a966295

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +476 -107
app.py CHANGED
@@ -1,122 +1,491 @@
1
- # app.py
 
 
 
 
 
 
 
 
 
2
  import os
3
- import gradio as gr
 
 
 
 
 
4
  import torch
5
- from huggingface_hub import HfApi, Repository, upload_folder
 
 
 
 
 
 
 
 
 
 
6
  from diffusers import DiffusionPipeline
7
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
-
10
- # =========================
11
- # 🧩 Helper Functions
12
- # =========================
13
- def load_diffusion_model(model_name, dtype=torch.float16):
14
- print(f"Loading base model: {model_name}")
15
- pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=dtype)
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- pipe.to(device)
18
- return pipe
19
 
20
- def apply_lora(pipe, lora_path=None, lora_rank=8, alpha=16):
21
- if lora_path and os.path.exists(lora_path):
22
- print(f"Loading LoRA weights from {lora_path}")
23
- pipe.load_lora_weights(lora_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  else:
25
- print("Initializing new LoRA config.")
26
- config = LoraConfig(r=lora_rank, lora_alpha=alpha)
27
- pipe.unet = get_peft_model(pipe.unet, config)
28
  return pipe
29
 
30
- def enhance_prompt(prompt, model_name="Qwen/Qwen2.5-1.5B-Instruct"):
31
- print(f"Enhancing prompt with {model_name}")
32
- tok = AutoTokenizer.from_pretrained(model_name)
33
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda" if torch.cuda.is_available() else "cpu")
34
- inputs = tok(prompt, return_tensors="pt").to(model.device)
35
- outputs = model.generate(**inputs, max_new_tokens=100)
36
- enhanced = tok.decode(outputs[0], skip_special_tokens=True)
37
- return enhanced
38
-
39
- def train_lora(base_model, dataset_path, output_dir, steps=100, lr=1e-4, progress=gr.Progress(track_tqdm=True)):
40
- progress(0, desc="Loading model...")
41
- pipe = load_diffusion_model(base_model)
42
- pipe = apply_lora(pipe)
43
- progress(0.2, desc="Preparing dataset...")
44
-
45
- # Dummy dataset loader (replace with your dataset logic)
46
- import pandas as pd
47
- df = pd.read_csv(dataset_path)
48
- prompts = df['text'].tolist()
49
-
50
- progress(0.3, desc="Training...")
51
- for i, text in enumerate(prompts):
52
- progress(0.3 + 0.6*(i/len(prompts)), desc=f"Training on sample {i+1}/{len(prompts)}")
53
- # Simulate training step
54
- torch.cuda.empty_cache()
55
- torch.manual_seed(i)
56
- _ = pipe(prompt=text, num_inference_steps=1)
57
-
58
- os.makedirs(output_dir, exist_ok=True)
59
- pipe.save_pretrained(output_dir)
60
- progress(1, desc="Training complete ✅")
61
- return output_dir
62
-
63
- def upload_to_hub(model_path, repo_id, token):
64
- print(f"Uploading {model_path} to {repo_id}...")
65
- api = HfApi()
66
- upload_folder(repo_id=repo_id, folder_path=model_path, token=token)
67
- return f"✅ Model uploaded to: https://huggingface.co/{repo_id}"
68
-
69
- def test_model(model_path, prompt):
70
- pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
71
- pipe.to("cuda" if torch.cuda.is_available() else "cpu")
72
- image = pipe(prompt=prompt, num_inference_steps=8).images[0]
73
- return image
74
-
75
- # =========================
76
- # 🎨 Gradio UI
77
- # =========================
78
- def gradio_app():
79
- with gr.Blocks(title="Universal Diffusion Trainer") as demo:
80
- gr.Markdown("## 🌌 Universal Diffusion Fine-tuner & Tester\nTrain or Test any Diffusion Model (T2I, T2V, LoRA, Prompt Enhancer)")
81
-
82
- with gr.Tab("🔧 Training"):
83
- base_model = gr.Textbox(label="Base Model (e.g., nvidia/ChronoEdit-14B-Diffusers)", value="runwayml/stable-diffusion-v1-5")
84
- dataset = gr.Textbox(label="CSV Dataset Path (with columns file_name,text)", value="data.csv")
85
- steps = gr.Slider(10, 1000, 100, step=10, label="Training Steps")
86
- output_dir = gr.Textbox(label="Output Folder", value="./trained_model")
87
- hf_repo = gr.Textbox(label="Upload to HF Repo (e.g., rahul7star/my-lora-model)")
88
- hf_token = gr.Textbox(label="Hugging Face Token", type="password")
89
- run_train = gr.Button("🚀 Start Training")
90
- log = gr.Textbox(label="Logs")
91
- progress = gr.HTML()
92
-
93
- def train_and_upload(base_model, dataset, steps, output_dir, hf_repo, hf_token, progress=gr.Progress(track_tqdm=True)):
94
- output_path = train_lora(base_model, dataset, output_dir, steps, progress=progress)
95
- if hf_repo and hf_token:
96
- url = upload_to_hub(output_path, hf_repo, hf_token)
97
- return f"Training done ✅\nUploaded to: {url}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  else:
99
- return f"Training done ✅\nModel saved at {output_path}"
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- run_train.click(train_and_upload,
102
- inputs=[base_model, dataset, steps, output_dir, hf_repo, hf_token],
103
- outputs=[log])
104
 
105
- with gr.Tab("🧪 Test Model"):
106
- test_model_path = gr.Textbox(label="Model Path or Repo", value="./trained_model")
107
- test_prompt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset")
108
- test_btn = gr.Button("🖼️ Generate")
109
- test_output = gr.Image(label="Generated Output")
110
- test_btn.click(test_model, inputs=[test_model_path, test_prompt], outputs=test_output)
 
 
 
 
 
111
 
112
- with gr.Tab("✨ Prompt Enhancement"):
113
- prompt_input = gr.Textbox(label="Input Prompt")
114
- enhance_btn = gr.Button("Enhance")
115
- enhanced_out = gr.Textbox(label="Enhanced Prompt")
116
- enhance_btn.click(enhance_prompt, inputs=[prompt_input], outputs=[enhanced_out])
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  return demo
119
 
 
120
  if __name__ == "__main__":
121
- demo = gradio_app()
122
- demo.launch()
 
1
+ # universal_lora_trainer_accelerate.py
2
+ """
3
+ Universal LoRA Trainer (Accelerate + PEFT) with Gradio UI.
4
+
5
+ - Real LoRA training (UNet / ChronoEdit transformer / prompt enhancer)
6
+ - Dataset: local folder or HF repo id containing dataset.csv with columns: file_name,text
7
+ - HF_TOKEN is read from environment for uploads
8
+ - QwenEdit / prompt-enhancer LoRA optional
9
+ """
10
+
11
  import os
12
+ import math
13
+ import time
14
+ import tempfile
15
+ from pathlib import Path
16
+ from typing import Optional, List, Tuple
17
+
18
  import torch
19
+ import torch.nn as nn
20
+ from torch.utils.data import Dataset, DataLoader
21
+ import torchvision
22
+ import torchvision.transforms as T
23
+ import pandas as pd
24
+ import numpy as np
25
+ import gradio as gr
26
+ from tqdm.auto import tqdm
27
+
28
+ from huggingface_hub import create_repo, upload_folder, hf_hub_download, HfApi
29
+
30
  from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # optional ChronoEdit
33
+ try:
34
+ from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
35
+ CHRONOEDIT_AVAILABLE = True
36
+ except Exception:
37
+ CHRONOEDIT_AVAILABLE = False
38
+
39
+ # PEFT + Accelerate
40
+ from peft import LoraConfig, get_peft_model
41
+ from accelerate import Accelerator
42
+
43
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
+ IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
45
+ VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
46
+
47
+
48
+ # -------------------------
49
+ # Utilities
50
+ # -------------------------
51
+ def is_hub_repo_like(s: str) -> bool:
52
+ # simple heuristic: contains a slash and no local path separators
53
+ return "/" in s and not os.path.exists(s)
54
+
55
+
56
+ def hf_download_file(repo_id: str, filename: str, local_cache_dir: Optional[str] = None, token: Optional[str] = None) -> str:
57
+ """Download a single file from HF repo to a temporary local path and return path."""
58
+ token = token or os.environ.get("HF_TOKEN")
59
+ out = hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token)
60
+ return out
61
+
62
+
63
+ def find_target_modules(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
64
+ names = [n for n, _ in model.named_modules()]
65
+ selected = set()
66
+ for cand in candidates:
67
+ for n in names:
68
+ if cand in n:
69
+ selected.add(n.split(".")[-1])
70
+ if not selected:
71
+ return ["to_q", "to_k", "to_v", "to_out"]
72
+ return list(selected)
73
+
74
+
75
+ # -------------------------
76
+ # Dataset: local or HF repo
77
+ # -------------------------
78
+ class MediaTextDataset(Dataset):
79
+ """
80
+ CSV must have columns: file_name, text
81
+ If file_name is a local path (exists), loads from local; otherwise if dataset_repo is provided it downloads from hub.
82
+ """
83
+
84
+ def __init__(self, dataset_dir_or_repo: str, csv_name: str = "dataset.csv", max_frames: int = 5,
85
+ image_size=(512, 512), video_frame_size=(128, 256), hub_token: Optional[str] = None):
86
+ self.source = dataset_dir_or_repo
87
+ self.is_hub = is_hub_repo_like(dataset_dir_or_repo)
88
+ self.df = None
89
+ self.root = None
90
+ self.tmpdir = None
91
+ self.max_frames = max_frames
92
+ self.image_size = image_size
93
+ self.video_frame_size = video_frame_size
94
+ self.hub_token = hub_token or os.environ.get("HF_TOKEN")
95
+
96
+ if self.is_hub:
97
+ # download CSV into temp dir
98
+ self.tmpdir = Path(tempfile.mkdtemp(prefix="dataset_hf_"))
99
+ csv_local = hf_download_file(self.source, csv_name, token=self.hub_token)
100
+ # hf_hub_download returns path inside local cache; copy into tmpdir for consistent file reads
101
+ csv_df = pd.read_csv(csv_local)
102
+ self.df = csv_df
103
+ # ensure we will download each referenced file on demand by storing repo id
104
+ self.root = None
105
+ else:
106
+ self.root = Path(dataset_dir_or_repo)
107
+ csv_path = self.root / csv_name
108
+ if not csv_path.exists():
109
+ raise FileNotFoundError(f"{csv_path} not found")
110
+ self.df = pd.read_csv(csv_path)
111
+
112
+ # transforms
113
+ self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
114
+ self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
115
+
116
+ def __len__(self):
117
+ return len(self.df)
118
+
119
+ def _maybe_download_from_hub(self, file_name: str) -> str:
120
+ # returns a local path for the file (cached)
121
+ # if local path exists return as-is
122
+ if self.root is not None:
123
+ p = self.root / file_name
124
+ if p.exists():
125
+ return str(p)
126
+ # else download from hub repo
127
+ local_path = hf_hub_download(repo_id=self.source, filename=file_name, use_auth_token=self.hub_token)
128
+ return local_path
129
+
130
+ def _read_video_frames(self, path: str, num_frames: int):
131
+ video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec')
132
+ total = len(video_frames)
133
+ if total == 0:
134
+ C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
135
+ return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
136
+ if total < num_frames:
137
+ idxs = list(range(total)) + [total-1]*(num_frames-total)
138
+ else:
139
+ idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
140
+ frames = []
141
+ for i in idxs:
142
+ arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
143
+ frames.append(self.video_transform(arr))
144
+ frames = torch.stack(frames, dim=0)
145
+ return frames
146
+
147
+ def __getitem__(self, idx):
148
+ rec = self.df.iloc[idx]
149
+ file_name = rec["file_name"]
150
+ caption = rec["text"]
151
+ if self.is_hub:
152
+ local_path = self._maybe_download_from_hub(file_name)
153
+ else:
154
+ local_path = str(Path(self.root) / file_name)
155
+ p = Path(local_path)
156
+ suffix = p.suffix.lower()
157
+ if suffix in IMAGE_EXTS:
158
+ img = torchvision.io.read_image(local_path) # [C,H,W]
159
+ if isinstance(img, torch.Tensor):
160
+ img = img.permute(1,2,0).numpy()
161
+ return {"type": "image", "image": self.image_transform(img), "caption": caption, "file_name": file_name}
162
+ elif suffix in VIDEO_EXTS:
163
+ frames = self._read_video_frames(local_path, self.max_frames) # [T,C,H,W]
164
+ return {"type": "video", "frames": frames, "caption": caption, "file_name": file_name}
165
+ else:
166
+ raise RuntimeError(f"Unsupported media type: {local_path}")
167
+
168
+
169
+ # -------------------------
170
+ # Pipeline loading helpers
171
+ # -------------------------
172
+ def load_pipeline_auto(base_model_id: str, torch_dtype=torch.float16):
173
+ is_chrono = "chrono" in base_model_id.lower() or "chronoedit" in base_model_id.lower()
174
+ if CHRONOEDIT_AVAILABLE and is_chrono:
175
+ print(f"Loading ChronoEdit pipeline: {base_model_id}")
176
+ pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
177
  else:
178
+ print(f"Loading standard Diffusers pipeline: {base_model_id}")
179
+ pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
 
180
  return pipe
181
 
182
+
183
+ def attach_lora(pipe, target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0):
184
+ """
185
+ Attach LoRA to pipe.unet (image), pipe.transformer (video), or pipe.text_encoder (prompt)
186
+ Returns: modified pipe and the attribute name used
187
+ """
188
+ if target == "unet":
189
+ if not hasattr(pipe, "unet"):
190
+ raise RuntimeError("Chosen pipeline has no UNet")
191
+ target_module = pipe.unet
192
+ attr = "unet"
193
+ elif target == "transformer":
194
+ if not hasattr(pipe, "transformer"):
195
+ raise RuntimeError("Chosen pipeline has no transformer")
196
+ target_module = pipe.transformer
197
+ attr = "transformer"
198
+ elif target == "text_encoder":
199
+ if not hasattr(pipe, "text_encoder"):
200
+ raise RuntimeError("Chosen pipeline has no text_encoder")
201
+ target_module = pipe.text_encoder
202
+ attr = "text_encoder"
203
+ else:
204
+ raise RuntimeError("Unknown target for LoRA")
205
+
206
+ target_modules = find_target_modules(target_module)
207
+ print("LoRA target_modules:", target_modules)
208
+ lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=dropout, bias="none", task_type="SEQ_2_SEQ_LM")
209
+ peft_model = get_peft_model(target_module, lora_config)
210
+
211
+ # set back into pipeline
212
+ setattr(pipe, attr, peft_model)
213
+ return pipe, attr
214
+
215
+
216
+ # -------------------------
217
+ # Training loop (Accelerate)
218
+ # -------------------------
219
+ def train_lora_accelerate(base_model_id: str,
220
+ dataset_dir_or_repo: str,
221
+ csv_name: str,
222
+ adapter_target: str,
223
+ output_dir: str,
224
+ epochs: int = 1,
225
+ batch_size: int = 1,
226
+ lr: float = 1e-4,
227
+ max_train_steps: Optional[int] = None,
228
+ lora_r: int = 8,
229
+ lora_alpha: int = 16,
230
+ max_frames: int = 5,
231
+ hub_token: Optional[str] = None,
232
+ save_every_steps: int = 200) -> Tuple[str, List[str]]:
233
+ """
234
+ Run training using Accelerate. Returns (output_dir, logs)
235
+ """
236
+
237
+ accelerator = Accelerator()
238
+ device = accelerator.device
239
+
240
+ pipe = load_pipeline_auto(base_model_id, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
241
+
242
+ dataset = MediaTextDataset(dataset_dir_or_repo, csv_name=csv_name, max_frames=max_frames)
243
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)
244
+
245
+ # attach LoRA to the chosen target
246
+ pipe, attr = attach_lora(pipe, adapter_target, r=lora_r, alpha=lora_alpha)
247
+ # Move model parts to device via accelerator.prepare
248
+ # For simplicity, we'll collect parameters to optimize
249
+ if adapter_target == "unet":
250
+ peft_module = pipe.unet
251
+ elif adapter_target == "transformer":
252
+ peft_module = pipe.transformer
253
+ else:
254
+ peft_module = pipe.text_encoder
255
+
256
+ # Collect trainable params
257
+ trainable_params = [p for _, p in peft_module.named_parameters() if p.requires_grad]
258
+ optimizer = torch.optim.AdamW(trainable_params, lr=lr)
259
+ # prepare with accelerator
260
+ peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader)
261
+
262
+ # Also move pipeline core bits to device if required (VAE, scheduler) - only for inference functions
263
+ # We'll call pipeline components when needed, moving them to device manually
264
+ logs = []
265
+ global_step = 0
266
+ loss_fn = nn.MSELoss()
267
+
268
+ # prepare scheduler timesteps if available
269
+ if hasattr(pipe, "scheduler"):
270
+ pipe.scheduler.set_timesteps(50, device=device)
271
+ timesteps = pipe.scheduler.timesteps
272
+ else:
273
+ timesteps = None
274
+
275
+ for epoch in range(epochs):
276
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
277
+ for batch in pbar:
278
+ example = batch[0]
279
+ if example["type"] == "image":
280
+ # image flow
281
+ img = example["image"].unsqueeze(0).to(device)
282
+ caption = [example["caption"]]
283
+ if not hasattr(pipe, "encode_prompt"):
284
+ raise RuntimeError("Pipeline has no encode_prompt")
285
+ # CALL encode_prompt on CPU side: move text encoder temporarily
286
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt=caption, negative_prompt=None, do_classifier_free_guidance=True, num_videos_per_prompt=1, prompt_embeds=None, negative_prompt_embeds=None, max_sequence_length=512, device=device)
287
+ # VAE encode
288
+ if not hasattr(pipe, "vae"):
289
+ raise RuntimeError("Pipeline missing VAE")
290
+ with torch.no_grad():
291
+ latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
292
+ noise = torch.randn_like(latents).to(device)
293
+ t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
294
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
295
+ # UNet forward (peft_module is already prepared and on device by accelerator)
296
+ # For accelerate we must call through the pipeline's UNet wrapper; if we replaced pipe.unet earlier,
297
+ # ensure the call signature matches: many UNets return a ModelOutput; for simplicity attempt common API
298
+ noise_pred = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)[0] if isinstance(peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds), tuple) else peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds).sample
299
+ loss = loss_fn(noise_pred, noise)
300
+ else:
301
+ # video flow (ChronoEdit simplified)
302
+ if not CHRONOEDIT_AVAILABLE:
303
+ raise RuntimeError("ChronoEdit pipeline not available in this environment")
304
+ frames = example["frames"].unsqueeze(0).to(device) # [1, T, C, H, W]
305
+ # preprocess frames into pipeline expected format
306
+ frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
307
+ video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
308
+ latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device, generator=None, latents=None, last_image=None)
309
+ if pipe.config.expand_timesteps:
310
+ latents, condition, first_frame_mask = latents_out
311
+ else:
312
+ latents, condition = latents_out
313
+ first_frame_mask = None
314
+ noise = torch.randn_like(latents).to(device)
315
+ t = pipe.scheduler.timesteps[torch.randint(0, len(pipe.scheduler.timesteps), (1,)).item()].to(device)
316
+ noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
317
+ if pipe.config.expand_timesteps:
318
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * noisy_latents
319
  else:
320
+ latent_model_input = torch.cat([noisy_latents, condition], dim=1)
321
+ # transformer forward
322
+ out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]), encoder_hidden_states=None, encoder_hidden_states_image=None, return_dict=False)
323
+ noise_pred = out[0]
324
+ loss = loss_fn(noise_pred, noise)
325
+
326
+ accelerator.backward(loss)
327
+ optimizer.step()
328
+ optimizer.zero_grad()
329
+ global_step += 1
330
+ logs.append(f"step {global_step} loss {loss.item():.6f}")
331
+ pbar.set_postfix({"loss": f"{loss.item():.6f}"})
332
 
333
+ if max_train_steps and global_step >= max_train_steps:
334
+ break
 
335
 
336
+ if global_step % save_every_steps == 0:
337
+ # save PEFT adapter
338
+ out_sub = Path(output_dir) / f"lora_step_{global_step}"
339
+ out_sub.mkdir(parents=True, exist_ok=True)
340
+ try:
341
+ # try to call save_pretrained on peft wrapper
342
+ peft_module.save_pretrained(str(out_sub))
343
+ except Exception as e:
344
+ # fallback to saving state_dict
345
+ torch.save({k: v.cpu() for k, v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
346
+ print(f"Saved intermediate adapter at {out_sub}")
347
 
348
+ if max_train_steps and global_step >= max_train_steps:
349
+ break
 
 
 
350
 
351
+ # final save
352
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
353
+ try:
354
+ peft_module.save_pretrained(output_dir)
355
+ except Exception:
356
+ torch.save({k: v.cpu() for k, v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
357
+
358
+ return output_dir, logs
359
+
360
+
361
+ # -------------------------
362
+ # Test generation
363
+ # -------------------------
364
+ def test_generation_load_and_run(base_model_id: str, adapter_dir: Optional[str], adapter_target: str, prompt: str, num_inference_steps: int = 8):
365
+ # Load base pipeline
366
+ pipe = load_pipeline_auto(base_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
367
+ # If adapter_dir is provided, load the adapter into the same target
368
+ if adapter_dir:
369
+ if adapter_target == "unet":
370
+ # peft: load_pretrained onto the module
371
+ if hasattr(pipe, "unet"):
372
+ pipe.unet = get_peft_model(pipe.unet, LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.unet)))
373
+ try:
374
+ pipe.unet.load_state_dict(torch.load(Path(adapter_dir) / "pytorch_model.bin"), strict=False)
375
+ except Exception:
376
+ try:
377
+ pipe.unet.load_adapter(adapter_dir)
378
+ except Exception:
379
+ print("Adapter loading fallbacks")
380
+ elif adapter_target == "transformer":
381
+ if hasattr(pipe, "transformer"):
382
+ pipe.transformer = get_peft_model(pipe.transformer, LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.transformer)))
383
+ # loader fallback
384
+ elif adapter_target == "text_encoder":
385
+ if hasattr(pipe, "text_encoder"):
386
+ pipe.text_encoder = get_peft_model(pipe.text_encoder, LoraConfig(r=8, lora_alpha=16, target_modules=find_target_modules(pipe.text_encoder)))
387
+
388
+ pipe.to(DEVICE)
389
+ out = pipe(prompt=prompt, num_inference_steps=num_inference_steps)
390
+ if hasattr(out, "images"):
391
+ return out.images[0]
392
+ elif hasattr(out, "frames"):
393
+ frames = out.frames[0]
394
+ from PIL import Image
395
+ return Image.fromarray((frames[-1] * 255).clip(0, 255).astype("uint8"))
396
+ raise RuntimeError("Pipeline returned no images/frames")
397
+
398
+
399
+ # -------------------------
400
+ # Upload adapter to HF Hub
401
+ # -------------------------
402
+ def upload_adapter(local_dir: str, repo_id: str) -> str:
403
+ token = os.environ.get("HF_TOKEN")
404
+ if token is None:
405
+ raise RuntimeError("HF_TOKEN not set in environment for upload")
406
+ create_repo(repo_id, exist_ok=True)
407
+ upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token)
408
+ return f"https://huggingface.co/{repo_id}"
409
+
410
+
411
+ # -------------------------
412
+ # Gradio UI wiring
413
+ # -------------------------
414
+ def run_all_ui(base_model_id: str,
415
+ dataset_source: str,
416
+ csv_name: str,
417
+ mode: str,
418
+ adapter_target: str,
419
+ lora_r: int,
420
+ lora_alpha: int,
421
+ epochs: int,
422
+ batch_size: int,
423
+ lr: float,
424
+ max_train_steps: int,
425
+ output_dir: str,
426
+ upload_repo: str,
427
+ save_every_steps: int):
428
+ # training
429
+ try:
430
+ out_dir, logs = train_lora_accelerate(base_model_id, dataset_source, csv_name, adapter_target, output_dir,
431
+ epochs=epochs, batch_size=batch_size, lr=lr, max_train_steps=(max_train_steps if max_train_steps>0 else None),
432
+ lora_r=lora_r, lora_alpha=lora_alpha, max_frames=5, save_every_steps=save_every_steps)
433
+ except Exception as e:
434
+ return f"Training failed: {e}", None, None
435
+
436
+ # upload (if requested)
437
+ link = None
438
+ if upload_repo:
439
+ try:
440
+ link = upload_adapter(out_dir, upload_repo)
441
+ except Exception as e:
442
+ link = f"Upload failed: {e}"
443
+
444
+ # test generation with first prompt in dataset
445
+ try:
446
+ ds = MediaTextDataset(dataset_source, csv_name=csv_name, max_frames=5)
447
+ test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
448
+ except Exception:
449
+ test_prompt = "A cat on a skateboard"
450
+
451
+ test_img = None
452
+ try:
453
+ test_img = test_generation_load_and_run(base_model_id, out_dir, adapter_target, test_prompt)
454
+ except Exception as e:
455
+ print("Test generation error:", e)
456
+
457
+ return "\n".join(logs[-200:]), test_img, link
458
+
459
+
460
+ def build_ui():
461
+ with gr.Blocks() as demo:
462
+ gr.Markdown("# Universal LoRA Trainer (Accelerate + PEFT)")
463
+ with gr.Row():
464
+ with gr.Column(scale=2):
465
+ base_model = gr.Textbox(label="Base model id (Diffusers)", value="runwayml/stable-diffusion-v1-5")
466
+ dataset_source = gr.Textbox(label="Dataset folder or HF repo (e.g. username/repo)", value="./dataset")
467
+ csv_name = gr.Textbox(label="CSV filename", value="dataset.csv")
468
+ mode = gr.Radio(["text-image", "text-video", "prompt-lora"], label="Mode", value="text-image")
469
+ adapter_target = gr.Dropdown(label="Adapter target (unet/transformer/text_encoder)", choices=["unet", "transformer", "text_encoder"], value="unet")
470
+ lora_r = gr.Slider(1, 32, value=8, step=1, label="LoRA rank (r)")
471
+ lora_alpha = gr.Slider(1, 64, value=16, step=1, label="LoRA alpha")
472
+ epochs = gr.Number(label="Epochs", value=1)
473
+ batch_size = gr.Number(label="Batch size (per device)", value=1)
474
+ lr = gr.Number(label="Learning rate", value=1e-4)
475
+ max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0)
476
+ save_every_steps = gr.Number(label="Save every steps", value=200)
477
+ output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
478
+ upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional, user/repo)", value="")
479
+ start_btn = gr.Button("Start training")
480
+ with gr.Column(scale=1):
481
+ logs = gr.Textbox(label="Training logs (tail)", lines=20)
482
+ sample_image = gr.Image(label="Sample generated frame after training")
483
+ def on_start(base_model_id, dataset_source, csv_name, mode, adapter_target, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, save_every_steps):
484
+ return run_all_ui(base_model_id, dataset_source, csv_name, mode, adapter_target, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, int(save_every_steps))
485
+ start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, mode, adapter_target, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, save_every_steps], outputs=[logs, sample_image, gr.Textbox()])
486
  return demo
487
 
488
+
489
  if __name__ == "__main__":
490
+ demo = build_ui()
491
+ demo.launch(server_name="0.0.0.0", server_port=7860)