Elea Zhong commited on
Commit
49cbc74
·
1 Parent(s): 16d51ab

regression training

Browse files
configs/base.yaml CHANGED
@@ -13,7 +13,7 @@ optim: "adamw"
13
  learning_rate: 1.0e-4
14
  num_workers: 4
15
  resume_from_checkpoint: null
16
- log_model_steps: 100
17
  preprocessing_epoch_len: 64
18
  preprocessing_epoch_repetitions: 1
19
 
 
13
  learning_rate: 1.0e-4
14
  num_workers: 4
15
  resume_from_checkpoint: null
16
+ log_model_steps: null
17
  preprocessing_epoch_len: 64
18
  preprocessing_epoch_repetitions: 1
19
 
configs/regression/base-reg.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-base"
2
+ output_dir: "/data/checkpoints/reg-base"
3
+
4
+ learning_rate: 1e-4
5
+ num_train_epochs: 1
6
+ max_train_steps: null
7
+ preprocessing_epoch_len: 0
8
+ preprocessing_epoch_repetitions: 1
9
+ num_validation_images: &val_num 32
10
+ num_sample_images: 2
11
+ train_range: [*val_num, null]
12
+ val_range: [0, *val_num]
13
+ test_range: [2, 4]
14
+ regression_base_pipe_steps: 8
15
+
16
+ training_type: "regression"
17
+
18
+ regression_data_dir: "/data/regression_output"
19
+ regression_gen_steps: 50
20
+ editing_data_dir: "/data/CrispEdit"
21
+ editing_total_per: 1
configs/regression/reg-mse-triplet.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ validation_loss_terms:
9
+ mse: 1.0
10
+ triplet: 1.0
11
+
configs/regression/reg-mse.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse"
2
+ output_dir: "/data/checkpoints/reg-mse"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+
7
+ validation_loss_terms:
8
+ mse: 1.0
9
+
qwenimage/datamodels.py CHANGED
@@ -1,10 +1,12 @@
1
  import enum
 
2
  from typing import Literal
3
 
4
  import torch
5
  from diffusers.image_processor import PipelineImageInput
6
  from pydantic import BaseModel, ConfigDict, Field
7
 
 
8
  from wandml.foundation.datamodels import FluxInputs
9
  from wandml.trainers.datamodels import ExperimentTrainerParameters
10
 
@@ -27,12 +29,33 @@ class QwenInputs(BaseModel):
27
  # extra="allow",
28
  )
29
 
 
 
 
 
 
 
 
 
30
 
31
  class QuantOptions(str, enum.Enum):
32
  INT8WO = "int8wo"
33
  INT4WO = "int4wo"
34
  FP8ROW = "fp8row"
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class QwenConfig(ExperimentTrainerParameters):
38
  load_multi_view_lora: bool = False
@@ -49,14 +72,27 @@ class QwenConfig(ExperimentTrainerParameters):
49
  quantize_text_encoder: bool = False
50
  quantize_transformer: bool = False
51
 
52
- source_type: str = "im2im"
 
 
 
 
 
 
 
 
53
  style_title: str|None = None
54
- base_dir: str|None = None
55
- csv_path: str|None = None
56
- data_dir: str|None = None
57
- ref_dir: str|None = None
58
- prompt: str|None = None
59
- train_range: tuple[int|float,int|float]|None=None
60
- test_range: tuple[int|float,int|float]|None=None
61
- val_with: str = "train"
 
 
 
 
 
62
 
 
1
  import enum
2
+ from pathlib import Path
3
  from typing import Literal
4
 
5
  import torch
6
  from diffusers.image_processor import PipelineImageInput
7
  from pydantic import BaseModel, ConfigDict, Field
8
 
9
+ from qwenimage.types import DataRange
10
  from wandml.foundation.datamodels import FluxInputs
11
  from wandml.trainers.datamodels import ExperimentTrainerParameters
12
 
 
29
  # extra="allow",
30
  )
31
 
32
+ class TrainingType(str, enum.Enum):
33
+ IM2IM = "im2im"
34
+ NAIVE = "naive"
35
+ REGRESSION = "regression"
36
+
37
+ @property
38
+ def is_style(self):
39
+ return self in [TrainingType.NAIVE, TrainingType.IM2IM]
40
 
41
  class QuantOptions(str, enum.Enum):
42
  INT8WO = "int8wo"
43
  INT4WO = "int4wo"
44
  FP8ROW = "fp8row"
45
 
46
+ LossTermSpecType = int|float|dict[str,int|float]|None
47
+
48
+ class QwenLossTerms(BaseModel):
49
+ mse: LossTermSpecType = 1.0
50
+ triplet: LossTermSpecType = 0.0
51
+ negative_mse: LossTermSpecType = 0.0
52
+ distribution_matching: LossTermSpecType = 0.0
53
+ negative_exponential: LossTermSpecType = 0.0
54
+ pixel_lpips: LossTermSpecType = 0.0
55
+ pixel_mse: LossTermSpecType = 0.0
56
+ adversarial: LossTermSpecType = 0.0
57
+
58
+ triplet_margin: float = 0.2
59
 
60
  class QwenConfig(ExperimentTrainerParameters):
61
  load_multi_view_lora: bool = False
 
72
  quantize_text_encoder: bool = False
73
  quantize_transformer: bool = False
74
 
75
+
76
+ train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
77
+ validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
78
+
79
+ training_type: TrainingType
80
+ train_range: DataRange|None=None
81
+ val_range: DataRange|None=None
82
+ test_range: DataRange|None=None
83
+
84
  style_title: str|None = None
85
+ style_base_dir: str|None = None
86
+ style_csv_path: str|None = None
87
+ style_data_dir: str|None = None
88
+ style_ref_dir: str|None = None
89
+ style_val_with: str = "train"
90
+ naive_static_prompt: str|None = None
91
+
92
+ regression_data_dir: str|Path|None = None
93
+ regression_gen_steps: int = 50
94
+ editing_data_dir: str|Path|None = None
95
+ editing_total_per: int = 1
96
+ regression_base_pipe_steps: int = 8
97
+
98
 
qwenimage/debug.py CHANGED
@@ -241,12 +241,16 @@ def fretry(func=None, *, exceptions=(Exception,), mod_args:tuple[Callable|None,
241
  return decorator(func)
242
 
243
 
244
- def texam(t: torch.Tensor):
245
- print(f"Shape: {tuple(t.shape)}")
 
 
 
 
246
  if t.dtype.is_floating_point or t.dtype.is_complex:
247
  mean_val = t.mean().item()
248
  else:
249
  mean_val = "N/A"
250
- print(f"Min: {t.min().item()}, Max: {t.max().item()}, Mean: {mean_val}")
251
- print(f"Device: {t.device}, Dtype: {t.dtype}, Requires Grad: {t.requires_grad}")
252
 
 
241
  return decorator(func)
242
 
243
 
244
+ def texam(t: torch.Tensor, name=None):
245
+ if name is None:
246
+ name = ""
247
+ else:
248
+ name += " " # spacing
249
+ print(f"{name}Shape: {tuple(t.shape)}")
250
  if t.dtype.is_floating_point or t.dtype.is_complex:
251
  mean_val = t.mean().item()
252
  else:
253
  mean_val = "N/A"
254
+ print(f"{name}Min: {t.min().item()}, Max: {t.max().item()}, Mean: {mean_val}")
255
+ print(f"{name}Device: {t.device}, Dtype: {t.dtype}, Requires Grad: {t.requires_grad}")
256
 
qwenimage/foundation.py CHANGED
@@ -5,6 +5,7 @@ import warnings
5
 
6
  from PIL import Image
7
  from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline
 
8
  import torch
9
  from safetensors.torch import load_file, save_model
10
  import torch.nn.functional as F
@@ -12,15 +13,18 @@ import torchvision.transforms.v2.functional as TF
12
  from einops import rearrange
13
 
14
  from qwenimage.datamodels import QwenConfig, QwenInputs
15
- from qwenimage.debug import ctimed, ftimed, print_gpu_memory, texam
16
  from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
17
  from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
 
18
  from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
19
  from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
20
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
21
  from qwenimage.optimization import simple_quantize_model
22
  from qwenimage.sampling import TimestepDistUtils
23
  from wandml import WandModel
 
 
24
 
25
 
26
  class QwenImageFoundation(WandModel):
@@ -159,8 +163,13 @@ class QwenImageFoundation(WandModel):
159
  texam(latents)
160
  return latents.to(dtype=self.dtype)
161
 
162
- def latents_to_pil(self, latents):
163
- latents = latents.clone().detach()
 
 
 
 
 
164
  latents = latents.unsqueeze(2)
165
 
166
  latents = latents.to(self.dtype)
@@ -178,6 +187,11 @@ class QwenImageFoundation(WandModel):
178
 
179
  latents = latents.to(device=self.device, dtype=self.dtype)
180
  image = self.pipe.vae.decode(latents, return_dict=False)[0][:, :, 0] # F = 1
 
 
 
 
 
181
  image = self.pipe.image_processor.postprocess(image)
182
  return image
183
 
@@ -191,6 +205,15 @@ class QwenImageFoundation(WandModel):
191
  latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
192
  return latents
193
 
 
 
 
 
 
 
 
 
 
194
 
195
  @ftimed
196
  def preprocess_batch(self, batch):
@@ -204,11 +227,7 @@ class QwenImageFoundation(WandModel):
204
  print("preprocess_batch.references")
205
  texam(references)
206
 
207
-
208
- with ctimed("text_encoder.cuda()"):
209
- self.text_encoder.to(device=self.device)
210
- if self.text_encoder_device != "cuda":
211
- self.text_encoder_device = "cuda"
212
 
213
  with torch.no_grad():
214
  prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt(
@@ -229,12 +248,7 @@ class QwenImageFoundation(WandModel):
229
 
230
  @ftimed
231
  def single_step(self, batch) -> torch.Tensor:
232
- with ctimed("text_encoder.cpu()"):
233
- if self.config.offload_text_encoder:
234
- self.text_encoder.to(device="cpu") # offload
235
- if self.text_encoder_device != "cpu":
236
- self.text_encoder_device = "cpu"
237
- print_gpu_memory()
238
 
239
  if "prompt_embeds" not in batch:
240
  batch = self.preprocess_batch(batch)
@@ -302,10 +316,7 @@ class QwenImageFoundation(WandModel):
302
 
303
  def base_pipe(self, inputs: QwenInputs) -> list[Image]:
304
  print(inputs)
305
- with ctimed("text_encoder.cuda()"):
306
- self.text_encoder.to(device=self.device)
307
- if self.text_encoder_device != "cuda":
308
- self.text_encoder_device = "cuda"
309
  image = inputs.image[0]
310
  w,h = image.size
311
  h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
@@ -327,3 +338,189 @@ class QwenImageFoundationSaveInterm(QwenImageFoundation):
327
  inputs.image = [image]
328
  return self.pipe(**inputs.model_dump())
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from PIL import Image
7
  from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline
8
+ import lpips
9
  import torch
10
  from safetensors.torch import load_file, save_model
11
  import torch.nn.functional as F
 
13
  from einops import rearrange
14
 
15
  from qwenimage.datamodels import QwenConfig, QwenInputs
16
+ from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
17
  from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
18
  from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
19
+ from qwenimage.loss import LossAccumulator
20
  from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
21
  from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
22
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
23
  from qwenimage.optimization import simple_quantize_model
24
  from qwenimage.sampling import TimestepDistUtils
25
  from wandml import WandModel
26
+ from wandml.core.logger import wand_logger
27
+ from wandml.trainers.experiment_trainer import ExperimentTrainer
28
 
29
 
30
  class QwenImageFoundation(WandModel):
 
163
  texam(latents)
164
  return latents.to(dtype=self.dtype)
165
 
166
+ def latents_to_pil(self, latents, h=None, w=None, with_grad=False):
167
+ if not with_grad:
168
+ latents = latents.clone().detach()
169
+ if latents.dim() == 3: # 1d latent
170
+ if h is None or w is None:
171
+ raise ValueError(f"auto unpack needs h,w, got {h=}, {w=}")
172
+ latents = self.unpack_latents(latents, h=h, w=w)
173
  latents = latents.unsqueeze(2)
174
 
175
  latents = latents.to(self.dtype)
 
187
 
188
  latents = latents.to(device=self.device, dtype=self.dtype)
189
  image = self.pipe.vae.decode(latents, return_dict=False)[0][:, :, 0] # F = 1
190
+
191
+ if with_grad:
192
+ texam(image, "latents_to_pil.image")
193
+ return image
194
+
195
  image = self.pipe.image_processor.postprocess(image)
196
  return image
197
 
 
205
  latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
206
  return latents
207
 
208
+ @ftimed
209
+ def offload_text_encoder(self, device=str|torch.device):
210
+ if self.text_encoder_device == device:
211
+ return
212
+ print(f"Moving text encoder to {device}")
213
+ self.text_encoder_device = device
214
+ self.text_encoder.to(device)
215
+ if device == "cpu" or device == torch.device("cpu"):
216
+ print_gpu_memory(clear_mem="pre")
217
 
218
  @ftimed
219
  def preprocess_batch(self, batch):
 
227
  print("preprocess_batch.references")
228
  texam(references)
229
 
230
+ self.offload_text_encoder("cuda")
 
 
 
 
231
 
232
  with torch.no_grad():
233
  prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt(
 
248
 
249
  @ftimed
250
  def single_step(self, batch) -> torch.Tensor:
251
+ self.offload_text_encoder("cpu")
 
 
 
 
 
252
 
253
  if "prompt_embeds" not in batch:
254
  batch = self.preprocess_batch(batch)
 
316
 
317
  def base_pipe(self, inputs: QwenInputs) -> list[Image]:
318
  print(inputs)
319
+ self.offload_text_encoder("cuda")
 
 
 
320
  image = inputs.image[0]
321
  w,h = image.size
322
  h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
 
338
  inputs.image = [image]
339
  return self.pipe(**inputs.model_dump())
340
 
341
+
342
+ class QwenImageRegressionFoundation(QwenImageFoundation):
343
+ def __init__(self, config:QwenConfig, device=None):
344
+ super().__init__(config, device=device)
345
+ self.lpips_fn = lpips.LPIPS(net='vgg').to(device=self.device)
346
+
347
+ def preprocess_batch(self, batch):
348
+ return batch
349
+
350
+ @ftimed
351
+ def single_step(self, batch) -> torch.Tensor:
352
+ self.offload_text_encoder("cpu")
353
+
354
+ out_dict = batch["data"]
355
+ assert len(out_dict) == 1
356
+ out_dict = out_dict[0]
357
+
358
+ prompt_embeds = out_dict["prompt_embeds"]
359
+ prompt_embeds_mask = out_dict["prompt_embeds_mask"]
360
+ prompt_embeds = prompt_embeds.to(device=self.device, dtype=self.dtype)
361
+ prompt_embeds_mask = prompt_embeds_mask.to(device=self.device, dtype=self.dtype)
362
+
363
+ h_f16 = out_dict["height"] // 16
364
+ w_f16 = out_dict["width"] // 16
365
+
366
+ refs_1d = out_dict["image_latents"].to(device=self.device, dtype=self.dtype)
367
+ t = out_dict["t"].to(device=self.device, dtype=self.dtype)
368
+ x_0_1d = out_dict["output"].to(device=self.device, dtype=self.dtype)
369
+ x_t_1d = out_dict["latents_start"].to(device=self.device, dtype=self.dtype)
370
+ v_neg_1d = out_dict["noise_pred"].to(device=self.device, dtype=self.dtype)
371
+
372
+
373
+ v_gt_1d = (x_t_1d - x_0_1d) / t
374
+
375
+ inp_1d = torch.cat([x_t_1d, refs_1d], dim=1)
376
+ print("inp_1d")
377
+ texam(inp_1d)
378
+
379
+ img_shapes = out_dict["img_shapes"]
380
+ txt_seq_lens = out_dict["txt_seq_lens"]
381
+ image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=self.device)
382
+
383
+ v_pred_1d = self.transformer(
384
+ hidden_states=inp_1d,
385
+ encoder_hidden_states=prompt_embeds,
386
+ encoder_hidden_states_mask=prompt_embeds_mask,
387
+ timestep=t,
388
+ image_rotary_emb=image_rotary_emb,
389
+ return_dict=False,
390
+ )[0]
391
+
392
+ v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)]
393
+
394
+
395
+ split = batch["split"]
396
+ step = batch["step"]
397
+ if split == "train":
398
+ loss_terms = self.config.train_loss_terms.model_dump()
399
+ elif split == "validation":
400
+ loss_terms = self.config.validation_loss_terms.model_dump()
401
+ loss_accumulator = LossAccumulator(
402
+ terms=loss_terms,
403
+ step=step,
404
+ split=split,
405
+ )
406
+
407
+ if loss_accumulator.has("mse"):
408
+ if self.config.loss_weight_dist is not None:
409
+ mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="none").mean(dim=[1,2,3])
410
+ weights = self.timestep_dist_utils.get_loss_weighting(t)
411
+ mse_loss = torch.mean(mse_loss * weights)
412
+ else:
413
+ mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="mean")
414
+ loss_accumulator.accum("mse", mse_loss)
415
+
416
+ if loss_accumulator.has("triplet"):
417
+ eps = 1e-6
418
+ margin = loss_terms["triplet_margin"]
419
+ v_span = (v_gt_1d - v_neg_1d).pow(2).sum(dim=(-2,-1))
420
+ diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
421
+ diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
422
+ diffv_gt_pred_reg = diffv_gt_pred / (v_span + eps)
423
+ diffv_neg_pred_reg = diffv_neg_pred / (v_span + eps)
424
+
425
+ texam(v_span, name="v_span")
426
+ texam(diffv_gt_pred, name="diffv_gt_pred")
427
+ texam(diffv_neg_pred, name="diffv_neg_pred")
428
+ texam(diffv_gt_pred_reg, name="diffv_gt_pred_reg")
429
+ texam(diffv_neg_pred_reg, name="diffv_neg_pred_reg")
430
+ texam(diffv_gt_pred_reg - diffv_neg_pred_reg, name="diffv_gt_pred_reg - diffv_neg_pred_reg")
431
+
432
+ triplet_loss = F.relu(diffv_gt_pred_reg - diffv_neg_pred_reg + margin).mean()
433
+ loss_accumulator.accum("triplet_loss", triplet_loss)
434
+
435
+
436
+ if loss_accumulator.has("negative_mse"):
437
+ neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
438
+ loss_accumulator.accum("negative_mse", neg_mse_loss)
439
+
440
+
441
+ if loss_accumulator.has("distribution_matching"):
442
+ dm_v = (v_pred_1d - v_neg_1d + v_gt_1d).detach()
443
+ dm_mse = F.mse_loss(v_pred_1d, dm_v, reduction="mean")
444
+ loss_accumulator.accum("distribution_matching", dm_mse)
445
+
446
+ if loss_accumulator.has("negative_exponential"):
447
+ raise NotImplementedError()
448
+
449
+ if loss_accumulator.has("pixel_lpips") or loss_accumulator.has("pixel_mse"):
450
+ x_0_pred = x_0_1d - t * v_pred_1d
451
+ pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16).detach()
452
+ pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16)
453
+
454
+ if loss_accumulator.has("pixel_lpips"):
455
+ lpips_loss = self.lpips_fn(pixel_values_x0_gt, pixel_values_x0_pred)
456
+ loss_accumulator.accum("pixel_lpips", lpips_loss)
457
+
458
+ if loss_accumulator.has("pixel_mse"):
459
+ pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
460
+ loss_accumulator.accum("pixel_mse", pixel_mse_loss)
461
+
462
+ if loss_accumulator.has("adversarial"):
463
+ raise NotImplementedError()
464
+
465
+
466
+ loss = loss_accumulator.total
467
+
468
+ logs = loss_accumulator.logs()
469
+ wand_logger.log(logs, step=step, commit=False)
470
+
471
+ if self.should_log_training(step):
472
+ self.log_single_step_images(
473
+ h_f16,
474
+ w_f16,
475
+ t,
476
+ x_0_1d,
477
+ x_t_1d,
478
+ v_gt_1d,
479
+ v_neg_1d,
480
+ v_pred_1d,
481
+ visualize_velocities=True,
482
+ )
483
+
484
+ return loss
485
+
486
+ def should_log_training(self, step) -> bool:
487
+ return (
488
+ self.training # don't log when validating
489
+ and ExperimentTrainer._is_step_trigger(step, self.config.log_batch_steps)
490
+ )
491
+
492
+ def log_single_step_images(
493
+ self,
494
+ h_f16,
495
+ w_f16,
496
+ t,
497
+ x_0_1d,
498
+ x_t_1d,
499
+ v_gt_1d,
500
+ v_neg_1d,
501
+ v_pred_1d,
502
+ visualize_velocities=True,
503
+ ):
504
+ x_0_pred = x_0_1d - t * v_pred_1d
505
+ x_0_neg = x_0_1d - t * v_neg_1d
506
+ log_pils = {
507
+ "x_t_1d": self.latents_to_pil(x_t_1d, h=h_f16, w=w_f16),
508
+ "x_0": self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16),
509
+ "x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
510
+ "x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
511
+ }
512
+ if visualize_velocities:
513
+ log_pils.update({
514
+ "v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
515
+ "v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
516
+ "v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
517
+ })
518
+
519
+ wand_logger.log({
520
+ "train_images": log_pils,
521
+ }, commit=False)
522
+
523
+
524
+ def base_pipe(self, inputs: QwenInputs) -> list[Image]:
525
+ inputs.num_inference_steps = self.config.regression_base_pipe_steps # override
526
+ return super().base_pipe(inputs)
qwenimage/loss.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from torch import Tensor
3
+ import torch
4
+
5
+ from wandml.core.wandmodel import WandModel
6
+
7
+
8
+ class LossAccumulator:
9
+
10
+ def __init__(
11
+ self,
12
+ terms: dict[str, int|float|dict],
13
+ step: int|None=None,
14
+ split: str|None=None,
15
+ ):
16
+ self.terms = terms
17
+ self.step = step
18
+ if split is not None:
19
+ self.split = split
20
+ self.prefix = f"{self.split}_"
21
+ else:
22
+ self.split = ""
23
+ self.prefix = ""
24
+ self.unweighted: dict[str, Tensor] = {}
25
+ self.weighted: dict[str, Tensor] = {}
26
+
27
+ def resolve_weight(self, name: str, step: int|None = None) -> float:
28
+ """
29
+ loss weight spec:
30
+ - float | int
31
+ - dict: {"start": int, "end": int, "min": float, "max": float}
32
+ """
33
+ spec = self.terms.get(name, 0.0)
34
+
35
+ if isinstance(spec, (int, float)):
36
+ return float(spec)
37
+
38
+ if isinstance(spec, dict):
39
+ try:
40
+ start = int(spec.get("start", 0))
41
+ end = int(spec["end"]) # required
42
+ vmin = float(spec.get("min", 0.0))
43
+ vmax = float(spec["max"]) # required
44
+ except Exception:
45
+ warnings.warn(f"Malformed dict {spec}; treat as disabled")
46
+ return 0.0
47
+
48
+ if self.step <= start:
49
+ return vmin
50
+ if self.step >= end:
51
+ return vmax
52
+ span = max(1, end - start)
53
+ t = (self.step - start) / span
54
+ return vmin + (vmax - vmin) * t
55
+
56
+ warnings.warn(f"Unknown spec type {spec}; treat as disabled")
57
+ return 0.0
58
+
59
+ def has(self, name: str) -> bool:
60
+ return self.resolve_weight(name) > 0
61
+
62
+ def accum(self, name: str, loss_value: Tensor, extra_weight: float|None = None) -> Tensor:
63
+ self.unweighted[name] = loss_value
64
+
65
+ w = self.resolve_weight(name)
66
+ if extra_weight is not None:
67
+ w *= float(extra_weight)
68
+
69
+ weighted = loss_value * w
70
+ self.weighted[name] = weighted
71
+ return weighted
72
+
73
+ @property
74
+ def total(self):
75
+ weighted_losses = list(self.weighted.values())
76
+ return torch.stack(weighted_losses).sum()
77
+
78
+
79
+ def logs(self) -> dict[str, float]:
80
+ # append prefix and suffix for logs
81
+ logs: dict[str, float] = {}
82
+ for k, v in self.unweighted.items():
83
+ logs[f"{self.prefix}_{k}"] = float(v.detach().item())
84
+ for k, v in self.weighted.items():
85
+ logs[f"{self.prefix}_{k}_weighted"] = float(v.detach().item())
86
+ return logs
87
+
qwenimage/sources.py CHANGED
@@ -3,11 +3,35 @@
3
  import csv
4
  from pathlib import Path
5
  import random
 
6
 
7
  from PIL import Image
 
 
 
 
8
  from wandml.core.datamodels import SourceDataType
9
  from wandml.core.source import Source
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class StyleSource(Source):
12
  _data_types = [
13
  SourceDataType(name="image", type=Image.Image),
@@ -63,7 +87,7 @@ class StyleImagetoImageSource(Source):
63
  SourceDataType(name="image", type=Image.Image),
64
  SourceDataType(name="reference", type=Image.Image),
65
  ]
66
- def __init__(self, csv_path, base_dir, style_title=None, data_range:tuple[int|float,int|float]|None=None):
67
  self.csv_path = Path(csv_path)
68
  self.base_dir = Path(base_dir)
69
  self.style_title = style_title
@@ -85,16 +109,9 @@ class StyleImagetoImageSource(Source):
85
  })
86
 
87
  if data_range is not None:
88
- left, right = data_range
89
- if (isinstance(left, float) or isinstance(right, float)) and (left<1 and right<1):
90
- left = left * len(self.data)
91
- right = right * len(self.data)
92
- remain_data = []
93
- for i, d in enumerate(self.data):
94
- if left <= i and i < right:
95
- remain_data.append(d)
96
- self.data = remain_data
97
-
98
  print(f"{self.__class__} of len{len(self)}")
99
 
100
 
@@ -106,4 +123,85 @@ class StyleImagetoImageSource(Source):
106
  prompt = item["prompt"]
107
  input_pil = Image.open(item['input_image']).convert("RGB")
108
  output_pil = Image.open(item['output_image']).convert("RGB")
109
- return prompt, output_pil, input_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import csv
4
  from pathlib import Path
5
  import random
6
+ from typing import Literal
7
 
8
  from PIL import Image
9
+ import torch
10
+ from datasets import concatenate_datasets, load_dataset, interleave_datasets
11
+
12
+ from qwenimage.types import DataRange
13
  from wandml.core.datamodels import SourceDataType
14
  from wandml.core.source import Source
15
 
16
+ def parse_datarange(dr: DataRange, length: int, return_as: Literal['list', 'range']='list'):
17
+ if not isinstance(length, int):
18
+ raise ValueError()
19
+ left, right = dr
20
+ if left is None:
21
+ left = 0
22
+ if right is None:
23
+ right = length
24
+ if (isinstance(left, float) or isinstance(right, float)) and (left<1 and right<1):
25
+ left = left * length
26
+ right = right * length
27
+ if return_as=="list":
28
+ return list(range(left, right))
29
+ elif return_as=="range":
30
+ return range(left, right)
31
+ else:
32
+ raise ValueError()
33
+
34
+
35
  class StyleSource(Source):
36
  _data_types = [
37
  SourceDataType(name="image", type=Image.Image),
 
87
  SourceDataType(name="image", type=Image.Image),
88
  SourceDataType(name="reference", type=Image.Image),
89
  ]
90
+ def __init__(self, csv_path, base_dir, style_title=None, data_range:DataRange|None=None):
91
  self.csv_path = Path(csv_path)
92
  self.base_dir = Path(base_dir)
93
  self.style_title = style_title
 
109
  })
110
 
111
  if data_range is not None:
112
+ indexes = parse_datarange(data_range, len(self.data))
113
+ self.data = [self.data[i] for i in indexes]
114
+
 
 
 
 
 
 
 
115
  print(f"{self.__class__} of len{len(self)}")
116
 
117
 
 
123
  prompt = item["prompt"]
124
  input_pil = Image.open(item['input_image']).convert("RGB")
125
  output_pil = Image.open(item['output_image']).convert("RGB")
126
+ return prompt, output_pil, input_pil
127
+
128
+
129
+ class RegressionSource(Source):
130
+ _data_types = [
131
+ SourceDataType(name="data", type=dict),
132
+ ]
133
+
134
+ def __init__(self, data_dir, gen_steps=50, data_range:DataRange|None=None):
135
+ if not isinstance(data_dir, Path):
136
+ data_dir = Path(data_dir)
137
+ self.data_paths = list(data_dir.glob("*.pt"))
138
+ if data_range is not None:
139
+ indexes = parse_datarange(data_range, len(self.data_paths))
140
+ self.data_paths = [self.data_paths[i] for i in indexes]
141
+ self.gen_steps = gen_steps
142
+ self._len = gen_steps * len(self.data_paths)
143
+ print(f"{self.__class__} of len{len(self)}")
144
+
145
+ def __len__(self):
146
+ return self._len
147
+
148
+ def __getitem__(self, idx):
149
+ data_idx = idx // self.gen_steps
150
+ step_idx = idx % self.gen_steps
151
+ out_dict = torch.load(self.data_paths[data_idx])
152
+ t = out_dict.pop(f"t_{step_idx}")
153
+ latents_start = out_dict.pop(f"latents_{step_idx}_start")
154
+ noise_pred = out_dict.pop(f"noise_pred_{step_idx}")
155
+ out_dict["t"] = t
156
+ out_dict["latents_start"] = latents_start
157
+ out_dict["noise_pred"] = noise_pred
158
+ return [out_dict,]
159
+
160
+
161
+ class EditingSource(Source):
162
+ _data_types = [
163
+ SourceDataType(name="text", type=str),
164
+ SourceDataType(name="image", type=Image.Image),
165
+ SourceDataType(name="reference", type=Image.Image),
166
+ ]
167
+ EDIT_TYPES = [
168
+ "color",
169
+ "style",
170
+ "replace",
171
+ "remove",
172
+ "add",
173
+ "motion change",
174
+ "background change",
175
+ ]
176
+ def __init__(self, data_dir:Path, total_per=1, data_range:DataRange|None=None):
177
+ data_dir = Path(data_dir)
178
+ self.join_ds = self.build_dataset(data_dir, total_per)
179
+
180
+ if data_range is not None:
181
+ indexes = parse_datarange(data_range, len(self.join_ds))
182
+ self.join_ds = self.join_ds.select(indexes)
183
+
184
+ print(f"{self.__class__} of len{len(self)}")
185
+
186
+ def build_dataset(self, data_dir:Path, total_per:int):
187
+ all_edit_datasets = []
188
+ for edit_type in self.EDIT_TYPES:
189
+ to_concat = []
190
+ for ds_n in range(total_per):
191
+ ds = load_dataset("parquet", data_files=str(data_dir/f"{edit_type}_{ds_n:05d}.parquet"), split="train")
192
+ to_concat.append(ds)
193
+ edit_type_concat = concatenate_datasets(to_concat)
194
+ all_edit_datasets.append(edit_type_concat)
195
+ # consistent ordering for indexing, also allow extension by increasing total_per
196
+ join_ds = interleave_datasets(all_edit_datasets)
197
+ return join_ds
198
+
199
+ def __len__(self):
200
+ return len(self.join_ds)
201
+
202
+ def __getitem__(self, idx):
203
+ data = self.join_ds[idx]
204
+ reference = data["input_img"]
205
+ image = data["output_img"]
206
+ text = data["instruction"]
207
+ return text, image, reference
qwenimage/task.py CHANGED
@@ -45,3 +45,17 @@ class TextToImageWithRefTask(Task):
45
  "image": SourceDataType(name="reference", type=Image.Image),
46
  }
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  "image": SourceDataType(name="reference", type=Image.Image),
46
  }
47
 
48
+ class RegressionTask(Task):
49
+ data_types = [
50
+ SourceDataType(name="data", type=dict),
51
+ ]
52
+ type_transforms = [
53
+ Task.identity,
54
+ ]
55
+ batch_type_transforms = [
56
+ Task.identity,
57
+ ]
58
+ sample_type_transforms = [
59
+ Task.identity,
60
+ ]
61
+ sample_input_dict = {}
qwenimage/training.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import subprocess
3
  from pathlib import Path
4
  import argparse
 
5
 
6
  import yaml
7
  import diffusers
@@ -17,10 +18,10 @@ from wandml.trainers.experiment_trainer import ExperimentTrainer
17
 
18
 
19
  from qwenimage.finetuner import QwenLoraFinetuner
20
- from qwenimage.sources import StyleSourceWithRandomRef, StyleImagetoImageSource
21
- from qwenimage.task import TextToImageWithRefTask
22
  from qwenimage.datamodels import QwenConfig
23
- from qwenimage.foundation import QwenImageFoundation
24
 
25
 
26
  wandml_utils.debug.DEBUG = True
@@ -33,65 +34,115 @@ def _deep_update(base: dict, updates: dict) -> dict:
33
  base[k] = v
34
  return base
35
 
36
- def run_training(config_path: Path | str, update_config_paths: list[Path] | None = None):
37
- WandAuth(ignore=True)
38
-
39
- with open(config_path, "r") as f:
40
- config = yaml.safe_load(f)
41
- if update_config_paths is not None:
42
- for update_config_path in update_config_paths:
43
- with open(update_config_path, "r") as uf:
44
- update_cfg = yaml.safe_load(uf)
45
- if isinstance(update_cfg, dict):
46
- config = _deep_update(config if isinstance(config, dict) else {}, update_cfg)
47
-
48
- config = QwenConfig(
49
- **config,
50
- )
51
-
52
- # Data
53
  dp = WandDataPipe()
54
  dp.set_task(TextToImageWithRefTask())
55
  dp_test = WandDataPipe()
56
  dp_test.set_task(TextToImageWithRefTask())
57
- if config.source_type == "naive":
58
  src = StyleSourceWithRandomRef(
59
- config.data_dir, config.prompt, config.ref_dir, set_len=config.max_train_steps
 
 
 
60
  )
61
  dp.add_source(src)
62
- elif config.source_type == "im2im":
63
  src = StyleImagetoImageSource(
64
- csv_path=config.csv_path,
65
- base_dir=config.base_dir,
66
  style_title=config.style_title,
67
  data_range=config.train_range,
68
  )
69
  dp.add_source(src)
70
  src_test = StyleImagetoImageSource(
71
- csv_path=config.csv_path,
72
- base_dir=config.base_dir,
73
  style_title=config.style_title,
74
  data_range=config.test_range,
75
  )
76
  dp_test.add_source(src_test)
77
  else:
78
  raise ValueError()
 
 
 
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Model
82
- foundation = QwenImageFoundation(config=config)
 
 
 
 
 
83
  finetuner = QwenLoraFinetuner(foundation, config)
84
- finetuner.load(None)
85
-
86
 
87
  if len(dp_test) == 0:
 
88
  dp_test = None
89
- if config.val_with == "train":
90
- dp_val = dp
91
- elif config.val_with == "test":
92
- dp_val = dp_test
93
- else:
94
- raise ValueError()
95
  trainer = ExperimentTrainer(
96
  model=foundation,
97
  datapipe=dp,
 
2
  import subprocess
3
  from pathlib import Path
4
  import argparse
5
+ import warnings
6
 
7
  import yaml
8
  import diffusers
 
18
 
19
 
20
  from qwenimage.finetuner import QwenLoraFinetuner
21
+ from qwenimage.sources import EditingSource, RegressionSource, StyleSourceWithRandomRef, StyleImagetoImageSource
22
+ from qwenimage.task import RegressionTask, TextToImageWithRefTask
23
  from qwenimage.datamodels import QwenConfig
24
+ from qwenimage.foundation import QwenImageFoundation, QwenImageRegressionFoundation
25
 
26
 
27
  wandml_utils.debug.DEBUG = True
 
34
  base[k] = v
35
  return base
36
 
37
+ def styles_datapipe(config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  dp = WandDataPipe()
39
  dp.set_task(TextToImageWithRefTask())
40
  dp_test = WandDataPipe()
41
  dp_test.set_task(TextToImageWithRefTask())
42
+ if config.training_type == "naive":
43
  src = StyleSourceWithRandomRef(
44
+ config.style_data_dir,
45
+ config.naive_static_prompt,
46
+ config.style_ref_dir,
47
+ set_len=config.max_train_steps,
48
  )
49
  dp.add_source(src)
50
+ elif config.training_type == "im2im":
51
  src = StyleImagetoImageSource(
52
+ csv_path=config.style_csv_path,
53
+ base_dir=config.style_base_dir,
54
  style_title=config.style_title,
55
  data_range=config.train_range,
56
  )
57
  dp.add_source(src)
58
  src_test = StyleImagetoImageSource(
59
+ csv_path=config.style_csv_path,
60
+ base_dir=config.style_base_dir,
61
  style_title=config.style_title,
62
  data_range=config.test_range,
63
  )
64
  dp_test.add_source(src_test)
65
  else:
66
  raise ValueError()
67
+ if config.style_val_with == "train":
68
+ dp_val = dp
69
+ elif config.style_val_with == "test":
70
+ dp_val = dp_test
71
+ else:
72
+ raise ValueError()
73
+ return dp, dp_val, dp_test
74
 
75
+ def regression_datapipe(config):
76
+ dp = WandDataPipe()
77
+ dp.set_task(RegressionTask())
78
+ dp_val = WandDataPipe()
79
+ dp_val.set_task(RegressionTask())
80
+ dp_test = WandDataPipe()
81
+ dp_test.set_task(TextToImageWithRefTask())
82
+
83
+ src_train = RegressionSource(
84
+ data_dir=config.regression_data_dir,
85
+ gen_steps=config.regression_gen_steps,
86
+ data_range=config.train_range,
87
+ )
88
+ dp.add_source(src_train)
89
+
90
+ src_val = RegressionSource(
91
+ data_dir=config.regression_data_dir,
92
+ gen_steps=config.regression_gen_steps,
93
+ data_range=config.val_range,
94
+ )
95
+ dp_val.add_source(src_val)
96
+
97
+ src_test = EditingSource(
98
+ data_dir=config.editing_data_dir,
99
+ total_per=config.editing_total_per,
100
+ data_range=config.test_range,
101
+ )
102
+ dp_test.add_source(src_test)
103
+
104
+ return dp, dp_val, dp_test
105
+
106
+ def run_training(config_path: Path | str, update_config_paths: list[Path] | None = None):
107
+ WandAuth(ignore=True)
108
+
109
+ with open(config_path, "r") as f:
110
+ config = yaml.safe_load(f)
111
+ if update_config_paths is not None:
112
+ for update_config_path in update_config_paths:
113
+ with open(update_config_path, "r") as uf:
114
+ update_cfg = yaml.safe_load(uf)
115
+ if isinstance(update_cfg, dict):
116
+ config = _deep_update(config if isinstance(config, dict) else {}, update_cfg)
117
+
118
+ config = QwenConfig(
119
+ **config,
120
+ )
121
+
122
+ # Data
123
+ if config.training_type.is_style:
124
+ dp, dp_val, dp_test = styles_datapipe(config)
125
+ elif config.training_type == "regression":
126
+ dp, dp_val, dp_test = regression_datapipe(config)
127
+ else:
128
+ raise ValueError(f"Got {config.training_type=}")
129
 
130
  # Model
131
+ if config.training_type.is_style:
132
+ foundation = QwenImageFoundation(config=config)
133
+ elif config.training_type == "regression":
134
+ foundation = QwenImageRegressionFoundation(config=config)
135
+ else:
136
+ raise ValueError(f"Got {config.training_type=}")
137
  finetuner = QwenLoraFinetuner(foundation, config)
138
+ finetuner.load(config.resume_from_checkpoint, config.lora_rank)
 
139
 
140
  if len(dp_test) == 0:
141
+ warnings.warn("Test datapipe is removed for being len=0")
142
  dp_test = None
143
+ if len(dp_val) == 0:
144
+ warnings.warn("Validation datapipe is removed for being len=0")
145
+ dp_val = None
 
 
 
146
  trainer = ExperimentTrainer(
147
  model=foundation,
148
  datapipe=dp,
qwenimage/types.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+
3
+ DataRange = tuple[int|float|None,int|float|None]
scripts/edit_datasets.ipynb CHANGED
@@ -89,7 +89,7 @@
89
  " edit_type_concat = concatenate_datasets(to_concat)\n",
90
  " all_edit_datasets.append(edit_type_concat)\n",
91
  "\n",
92
- "# consistent ordering for indexing, also allow extension\n",
93
  "join_ds = interleave_datasets(all_edit_datasets)"
94
  ]
95
  },
 
89
  " edit_type_concat = concatenate_datasets(to_concat)\n",
90
  " all_edit_datasets.append(edit_type_concat)\n",
91
  "\n",
92
+ "# consistent ordering for indexing, also allow extension by increasing total_per\n",
93
  "join_ds = interleave_datasets(all_edit_datasets)"
94
  ]
95
  },
scripts/process_regression_data.ipynb ADDED
@@ -0,0 +1,1274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "e1e781e9",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%cd /home/ubuntu/Qwen-Image-Edit-Angles"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "d6192ee5",
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "data": {
21
+ "text/plain": [
22
+ "4941"
23
+ ]
24
+ },
25
+ "execution_count": 11,
26
+ "metadata": {},
27
+ "output_type": "execute_result"
28
+ }
29
+ ],
30
+ "source": [
31
+ "import glob\n",
32
+ "from pathlib import Path\n",
33
+ "\n",
34
+ "base_data = Path(\"/data/regression_output\")\n",
35
+ "\n",
36
+ "all_reg = list(base_data.glob(\"*.pt\"))\n",
37
+ "max_ind = max([int(reg_pth.stem) for reg_pth in all_reg])\n",
38
+ "\n",
39
+ "max_ind"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 14,
45
+ "id": "b5124900",
46
+ "metadata": {},
47
+ "outputs": [
48
+ {
49
+ "name": "stdout",
50
+ "output_type": "stream",
51
+ "text": [
52
+ "prompt_embeds\n",
53
+ "prompt_embeds_mask\n",
54
+ "noise\n",
55
+ "image_latents\n",
56
+ "vae_image_sizes\n",
57
+ "img_shapes\n",
58
+ "txt_seq_lens\n",
59
+ "t_0\n",
60
+ "latents_0_start\n",
61
+ "noise_pred_0\n",
62
+ "t_1\n",
63
+ "latents_1_start\n",
64
+ "noise_pred_1\n",
65
+ "t_2\n",
66
+ "latents_2_start\n",
67
+ "noise_pred_2\n",
68
+ "t_3\n",
69
+ "latents_3_start\n",
70
+ "noise_pred_3\n",
71
+ "t_4\n",
72
+ "latents_4_start\n",
73
+ "noise_pred_4\n",
74
+ "t_5\n",
75
+ "latents_5_start\n",
76
+ "noise_pred_5\n",
77
+ "t_6\n",
78
+ "latents_6_start\n",
79
+ "noise_pred_6\n",
80
+ "t_7\n",
81
+ "latents_7_start\n",
82
+ "noise_pred_7\n",
83
+ "t_8\n",
84
+ "latents_8_start\n",
85
+ "noise_pred_8\n",
86
+ "t_9\n",
87
+ "latents_9_start\n",
88
+ "noise_pred_9\n",
89
+ "t_10\n",
90
+ "latents_10_start\n",
91
+ "noise_pred_10\n",
92
+ "t_11\n",
93
+ "latents_11_start\n",
94
+ "noise_pred_11\n",
95
+ "t_12\n",
96
+ "latents_12_start\n",
97
+ "noise_pred_12\n",
98
+ "t_13\n",
99
+ "latents_13_start\n",
100
+ "noise_pred_13\n",
101
+ "t_14\n",
102
+ "latents_14_start\n",
103
+ "noise_pred_14\n",
104
+ "t_15\n",
105
+ "latents_15_start\n",
106
+ "noise_pred_15\n",
107
+ "t_16\n",
108
+ "latents_16_start\n",
109
+ "noise_pred_16\n",
110
+ "t_17\n",
111
+ "latents_17_start\n",
112
+ "noise_pred_17\n",
113
+ "t_18\n",
114
+ "latents_18_start\n",
115
+ "noise_pred_18\n",
116
+ "t_19\n",
117
+ "latents_19_start\n",
118
+ "noise_pred_19\n",
119
+ "t_20\n",
120
+ "latents_20_start\n",
121
+ "noise_pred_20\n",
122
+ "t_21\n",
123
+ "latents_21_start\n",
124
+ "noise_pred_21\n",
125
+ "t_22\n",
126
+ "latents_22_start\n",
127
+ "noise_pred_22\n",
128
+ "t_23\n",
129
+ "latents_23_start\n",
130
+ "noise_pred_23\n",
131
+ "t_24\n",
132
+ "latents_24_start\n",
133
+ "noise_pred_24\n",
134
+ "t_25\n",
135
+ "latents_25_start\n",
136
+ "noise_pred_25\n",
137
+ "t_26\n",
138
+ "latents_26_start\n",
139
+ "noise_pred_26\n",
140
+ "t_27\n",
141
+ "latents_27_start\n",
142
+ "noise_pred_27\n",
143
+ "t_28\n",
144
+ "latents_28_start\n",
145
+ "noise_pred_28\n",
146
+ "t_29\n",
147
+ "latents_29_start\n",
148
+ "noise_pred_29\n",
149
+ "t_30\n",
150
+ "latents_30_start\n",
151
+ "noise_pred_30\n",
152
+ "t_31\n",
153
+ "latents_31_start\n",
154
+ "noise_pred_31\n",
155
+ "t_32\n",
156
+ "latents_32_start\n",
157
+ "noise_pred_32\n",
158
+ "t_33\n",
159
+ "latents_33_start\n",
160
+ "noise_pred_33\n",
161
+ "t_34\n",
162
+ "latents_34_start\n",
163
+ "noise_pred_34\n",
164
+ "t_35\n",
165
+ "latents_35_start\n",
166
+ "noise_pred_35\n",
167
+ "t_36\n",
168
+ "latents_36_start\n",
169
+ "noise_pred_36\n",
170
+ "t_37\n",
171
+ "latents_37_start\n",
172
+ "noise_pred_37\n",
173
+ "t_38\n",
174
+ "latents_38_start\n",
175
+ "noise_pred_38\n",
176
+ "t_39\n",
177
+ "latents_39_start\n",
178
+ "noise_pred_39\n",
179
+ "t_40\n",
180
+ "latents_40_start\n",
181
+ "noise_pred_40\n",
182
+ "t_41\n",
183
+ "latents_41_start\n",
184
+ "noise_pred_41\n",
185
+ "t_42\n",
186
+ "latents_42_start\n",
187
+ "noise_pred_42\n",
188
+ "t_43\n",
189
+ "latents_43_start\n",
190
+ "noise_pred_43\n",
191
+ "t_44\n",
192
+ "latents_44_start\n",
193
+ "noise_pred_44\n",
194
+ "t_45\n",
195
+ "latents_45_start\n",
196
+ "noise_pred_45\n",
197
+ "t_46\n",
198
+ "latents_46_start\n",
199
+ "noise_pred_46\n",
200
+ "t_47\n",
201
+ "latents_47_start\n",
202
+ "noise_pred_47\n",
203
+ "t_48\n",
204
+ "latents_48_start\n",
205
+ "noise_pred_48\n",
206
+ "t_49\n",
207
+ "latents_49_start\n",
208
+ "noise_pred_49\n",
209
+ "output\n",
210
+ "height\n",
211
+ "width\n"
212
+ ]
213
+ }
214
+ ],
215
+ "source": [
216
+ "import torch\n",
217
+ "\n",
218
+ "out = all_reg[0]\n",
219
+ "out_dict = torch.load(out)\n",
220
+ "for k in out_dict.keys():\n",
221
+ " print(k)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "74f693db",
228
+ "metadata": {},
229
+ "outputs": [
230
+ {
231
+ "data": {
232
+ "text/plain": [
233
+ "'003329'"
234
+ ]
235
+ },
236
+ "execution_count": 4,
237
+ "metadata": {},
238
+ "output_type": "execute_result"
239
+ }
240
+ ],
241
+ "source": []
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 7,
246
+ "id": "da107d9f",
247
+ "metadata": {},
248
+ "outputs": [
249
+ {
250
+ "name": "stdout",
251
+ "output_type": "stream",
252
+ "text": [
253
+ "69G\t/data/regression_output\n"
254
+ ]
255
+ }
256
+ ],
257
+ "source": [
258
+ "!du -h {base_data}"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "269c0bfb",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": []
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": 16,
272
+ "id": "5964bf2b",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "class RegressionSource:\n",
277
+ " # WIP\n",
278
+ "\n",
279
+ " def __init__(self, data_dir, gen_steps=50):\n",
280
+ " if not isinstance(data_dir, Path):\n",
281
+ " data_dir = Path(data_dir)\n",
282
+ " self.data_paths = list(data_dir.glob(\"*.pt\"))\n",
283
+ " self.gen_steps = gen_steps\n",
284
+ " self._len = gen_steps * len(self.data_paths)\n",
285
+ " \n",
286
+ " def __len__(self):\n",
287
+ " return self._len\n",
288
+ " \n",
289
+ " def __getitem__(self, idx):\n",
290
+ " data_idx = idx // self.gen_steps\n",
291
+ " step_idx = idx % self.gen_steps\n",
292
+ " out_dict = torch.load(self.data_paths[data_idx])\n",
293
+ " t = out_dict.pop(f\"t_{step_idx}\")\n",
294
+ " latents_start = out_dict.pop(f\"latents_{step_idx}_start\")\n",
295
+ " noise_pred = out_dict.pop(f\"noise_pred_{step_idx}\")\n",
296
+ " out_dict[\"t\"] = t\n",
297
+ " out_dict[\"latents_start\"] = latents_start\n",
298
+ " out_dict[\"noise_pred\"] = noise_pred\n",
299
+ " return out_dict\n",
300
+ "\n",
301
+ " \n"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": 17,
307
+ "id": "b62e7bec",
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "src = RegressionSource(\"/data/regression_output\")"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "id": "4ee68ab3",
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": []
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": 18,
325
+ "id": "9738e1d4",
326
+ "metadata": {},
327
+ "outputs": [
328
+ {
329
+ "data": {
330
+ "text/plain": [
331
+ "{'prompt_embeds': tensor([[[ 3.2188, 3.4375, 3.1719, ..., 0.3535, 1.7812, 2.0312],\n",
332
+ " [ 3.0938, 1.9297, 0.7031, ..., 2.0625, -0.2314, 1.2266],\n",
333
+ " [ 2.6250, 1.7031, 3.5625, ..., 0.8828, 2.1719, 1.4766],\n",
334
+ " ...,\n",
335
+ " [ 4.7812, 0.1689, 4.4688, ..., 5.0000, -1.8359, -0.7500],\n",
336
+ " [-0.0654, 2.1406, -1.4922, ..., 0.7930, 3.9844, 1.6406],\n",
337
+ " [-2.7031, 1.5547, 2.6094, ..., -0.0481, 0.1582, 0.7383]]],\n",
338
+ " dtype=torch.bfloat16),\n",
339
+ " 'prompt_embeds_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
340
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
341
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
342
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
343
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
344
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
345
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
346
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
347
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
348
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
349
+ " 'noise': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n",
350
+ " [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n",
351
+ " [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n",
352
+ " ...,\n",
353
+ " [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n",
354
+ " [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n",
355
+ " [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n",
356
+ " dtype=torch.bfloat16),\n",
357
+ " 'image_latents': tensor([[[ 0.1719, 0.0194, 0.0084, ..., -0.1494, 0.0552, 0.2295],\n",
358
+ " [ 0.1777, 0.1406, 0.1592, ..., 0.1260, -0.2412, -0.0041],\n",
359
+ " [ 0.1187, 0.2324, 0.1104, ..., 0.0801, 0.3516, 0.4414],\n",
360
+ " ...,\n",
361
+ " [-0.0972, -0.3242, -0.3027, ..., 0.3672, 0.1699, 0.4004],\n",
362
+ " [-0.1221, -0.0125, -0.3867, ..., 0.7031, 0.8477, 0.8320],\n",
363
+ " [-0.1416, -0.1914, -0.3359, ..., 0.9883, 1.3359, 0.7422]]],\n",
364
+ " dtype=torch.bfloat16),\n",
365
+ " 'vae_image_sizes': [(448, 576)],\n",
366
+ " 'img_shapes': [[(1, 36, 28), (1, 36, 28)]],\n",
367
+ " 'txt_seq_lens': [228],\n",
368
+ " 't_1': tensor([0.9883], dtype=torch.bfloat16),\n",
369
+ " 'latents_1_start': tensor([[[ 1.9531, -0.7930, 0.6289, ..., -1.7188, 1.0312, 0.3770],\n",
370
+ " [ 1.6406, 0.1143, -0.1533, ..., 0.7461, -0.1748, -1.6719],\n",
371
+ " [ 0.3945, 0.3887, 0.1895, ..., -0.9141, -1.3125, -2.2969],\n",
372
+ " ...,\n",
373
+ " [-1.4609, 0.2471, 1.3203, ..., 0.1826, -0.6133, 0.7578],\n",
374
+ " [ 1.0234, 1.3516, -0.1582, ..., 0.1226, 1.4609, -0.2715],\n",
375
+ " [ 0.4863, -1.1016, 2.3125, ..., 0.8281, -0.8086, 0.4297]]],\n",
376
+ " dtype=torch.bfloat16),\n",
377
+ " 'noise_pred_1': tensor([[[ 1.9062, -0.9102, 0.5742, ..., -1.7422, 1.0625, 0.3359],\n",
378
+ " [ 1.5859, 0.0306, -0.2637, ..., 0.7539, -0.1768, -1.7969],\n",
379
+ " [ 0.3184, 0.3066, 0.1592, ..., -1.0391, -1.5391, -2.5625],\n",
380
+ " ...,\n",
381
+ " [-1.2734, 0.4941, 1.5781, ..., -0.2344, -1.0156, 0.3477],\n",
382
+ " [ 1.2422, 1.5234, 0.0510, ..., -0.5820, 0.9219, -1.0859],\n",
383
+ " [ 0.6172, -0.9336, 2.5781, ..., -0.0801, -1.7734, -0.3730]]],\n",
384
+ " dtype=torch.bfloat16),\n",
385
+ " 't_2': tensor([0.9766], dtype=torch.bfloat16),\n",
386
+ " 'latents_2_start': tensor([[[ 1.9297, -0.7812, 0.6211, ..., -1.6953, 1.0156, 0.3730],\n",
387
+ " [ 1.6250, 0.1138, -0.1504, ..., 0.7383, -0.1729, -1.6484],\n",
388
+ " [ 0.3906, 0.3848, 0.1875, ..., -0.9023, -1.2969, -2.2656],\n",
389
+ " ...,\n",
390
+ " [-1.4453, 0.2412, 1.3047, ..., 0.1855, -0.6016, 0.7539],\n",
391
+ " [ 1.0078, 1.3359, -0.1592, ..., 0.1299, 1.4531, -0.2578],\n",
392
+ " [ 0.4785, -1.0938, 2.2812, ..., 0.8281, -0.7891, 0.4336]]],\n",
393
+ " dtype=torch.bfloat16),\n",
394
+ " 'noise_pred_2': tensor([[[ 1.8984, -0.9219, 0.5664, ..., -1.7188, 1.0703, 0.3633],\n",
395
+ " [ 1.5859, 0.0256, -0.2539, ..., 0.7578, -0.1719, -1.7656],\n",
396
+ " [ 0.3105, 0.3027, 0.1611, ..., -1.0000, -1.4688, -2.4688],\n",
397
+ " ...,\n",
398
+ " [-1.2969, 0.4453, 1.5625, ..., -0.1934, -0.9883, 0.4082],\n",
399
+ " [ 1.2188, 1.4844, -0.0028, ..., -0.4492, 1.0312, -0.9180],\n",
400
+ " [ 0.5820, -1.0156, 2.5156, ..., 0.1885, -1.5391, -0.1602]]],\n",
401
+ " dtype=torch.bfloat16),\n",
402
+ " 't_3': tensor([0.9648], dtype=torch.bfloat16),\n",
403
+ " 'latents_3_start': tensor([[[ 1.9062, -0.7695, 0.6133, ..., -1.6719, 1.0000, 0.3691],\n",
404
+ " [ 1.6094, 0.1133, -0.1475, ..., 0.7305, -0.1709, -1.6250],\n",
405
+ " [ 0.3867, 0.3809, 0.1855, ..., -0.8906, -1.2812, -2.2344],\n",
406
+ " ...,\n",
407
+ " [-1.4297, 0.2354, 1.2891, ..., 0.1875, -0.5898, 0.7500],\n",
408
+ " [ 0.9922, 1.3203, -0.1592, ..., 0.1357, 1.4375, -0.2461],\n",
409
+ " [ 0.4707, -1.0781, 2.2500, ..., 0.8242, -0.7695, 0.4355]]],\n",
410
+ " dtype=torch.bfloat16),\n",
411
+ " 'noise_pred_3': tensor([[[ 1.8984, -0.9180, 0.5430, ..., -1.7031, 1.0938, 0.3691],\n",
412
+ " [ 1.5703, 0.0308, -0.2676, ..., 0.7812, -0.1602, -1.7422],\n",
413
+ " [ 0.3164, 0.2949, 0.1514, ..., -0.9922, -1.4609, -2.4531],\n",
414
+ " ...,\n",
415
+ " [-1.3203, 0.4277, 1.5234, ..., -0.1611, -0.9688, 0.4434],\n",
416
+ " [ 1.1875, 1.4609, -0.0179, ..., -0.4355, 1.0312, -0.8867],\n",
417
+ " [ 0.5547, -1.0234, 2.4844, ..., 0.2344, -1.4844, -0.1025]]],\n",
418
+ " dtype=torch.bfloat16),\n",
419
+ " 't_4': tensor([0.9531], dtype=torch.bfloat16),\n",
420
+ " 'latents_4_start': tensor([[[ 1.8828, -0.7578, 0.6055, ..., -1.6484, 0.9844, 0.3652],\n",
421
+ " [ 1.5859, 0.1128, -0.1445, ..., 0.7188, -0.1689, -1.6016],\n",
422
+ " [ 0.3828, 0.3770, 0.1836, ..., -0.8789, -1.2656, -2.2031],\n",
423
+ " ...,\n",
424
+ " [-1.4141, 0.2305, 1.2734, ..., 0.1895, -0.5781, 0.7461],\n",
425
+ " [ 0.9766, 1.3047, -0.1592, ..., 0.1416, 1.4219, -0.2354],\n",
426
+ " [ 0.4629, -1.0625, 2.2188, ..., 0.8203, -0.7500, 0.4375]]],\n",
427
+ " dtype=torch.bfloat16),\n",
428
+ " 'noise_pred_4': tensor([[[ 1.8984, -0.9141, 0.5508, ..., -1.7109, 1.0859, 0.3672],\n",
429
+ " [ 1.5625, 0.0238, -0.2754, ..., 0.7656, -0.1768, -1.7578],\n",
430
+ " [ 0.3105, 0.2988, 0.1602, ..., -1.0156, -1.4766, -2.4375],\n",
431
+ " ...,\n",
432
+ " [-1.3125, 0.4316, 1.5469, ..., -0.1621, -0.9805, 0.4141],\n",
433
+ " [ 1.1953, 1.4844, -0.0118, ..., -0.4590, 1.0078, -0.9492],\n",
434
+ " [ 0.5703, -1.0156, 2.5156, ..., 0.1777, -1.5469, -0.1475]]],\n",
435
+ " dtype=torch.bfloat16),\n",
436
+ " 't_5': tensor([0.9414], dtype=torch.bfloat16),\n",
437
+ " 'latents_5_start': tensor([[[ 1.8594, -0.7461, 0.5977, ..., -1.6250, 0.9688, 0.3613],\n",
438
+ " [ 1.5625, 0.1123, -0.1406, ..., 0.7109, -0.1670, -1.5781],\n",
439
+ " [ 0.3789, 0.3730, 0.1816, ..., -0.8672, -1.2500, -2.1719],\n",
440
+ " ...,\n",
441
+ " [-1.3984, 0.2246, 1.2500, ..., 0.1914, -0.5664, 0.7422],\n",
442
+ " [ 0.9609, 1.2891, -0.1592, ..., 0.1475, 1.4062, -0.2236],\n",
443
+ " [ 0.4551, -1.0469, 2.1875, ..., 0.8164, -0.7305, 0.4395]]],\n",
444
+ " dtype=torch.bfloat16),\n",
445
+ " 'noise_pred_5': tensor([[[ 1.8906, -0.8984, 0.5586, ..., -1.7031, 1.0391, 0.3516],\n",
446
+ " [ 1.5625, 0.0388, -0.2871, ..., 0.8008, -0.1504, -1.7734],\n",
447
+ " [ 0.3008, 0.2949, 0.1777, ..., -1.0312, -1.5781, -2.5469],\n",
448
+ " ...,\n",
449
+ " [-1.3047, 0.4688, 1.5938, ..., -0.2188, -1.0781, 0.3945],\n",
450
+ " [ 1.2031, 1.4844, 0.0082, ..., -0.5469, 0.9414, -1.1875],\n",
451
+ " [ 0.5781, -0.9336, 2.5625, ..., -0.0903, -1.8047, -0.3828]]],\n",
452
+ " dtype=torch.bfloat16),\n",
453
+ " 't_6': tensor([0.9258], dtype=torch.bfloat16),\n",
454
+ " 'latents_6_start': tensor([[[ 1.8359, -0.7344, 0.5898, ..., -1.6016, 0.9570, 0.3574],\n",
455
+ " [ 1.5391, 0.1118, -0.1367, ..., 0.6992, -0.1650, -1.5547],\n",
456
+ " [ 0.3750, 0.3691, 0.1797, ..., -0.8555, -1.2266, -2.1406],\n",
457
+ " ...,\n",
458
+ " [-1.3828, 0.2188, 1.2266, ..., 0.1943, -0.5508, 0.7383],\n",
459
+ " [ 0.9453, 1.2734, -0.1592, ..., 0.1543, 1.3906, -0.2080],\n",
460
+ " [ 0.4473, -1.0312, 2.1562, ..., 0.8164, -0.7070, 0.4453]]],\n",
461
+ " dtype=torch.bfloat16),\n",
462
+ " 'noise_pred_6': tensor([[[ 1.9219, -0.8828, 0.5820, ..., -1.7109, 1.0781, 0.3613],\n",
463
+ " [ 1.5703, 0.0359, -0.2812, ..., 0.7773, -0.1865, -1.8203],\n",
464
+ " [ 0.3301, 0.2949, 0.1924, ..., -1.0781, -1.6016, -2.5312],\n",
465
+ " ...,\n",
466
+ " [-1.2734, 0.4941, 1.6094, ..., -0.2236, -1.0391, 0.3633],\n",
467
+ " [ 1.2656, 1.5469, 0.0796, ..., -0.6797, 0.8672, -1.2656],\n",
468
+ " [ 0.6484, -0.9102, 2.5938, ..., -0.1904, -1.8516, -0.4590]]],\n",
469
+ " dtype=torch.bfloat16),\n",
470
+ " 't_7': tensor([0.9102], dtype=torch.bfloat16),\n",
471
+ " 'latents_7_start': tensor([[[ 1.8125, -0.7227, 0.5820, ..., -1.5781, 0.9414, 0.3535],\n",
472
+ " [ 1.5156, 0.1113, -0.1328, ..., 0.6875, -0.1621, -1.5312],\n",
473
+ " [ 0.3711, 0.3652, 0.1768, ..., -0.8398, -1.2031, -2.1094],\n",
474
+ " ...,\n",
475
+ " [-1.3672, 0.2119, 1.2031, ..., 0.1973, -0.5352, 0.7344],\n",
476
+ " [ 0.9297, 1.2500, -0.1602, ..., 0.1631, 1.3828, -0.1914],\n",
477
+ " [ 0.4395, -1.0156, 2.1250, ..., 0.8203, -0.6836, 0.4512]]],\n",
478
+ " dtype=torch.bfloat16),\n",
479
+ " 'noise_pred_7': tensor([[[ 1.9531, -0.8906, 0.5938, ..., -1.7266, 1.0938, 0.4180],\n",
480
+ " [ 1.5781, 0.0309, -0.3008, ..., 0.7969, -0.1699, -1.8281],\n",
481
+ " [ 0.3262, 0.3008, 0.2314, ..., -1.0781, -1.6797, -2.6250],\n",
482
+ " ...,\n",
483
+ " [-1.2812, 0.5039, 1.5938, ..., -0.2314, -1.0547, 0.3828],\n",
484
+ " [ 1.2734, 1.5781, 0.0859, ..., -0.7930, 0.8555, -1.3516],\n",
485
+ " [ 0.6914, -0.9062, 2.6250, ..., -0.2598, -1.8516, -0.4902]]],\n",
486
+ " dtype=torch.bfloat16),\n",
487
+ " 't_8': tensor([0.8984], dtype=torch.bfloat16),\n",
488
+ " 'latents_8_start': tensor([[[ 1.7891, -0.7109, 0.5742, ..., -1.5547, 0.9258, 0.3477],\n",
489
+ " [ 1.4922, 0.1108, -0.1289, ..., 0.6758, -0.1602, -1.5078],\n",
490
+ " [ 0.3672, 0.3613, 0.1738, ..., -0.8242, -1.1797, -2.0781],\n",
491
+ " ...,\n",
492
+ " [-1.3516, 0.2051, 1.1797, ..., 0.2002, -0.5195, 0.7305],\n",
493
+ " [ 0.9141, 1.2266, -0.1611, ..., 0.1738, 1.3750, -0.1729],\n",
494
+ " [ 0.4297, -1.0000, 2.0938, ..., 0.8242, -0.6602, 0.4570]]],\n",
495
+ " dtype=torch.bfloat16),\n",
496
+ " 'noise_pred_8': tensor([[[ 1.9453, -0.8789, 0.6094, ..., -1.7266, 1.0781, 0.4082],\n",
497
+ " [ 1.5703, 0.0396, -0.3047, ..., 0.7891, -0.1826, -1.8516],\n",
498
+ " [ 0.3164, 0.2949, 0.2500, ..., -1.0859, -1.7031, -2.6250],\n",
499
+ " ...,\n",
500
+ " [-1.2578, 0.5234, 1.5938, ..., -0.2246, -1.0547, 0.3770],\n",
501
+ " [ 1.2734, 1.6016, 0.0884, ..., -0.8828, 0.8086, -1.3672],\n",
502
+ " [ 0.7070, -0.8828, 2.6094, ..., -0.2832, -1.8750, -0.5117]]],\n",
503
+ " dtype=torch.bfloat16),\n",
504
+ " 't_9': tensor([0.8828], dtype=torch.bfloat16),\n",
505
+ " 'latents_9_start': tensor([[[ 1.7656, -0.6992, 0.5664, ..., -1.5312, 0.9102, 0.3418],\n",
506
+ " [ 1.4688, 0.1104, -0.1245, ..., 0.6641, -0.1572, -1.4844],\n",
507
+ " [ 0.3633, 0.3574, 0.1699, ..., -0.8086, -1.1562, -2.0469],\n",
508
+ " ...,\n",
509
+ " [-1.3359, 0.1982, 1.1562, ..., 0.2031, -0.5039, 0.7266],\n",
510
+ " [ 0.8984, 1.2031, -0.1621, ..., 0.1855, 1.3672, -0.1543],\n",
511
+ " [ 0.4199, -0.9883, 2.0625, ..., 0.8281, -0.6328, 0.4648]]],\n",
512
+ " dtype=torch.bfloat16),\n",
513
+ " 'noise_pred_9': tensor([[[ 1.9531, -0.8828, 0.6172, ..., -1.7188, 1.0391, 0.4141],\n",
514
+ " [ 1.5703, 0.0583, -0.3125, ..., 0.7930, -0.1582, -1.8594],\n",
515
+ " [ 0.3203, 0.2910, 0.2598, ..., -1.1016, -1.7500, -2.6562],\n",
516
+ " ...,\n",
517
+ " [-1.2422, 0.5273, 1.6094, ..., -0.2168, -1.0391, 0.4121],\n",
518
+ " [ 1.2656, 1.6172, 0.1001, ..., -0.8984, 0.8008, -1.4062],\n",
519
+ " [ 0.7383, -0.8750, 2.6250, ..., -0.2891, -1.8672, -0.5312]]],\n",
520
+ " dtype=torch.bfloat16),\n",
521
+ " 't_10': tensor([0.8711], dtype=torch.bfloat16),\n",
522
+ " 'latents_10_start': tensor([[[ 1.7344, -0.6875, 0.5586, ..., -1.5078, 0.8945, 0.3359],\n",
523
+ " [ 1.4453, 0.1094, -0.1201, ..., 0.6523, -0.1553, -1.4609],\n",
524
+ " [ 0.3594, 0.3535, 0.1660, ..., -0.7930, -1.1328, -2.0156],\n",
525
+ " ...,\n",
526
+ " [-1.3203, 0.1904, 1.1328, ..., 0.2061, -0.4902, 0.7227],\n",
527
+ " [ 0.8789, 1.1797, -0.1631, ..., 0.1982, 1.3594, -0.1348],\n",
528
+ " [ 0.4102, -0.9766, 2.0312, ..., 0.8320, -0.6055, 0.4727]]],\n",
529
+ " dtype=torch.bfloat16),\n",
530
+ " 'noise_pred_10': tensor([[[ 1.9609, -0.8672, 0.6445, ..., -1.7188, 1.0156, 0.4180],\n",
531
+ " [ 1.5781, 0.0728, -0.3125, ..., 0.7812, -0.1602, -1.8828],\n",
532
+ " [ 0.3125, 0.2832, 0.2832, ..., -1.1172, -1.7734, -2.6875],\n",
533
+ " ...,\n",
534
+ " [-1.2500, 0.5273, 1.6016, ..., -0.2227, -1.0625, 0.4062],\n",
535
+ " [ 1.2500, 1.6094, 0.1016, ..., -0.9297, 0.8086, -1.4219],\n",
536
+ " [ 0.7617, -0.8555, 2.6406, ..., -0.3105, -1.8594, -0.5352]]],\n",
537
+ " dtype=torch.bfloat16),\n",
538
+ " 't_11': tensor([0.8555], dtype=torch.bfloat16),\n",
539
+ " 'latents_11_start': tensor([[[ 1.7031, -0.6758, 0.5508, ..., -1.4844, 0.8789, 0.3301],\n",
540
+ " [ 1.4219, 0.1084, -0.1157, ..., 0.6406, -0.1533, -1.4375],\n",
541
+ " [ 0.3555, 0.3496, 0.1621, ..., -0.7773, -1.1094, -1.9766],\n",
542
+ " ...,\n",
543
+ " [-1.3047, 0.1826, 1.1094, ..., 0.2090, -0.4746, 0.7188],\n",
544
+ " [ 0.8594, 1.1562, -0.1641, ..., 0.2119, 1.3516, -0.1143],\n",
545
+ " [ 0.3984, -0.9648, 1.9922, ..., 0.8359, -0.5781, 0.4805]]],\n",
546
+ " dtype=torch.bfloat16),\n",
547
+ " 'noise_pred_11': tensor([[[ 1.9688, -0.8516, 0.6484, ..., -1.7266, 1.0000, 0.4082],\n",
548
+ " [ 1.5938, 0.0679, -0.3105, ..., 0.8086, -0.1455, -1.8984],\n",
549
+ " [ 0.3203, 0.2812, 0.2949, ..., -1.1094, -1.7812, -2.6719],\n",
550
+ " ...,\n",
551
+ " [-1.2500, 0.5273, 1.5938, ..., -0.2119, -1.0625, 0.4102],\n",
552
+ " [ 1.2656, 1.6016, 0.1011, ..., -0.9180, 0.8281, -1.4531],\n",
553
+ " [ 0.7695, -0.8320, 2.6562, ..., -0.2891, -1.8516, -0.5234]]],\n",
554
+ " dtype=torch.bfloat16),\n",
555
+ " 't_12': tensor([0.8438], dtype=torch.bfloat16),\n",
556
+ " 'latents_12_start': tensor([[[ 1.6719, -0.6641, 0.5430, ..., -1.4609, 0.8633, 0.3242],\n",
557
+ " [ 1.3984, 0.1074, -0.1113, ..., 0.6289, -0.1514, -1.4062],\n",
558
+ " [ 0.3516, 0.3457, 0.1582, ..., -0.7617, -1.0859, -1.9375],\n",
559
+ " ...,\n",
560
+ " [-1.2891, 0.1748, 1.0859, ..., 0.2119, -0.4590, 0.7109],\n",
561
+ " [ 0.8398, 1.1328, -0.1660, ..., 0.2256, 1.3359, -0.0933],\n",
562
+ " [ 0.3867, -0.9531, 1.9531, ..., 0.8398, -0.5508, 0.4883]]],\n",
563
+ " dtype=torch.bfloat16),\n",
564
+ " 'noise_pred_12': tensor([[[ 1.9688, -0.8477, 0.6602, ..., -1.7422, 0.9805, 0.3965],\n",
565
+ " [ 1.5938, 0.0845, -0.3066, ..., 0.7891, -0.1816, -1.9062],\n",
566
+ " [ 0.3105, 0.2754, 0.3242, ..., -1.1328, -1.8047, -2.6875],\n",
567
+ " ...,\n",
568
+ " [-1.2422, 0.5195, 1.5938, ..., -0.2227, -1.0625, 0.4180],\n",
569
+ " [ 1.2500, 1.6172, 0.1138, ..., -0.9492, 0.8281, -1.4609],\n",
570
+ " [ 0.7852, -0.8555, 2.6562, ..., -0.3047, -1.8438, -0.5430]]],\n",
571
+ " dtype=torch.bfloat16),\n",
572
+ " 't_13': tensor([0.8281], dtype=torch.bfloat16),\n",
573
+ " 'latents_13_start': tensor([[[ 1.6406, -0.6523, 0.5312, ..., -1.4375, 0.8477, 0.3184],\n",
574
+ " [ 1.3750, 0.1060, -0.1069, ..., 0.6172, -0.1484, -1.3750],\n",
575
+ " [ 0.3477, 0.3418, 0.1533, ..., -0.7461, -1.0625, -1.8984],\n",
576
+ " ...,\n",
577
+ " [-1.2734, 0.1670, 1.0625, ..., 0.2148, -0.4434, 0.7031],\n",
578
+ " [ 0.8203, 1.1094, -0.1680, ..., 0.2393, 1.3203, -0.0718],\n",
579
+ " [ 0.3750, -0.9414, 1.9141, ..., 0.8438, -0.5234, 0.4961]]],\n",
580
+ " dtype=torch.bfloat16),\n",
581
+ " 'noise_pred_13': tensor([[[ 1.9688, -0.8438, 0.6562, ..., -1.7500, 0.9805, 0.3906],\n",
582
+ " [ 1.5938, 0.0791, -0.3066, ..., 0.7734, -0.1934, -1.9062],\n",
583
+ " [ 0.3145, 0.2734, 0.3203, ..., -1.1250, -1.8359, -2.7188],\n",
584
+ " ...,\n",
585
+ " [-1.2422, 0.5156, 1.6094, ..., -0.2021, -1.0312, 0.4609],\n",
586
+ " [ 1.2656, 1.6250, 0.1108, ..., -0.9453, 0.8789, -1.4688],\n",
587
+ " [ 0.7930, -0.8594, 2.6719, ..., -0.3105, -1.8359, -0.5469]]],\n",
588
+ " dtype=torch.bfloat16),\n",
589
+ " 't_14': tensor([0.8125], dtype=torch.bfloat16),\n",
590
+ " 'latents_14_start': tensor([[[ 1.6094, -0.6406, 0.5195, ..., -1.4141, 0.8320, 0.3125],\n",
591
+ " [ 1.3516, 0.1050, -0.1025, ..., 0.6055, -0.1455, -1.3438],\n",
592
+ " [ 0.3438, 0.3379, 0.1484, ..., -0.7305, -1.0312, -1.8594],\n",
593
+ " ...,\n",
594
+ " [-1.2578, 0.1592, 1.0391, ..., 0.2178, -0.4277, 0.6953],\n",
595
+ " [ 0.8008, 1.0859, -0.1699, ..., 0.2539, 1.3047, -0.0498],\n",
596
+ " [ 0.3633, -0.9297, 1.8750, ..., 0.8477, -0.4961, 0.5039]]],\n",
597
+ " dtype=torch.bfloat16),\n",
598
+ " 'noise_pred_14': tensor([[[ 1.9609, -0.8438, 0.6562, ..., -1.7500, 0.9727, 0.3828],\n",
599
+ " [ 1.5938, 0.0840, -0.3203, ..., 0.7695, -0.2061, -1.8906],\n",
600
+ " [ 0.3125, 0.2754, 0.3262, ..., -1.1328, -1.8359, -2.7031],\n",
601
+ " ...,\n",
602
+ " [-1.2422, 0.5117, 1.6094, ..., -0.2002, -1.0156, 0.4746],\n",
603
+ " [ 1.2656, 1.6172, 0.1108, ..., -0.9219, 0.8945, -1.4609],\n",
604
+ " [ 0.7969, -0.8672, 2.6406, ..., -0.3047, -1.7969, -0.5430]]],\n",
605
+ " dtype=torch.bfloat16),\n",
606
+ " 't_15': tensor([0.7969], dtype=torch.bfloat16),\n",
607
+ " 'latents_15_start': tensor([[[ 1.5781, -0.6289, 0.5078, ..., -1.3906, 0.8164, 0.3066],\n",
608
+ " [ 1.3281, 0.1035, -0.0977, ..., 0.5938, -0.1426, -1.3125],\n",
609
+ " [ 0.3398, 0.3340, 0.1436, ..., -0.7148, -1.0000, -1.8203],\n",
610
+ " ...,\n",
611
+ " [-1.2422, 0.1514, 1.0156, ..., 0.2207, -0.4121, 0.6875],\n",
612
+ " [ 0.7812, 1.0625, -0.1719, ..., 0.2676, 1.2891, -0.0275],\n",
613
+ " [ 0.3516, -0.9180, 1.8359, ..., 0.8516, -0.4688, 0.5117]]],\n",
614
+ " dtype=torch.bfloat16),\n",
615
+ " 'noise_pred_15': tensor([[[ 1.9531, -0.8320, 0.6641, ..., -1.7578, 0.9570, 0.3789],\n",
616
+ " [ 1.5938, 0.0806, -0.3184, ..., 0.7500, -0.2031, -1.8750],\n",
617
+ " [ 0.3242, 0.2656, 0.3301, ..., -1.1406, -1.8359, -2.7031],\n",
618
+ " ...,\n",
619
+ " [-1.2578, 0.5195, 1.6328, ..., -0.1875, -0.9883, 0.5117],\n",
620
+ " [ 1.2734, 1.6094, 0.1230, ..., -0.9297, 0.9102, -1.4531],\n",
621
+ " [ 0.7930, -0.8633, 2.6406, ..., -0.3027, -1.8203, -0.5312]]],\n",
622
+ " dtype=torch.bfloat16),\n",
623
+ " 't_16': tensor([0.7812], dtype=torch.bfloat16),\n",
624
+ " 'latents_16_start': tensor([[[ 1.5469, -0.6172, 0.4980, ..., -1.3594, 0.8008, 0.3008],\n",
625
+ " [ 1.3047, 0.1021, -0.0928, ..., 0.5820, -0.1396, -1.2812],\n",
626
+ " [ 0.3340, 0.3301, 0.1387, ..., -0.6953, -0.9727, -1.7812],\n",
627
+ " ...,\n",
628
+ " [-1.2188, 0.1436, 0.9883, ..., 0.2236, -0.3965, 0.6797],\n",
629
+ " [ 0.7617, 1.0391, -0.1738, ..., 0.2812, 1.2734, -0.0048],\n",
630
+ " [ 0.3398, -0.9062, 1.7969, ..., 0.8555, -0.4395, 0.5195]]],\n",
631
+ " dtype=torch.bfloat16),\n",
632
+ " 'noise_pred_16': tensor([[[ 1.9766, -0.8320, 0.6758, ..., -1.7734, 0.9844, 0.3867],\n",
633
+ " [ 1.6094, 0.0923, -0.3164, ..., 0.7617, -0.2148, -1.8984],\n",
634
+ " [ 0.3281, 0.2695, 0.3398, ..., -1.1484, -1.8516, -2.7500],\n",
635
+ " ...,\n",
636
+ " [-1.2500, 0.5156, 1.6406, ..., -0.1953, -0.9648, 0.5273],\n",
637
+ " [ 1.3125, 1.6250, 0.1113, ..., -0.9102, 0.9414, -1.4609],\n",
638
+ " [ 0.7969, -0.8750, 2.6719, ..., -0.2988, -1.7891, -0.5469]]],\n",
639
+ " dtype=torch.bfloat16),\n",
640
+ " 't_17': tensor([0.7656], dtype=torch.bfloat16),\n",
641
+ " 'latents_17_start': tensor([[[ 1.5156, -0.6055, 0.4883, ..., -1.3281, 0.7852, 0.2949],\n",
642
+ " [ 1.2812, 0.1006, -0.0879, ..., 0.5703, -0.1367, -1.2500],\n",
643
+ " [ 0.3281, 0.3262, 0.1328, ..., -0.6758, -0.9414, -1.7344],\n",
644
+ " ...,\n",
645
+ " [-1.1953, 0.1357, 0.9609, ..., 0.2266, -0.3809, 0.6719],\n",
646
+ " [ 0.7422, 1.0156, -0.1758, ..., 0.2949, 1.2578, 0.0184],\n",
647
+ " [ 0.3281, -0.8906, 1.7578, ..., 0.8594, -0.4102, 0.5273]]],\n",
648
+ " dtype=torch.bfloat16),\n",
649
+ " 'noise_pred_17': tensor([[[ 1.9688, -0.8242, 0.6719, ..., -1.7578, 0.9688, 0.3691],\n",
650
+ " [ 1.6094, 0.0869, -0.3145, ..., 0.7500, -0.2217, -1.8828],\n",
651
+ " [ 0.3203, 0.2754, 0.3457, ..., -1.1406, -1.8516, -2.7500],\n",
652
+ " ...,\n",
653
+ " [-1.2266, 0.5156, 1.6250, ..., -0.1904, -0.9492, 0.5273],\n",
654
+ " [ 1.3047, 1.6172, 0.1040, ..., -0.9141, 0.9570, -1.4531],\n",
655
+ " [ 0.7852, -0.8633, 2.6562, ..., -0.2949, -1.7969, -0.5430]]],\n",
656
+ " dtype=torch.bfloat16),\n",
657
+ " 't_18': tensor([0.7461], dtype=torch.bfloat16),\n",
658
+ " 'latents_18_start': tensor([[[ 1.4844, -0.5938, 0.4766, ..., -1.2969, 0.7695, 0.2891],\n",
659
+ " [ 1.2578, 0.0991, -0.0830, ..., 0.5586, -0.1328, -1.2188],\n",
660
+ " [ 0.3223, 0.3223, 0.1270, ..., -0.6562, -0.9102, -1.6875],\n",
661
+ " ...,\n",
662
+ " [-1.1719, 0.1270, 0.9336, ..., 0.2295, -0.3652, 0.6641],\n",
663
+ " [ 0.7227, 0.9883, -0.1777, ..., 0.3105, 1.2422, 0.0420],\n",
664
+ " [ 0.3145, -0.8750, 1.7109, ..., 0.8633, -0.3809, 0.5352]]],\n",
665
+ " dtype=torch.bfloat16),\n",
666
+ " 'noise_pred_18': tensor([[[ 1.9844, -0.8398, 0.6680, ..., -1.7578, 0.9727, 0.3730],\n",
667
+ " [ 1.6172, 0.0752, -0.3184, ..., 0.7578, -0.2148, -1.8828],\n",
668
+ " [ 0.3066, 0.2715, 0.3398, ..., -1.1328, -1.8516, -2.7500],\n",
669
+ " ...,\n",
670
+ " [-1.2422, 0.5156, 1.6484, ..., -0.1777, -0.9336, 0.5625],\n",
671
+ " [ 1.3047, 1.6328, 0.1147, ..., -0.8906, 0.9883, -1.4688],\n",
672
+ " [ 0.7734, -0.8672, 2.6406, ..., -0.2910, -1.7891, -0.5312]]],\n",
673
+ " dtype=torch.bfloat16),\n",
674
+ " 't_19': tensor([0.7305], dtype=torch.bfloat16),\n",
675
+ " 'latents_19_start': tensor([[[ 1.4531, -0.5781, 0.4648, ..., -1.2656, 0.7539, 0.2832],\n",
676
+ " [ 1.2344, 0.0977, -0.0776, ..., 0.5469, -0.1289, -1.1875],\n",
677
+ " [ 0.3164, 0.3184, 0.1211, ..., -0.6367, -0.8789, -1.6406],\n",
678
+ " ...,\n",
679
+ " [-1.1484, 0.1182, 0.9062, ..., 0.2324, -0.3496, 0.6562],\n",
680
+ " [ 0.6992, 0.9609, -0.1797, ..., 0.3262, 1.2266, 0.0664],\n",
681
+ " [ 0.3008, -0.8594, 1.6641, ..., 0.8672, -0.3516, 0.5430]]],\n",
682
+ " dtype=torch.bfloat16),\n",
683
+ " 'noise_pred_19': tensor([[[ 1.9844, -0.8516, 0.6641, ..., -1.7656, 0.9688, 0.3770],\n",
684
+ " [ 1.6094, 0.0684, -0.3281, ..., 0.7734, -0.2119, -1.8672],\n",
685
+ " [ 0.3086, 0.2695, 0.3301, ..., -1.1484, -1.8438, -2.7188],\n",
686
+ " ...,\n",
687
+ " [-1.2422, 0.5117, 1.6484, ..., -0.1738, -0.8984, 0.5781],\n",
688
+ " [ 1.3203, 1.6328, 0.1157, ..., -0.8828, 1.0078, -1.4844],\n",
689
+ " [ 0.7617, -0.8672, 2.6719, ..., -0.2480, -1.8125, -0.5273]]],\n",
690
+ " dtype=torch.bfloat16),\n",
691
+ " 't_20': tensor([0.7148], dtype=torch.bfloat16),\n",
692
+ " 'latents_20_start': tensor([[[ 1.4219, -0.5625, 0.4531, ..., -1.2344, 0.7383, 0.2773],\n",
693
+ " [ 1.2109, 0.0967, -0.0723, ..., 0.5352, -0.1250, -1.1562],\n",
694
+ " [ 0.3105, 0.3145, 0.1157, ..., -0.6172, -0.8477, -1.5938],\n",
695
+ " ...,\n",
696
+ " [-1.1250, 0.1094, 0.8789, ..., 0.2354, -0.3340, 0.6484],\n",
697
+ " [ 0.6758, 0.9336, -0.1816, ..., 0.3418, 1.2109, 0.0913],\n",
698
+ " [ 0.2871, -0.8438, 1.6172, ..., 0.8711, -0.3203, 0.5508]]],\n",
699
+ " dtype=torch.bfloat16),\n",
700
+ " 'noise_pred_20': tensor([[[ 1.9766, -0.8438, 0.6562, ..., -1.7656, 0.9570, 0.3809],\n",
701
+ " [ 1.6094, 0.0713, -0.3340, ..., 0.7734, -0.2246, -1.8594],\n",
702
+ " [ 0.2910, 0.2598, 0.3281, ..., -1.1250, -1.8359, -2.7500],\n",
703
+ " ...,\n",
704
+ " [-1.2422, 0.5039, 1.6406, ..., -0.1738, -0.8867, 0.6016],\n",
705
+ " [ 1.3125, 1.6172, 0.1187, ..., -0.8672, 1.0156, -1.4922],\n",
706
+ " [ 0.7578, -0.8711, 2.6719, ..., -0.2559, -1.7891, -0.5547]]],\n",
707
+ " dtype=torch.bfloat16),\n",
708
+ " 't_21': tensor([0.6992], dtype=torch.bfloat16),\n",
709
+ " 'latents_21_start': tensor([[[ 1.3906, -0.5469, 0.4414, ..., -1.2031, 0.7227, 0.2715],\n",
710
+ " [ 1.1797, 0.0952, -0.0664, ..., 0.5234, -0.1211, -1.1250],\n",
711
+ " [ 0.3047, 0.3105, 0.1099, ..., -0.5977, -0.8164, -1.5469],\n",
712
+ " ...,\n",
713
+ " [-1.1016, 0.1006, 0.8516, ..., 0.2383, -0.3184, 0.6367],\n",
714
+ " [ 0.6523, 0.9062, -0.1836, ..., 0.3574, 1.1953, 0.1172],\n",
715
+ " [ 0.2734, -0.8281, 1.5703, ..., 0.8750, -0.2891, 0.5586]]],\n",
716
+ " dtype=torch.bfloat16),\n",
717
+ " 'noise_pred_21': tensor([[[ 1.9844, -0.8281, 0.6680, ..., -1.7578, 0.9375, 0.3457],\n",
718
+ " [ 1.5938, 0.0811, -0.3203, ..., 0.7656, -0.2207, -1.8594],\n",
719
+ " [ 0.2773, 0.2559, 0.3340, ..., -1.1328, -1.8438, -2.7344],\n",
720
+ " ...,\n",
721
+ " [-1.2266, 0.4961, 1.6406, ..., -0.1953, -0.8750, 0.5859],\n",
722
+ " [ 1.2969, 1.6250, 0.1147, ..., -0.8594, 1.0156, -1.4922],\n",
723
+ " [ 0.7578, -0.8711, 2.6719, ..., -0.2617, -1.7891, -0.5508]]],\n",
724
+ " dtype=torch.bfloat16),\n",
725
+ " 't_22': tensor([0.6797], dtype=torch.bfloat16),\n",
726
+ " 'latents_22_start': tensor([[[ 1.3594, -0.5312, 0.4297, ..., -1.1719, 0.7070, 0.2656],\n",
727
+ " [ 1.1484, 0.0938, -0.0608, ..., 0.5117, -0.1172, -1.0938],\n",
728
+ " [ 0.3008, 0.3066, 0.1040, ..., -0.5781, -0.7852, -1.5000],\n",
729
+ " ...,\n",
730
+ " [-1.0781, 0.0918, 0.8242, ..., 0.2422, -0.3027, 0.6250],\n",
731
+ " [ 0.6289, 0.8789, -0.1855, ..., 0.3730, 1.1797, 0.1436],\n",
732
+ " [ 0.2598, -0.8125, 1.5234, ..., 0.8789, -0.2578, 0.5664]]],\n",
733
+ " dtype=torch.bfloat16),\n",
734
+ " 'noise_pred_22': tensor([[[ 1.9922, -0.8242, 0.6523, ..., -1.7500, 0.9375, 0.3477],\n",
735
+ " [ 1.5859, 0.0757, -0.3379, ..., 0.7578, -0.2178, -1.8438],\n",
736
+ " [ 0.2930, 0.2520, 0.3320, ..., -1.1250, -1.8516, -2.7500],\n",
737
+ " ...,\n",
738
+ " [-1.2031, 0.5000, 1.6406, ..., -0.2012, -0.8750, 0.5820],\n",
739
+ " [ 1.3047, 1.6094, 0.1309, ..., -0.8555, 1.0234, -1.5078],\n",
740
+ " [ 0.7617, -0.8711, 2.6562, ..., -0.2793, -1.7969, -0.5742]]],\n",
741
+ " dtype=torch.bfloat16),\n",
742
+ " 't_23': tensor([0.6641], dtype=torch.bfloat16),\n",
743
+ " 'latents_23_start': tensor([[[ 1.3203, -0.5156, 0.4180, ..., -1.1406, 0.6914, 0.2598],\n",
744
+ " [ 1.1172, 0.0923, -0.0547, ..., 0.4980, -0.1133, -1.0625],\n",
745
+ " [ 0.2949, 0.3027, 0.0981, ..., -0.5586, -0.7500, -1.4531],\n",
746
+ " ...,\n",
747
+ " [-1.0547, 0.0830, 0.7930, ..., 0.2461, -0.2871, 0.6133],\n",
748
+ " [ 0.6055, 0.8516, -0.1875, ..., 0.3887, 1.1641, 0.1709],\n",
749
+ " [ 0.2461, -0.7969, 1.4766, ..., 0.8828, -0.2256, 0.5781]]],\n",
750
+ " dtype=torch.bfloat16),\n",
751
+ " 'noise_pred_23': tensor([[[ 1.9688, -0.8203, 0.6562, ..., -1.7422, 0.9336, 0.3359],\n",
752
+ " [ 1.5781, 0.0923, -0.3164, ..., 0.7617, -0.2188, -1.8438],\n",
753
+ " [ 0.2969, 0.2617, 0.3203, ..., -1.1328, -1.8594, -2.7500],\n",
754
+ " ...,\n",
755
+ " [-1.2109, 0.5117, 1.6406, ..., -0.1914, -0.8711, 0.5938],\n",
756
+ " [ 1.2891, 1.6094, 0.1182, ..., -0.8477, 1.0469, -1.5000],\n",
757
+ " [ 0.7461, -0.8945, 2.6562, ..., -0.2852, -1.8047, -0.5586]]],\n",
758
+ " dtype=torch.bfloat16),\n",
759
+ " 't_24': tensor([0.6445], dtype=torch.bfloat16),\n",
760
+ " 'latents_24_start': tensor([[[ 1.2812, -0.5000, 0.4062, ..., -1.1094, 0.6758, 0.2539],\n",
761
+ " [ 1.0859, 0.0908, -0.0488, ..., 0.4844, -0.1094, -1.0312],\n",
762
+ " [ 0.2891, 0.2988, 0.0923, ..., -0.5391, -0.7148, -1.4062],\n",
763
+ " ...,\n",
764
+ " [-1.0312, 0.0737, 0.7617, ..., 0.2500, -0.2715, 0.6016],\n",
765
+ " [ 0.5820, 0.8203, -0.1895, ..., 0.4043, 1.1484, 0.1982],\n",
766
+ " [ 0.2324, -0.7812, 1.4297, ..., 0.8867, -0.1924, 0.5898]]],\n",
767
+ " dtype=torch.bfloat16),\n",
768
+ " 'noise_pred_24': tensor([[[ 1.9688, -0.8164, 0.6484, ..., -1.7422, 0.9492, 0.3574],\n",
769
+ " [ 1.5703, 0.0918, -0.3262, ..., 0.7734, -0.2207, -1.8438],\n",
770
+ " [ 0.2871, 0.2637, 0.3340, ..., -1.1250, -1.8516, -2.7656],\n",
771
+ " ...,\n",
772
+ " [-1.2031, 0.4961, 1.6328, ..., -0.1768, -0.8555, 0.6055],\n",
773
+ " [ 1.2891, 1.6172, 0.1211, ..., -0.8516, 1.0625, -1.5000],\n",
774
+ " [ 0.7461, -0.8867, 2.6562, ..., -0.2891, -1.8047, -0.5391]]],\n",
775
+ " dtype=torch.bfloat16),\n",
776
+ " 't_25': tensor([0.6289], dtype=torch.bfloat16),\n",
777
+ " 'latents_25_start': tensor([[[ 1.2422, -0.4844, 0.3945, ..., -1.0781, 0.6562, 0.2471],\n",
778
+ " [ 1.0547, 0.0889, -0.0427, ..., 0.4707, -0.1055, -0.9961],\n",
779
+ " [ 0.2832, 0.2930, 0.0859, ..., -0.5195, -0.6797, -1.3516],\n",
780
+ " ...,\n",
781
+ " [-1.0078, 0.0645, 0.7305, ..., 0.2539, -0.2559, 0.5898],\n",
782
+ " [ 0.5586, 0.7891, -0.1914, ..., 0.4199, 1.1250, 0.2266],\n",
783
+ " [ 0.2188, -0.7656, 1.3828, ..., 0.8906, -0.1582, 0.6016]]],\n",
784
+ " dtype=torch.bfloat16),\n",
785
+ " 'noise_pred_25': tensor([[[ 1.9609, -0.8242, 0.6523, ..., -1.7422, 0.9258, 0.3418],\n",
786
+ " [ 1.5625, 0.0850, -0.3359, ..., 0.7812, -0.2295, -1.8516],\n",
787
+ " [ 0.2871, 0.2520, 0.3184, ..., -1.1250, -1.8359, -2.7500],\n",
788
+ " ...,\n",
789
+ " [-1.1875, 0.4863, 1.6250, ..., -0.1924, -0.8633, 0.6055],\n",
790
+ " [ 1.2969, 1.6172, 0.1240, ..., -0.8359, 1.0625, -1.5078],\n",
791
+ " [ 0.7305, -0.8789, 2.6562, ..., -0.2969, -1.8203, -0.5469]]],\n",
792
+ " dtype=torch.bfloat16),\n",
793
+ " 't_26': tensor([0.6094], dtype=torch.bfloat16),\n",
794
+ " 'latents_26_start': tensor([[[ 1.2031, -0.4688, 0.3828, ..., -1.0469, 0.6406, 0.2402],\n",
795
+ " [ 1.0234, 0.0874, -0.0364, ..., 0.4551, -0.1011, -0.9609],\n",
796
+ " [ 0.2773, 0.2891, 0.0801, ..., -0.4980, -0.6445, -1.2969],\n",
797
+ " ...,\n",
798
+ " [-0.9844, 0.0552, 0.6992, ..., 0.2578, -0.2393, 0.5781],\n",
799
+ " [ 0.5352, 0.7578, -0.1934, ..., 0.4355, 1.1016, 0.2559],\n",
800
+ " [ 0.2051, -0.7500, 1.3359, ..., 0.8945, -0.1235, 0.6133]]],\n",
801
+ " dtype=torch.bfloat16),\n",
802
+ " 'noise_pred_26': tensor([[[ 1.9609, -0.8320, 0.6602, ..., -1.7578, 0.9219, 0.3496],\n",
803
+ " [ 1.5703, 0.0801, -0.3359, ..., 0.7812, -0.2266, -1.8281],\n",
804
+ " [ 0.2793, 0.2471, 0.3223, ..., -1.1172, -1.8281, -2.7188],\n",
805
+ " ...,\n",
806
+ " [-1.1797, 0.4863, 1.6172, ..., -0.2021, -0.8516, 0.6133],\n",
807
+ " [ 1.2969, 1.6016, 0.1226, ..., -0.8359, 1.0625, -1.5000],\n",
808
+ " [ 0.7305, -0.8906, 2.6562, ..., -0.2988, -1.8047, -0.5469]]],\n",
809
+ " dtype=torch.bfloat16),\n",
810
+ " 't_27': tensor([0.5898], dtype=torch.bfloat16),\n",
811
+ " 'latents_27_start': tensor([[[ 1.1641, -0.4531, 0.3691, ..., -1.0156, 0.6211, 0.2334],\n",
812
+ " [ 0.9922, 0.0859, -0.0298, ..., 0.4395, -0.0967, -0.9258],\n",
813
+ " [ 0.2715, 0.2852, 0.0737, ..., -0.4766, -0.6094, -1.2422],\n",
814
+ " ...,\n",
815
+ " [-0.9609, 0.0457, 0.6680, ..., 0.2617, -0.2227, 0.5664],\n",
816
+ " [ 0.5078, 0.7266, -0.1953, ..., 0.4512, 1.0781, 0.2852],\n",
817
+ " [ 0.1904, -0.7344, 1.2812, ..., 0.8984, -0.0884, 0.6250]]],\n",
818
+ " dtype=torch.bfloat16),\n",
819
+ " 'noise_pred_27': tensor([[[ 1.9453, -0.8398, 0.6562, ..., -1.7578, 0.9141, 0.3398],\n",
820
+ " [ 1.5469, 0.0654, -0.3379, ..., 0.7734, -0.2236, -1.8359],\n",
821
+ " [ 0.2773, 0.2354, 0.3223, ..., -1.1094, -1.8359, -2.7188],\n",
822
+ " ...,\n",
823
+ " [-1.1797, 0.4844, 1.6328, ..., -0.1992, -0.8359, 0.6172],\n",
824
+ " [ 1.2891, 1.5859, 0.1133, ..., -0.8242, 1.0469, -1.4922],\n",
825
+ " [ 0.7148, -0.9180, 2.6406, ..., -0.3047, -1.8203, -0.5508]]],\n",
826
+ " dtype=torch.bfloat16),\n",
827
+ " 't_28': tensor([0.5664], dtype=torch.bfloat16),\n",
828
+ " 'latents_28_start': tensor([[[ 1.1250, -0.4355, 0.3555, ..., -0.9805, 0.6016, 0.2266],\n",
829
+ " [ 0.9609, 0.0845, -0.0231, ..., 0.4238, -0.0923, -0.8906],\n",
830
+ " [ 0.2656, 0.2812, 0.0674, ..., -0.4551, -0.5742, -1.1875],\n",
831
+ " ...,\n",
832
+ " [-0.9375, 0.0361, 0.6367, ..., 0.2656, -0.2061, 0.5547],\n",
833
+ " [ 0.4824, 0.6953, -0.1973, ..., 0.4668, 1.0547, 0.3145],\n",
834
+ " [ 0.1758, -0.7148, 1.2266, ..., 0.9062, -0.0522, 0.6367]]],\n",
835
+ " dtype=torch.bfloat16),\n",
836
+ " 'noise_pred_28': tensor([[[ 1.9453, -0.8281, 0.6406, ..., -1.7344, 0.9219, 0.3672],\n",
837
+ " [ 1.5547, 0.0684, -0.3496, ..., 0.8047, -0.1953, -1.8281],\n",
838
+ " [ 0.2812, 0.2207, 0.3281, ..., -1.1016, -1.8359, -2.7188],\n",
839
+ " ...,\n",
840
+ " [-1.1953, 0.4668, 1.6328, ..., -0.1758, -0.8203, 0.6445],\n",
841
+ " [ 1.2734, 1.5781, 0.1045, ..., -0.7969, 1.0547, -1.4922],\n",
842
+ " [ 0.6953, -0.9258, 2.6562, ..., -0.3086, -1.7969, -0.5508]]],\n",
843
+ " dtype=torch.bfloat16),\n",
844
+ " 't_29': tensor([0.5469], dtype=torch.bfloat16),\n",
845
+ " 'latents_29_start': tensor([[[ 1.0859, -0.4180, 0.3418, ..., -0.9453, 0.5820, 0.2188],\n",
846
+ " [ 0.9297, 0.0830, -0.0159, ..., 0.4082, -0.0884, -0.8516],\n",
847
+ " [ 0.2598, 0.2773, 0.0608, ..., -0.4336, -0.5352, -1.1328],\n",
848
+ " ...,\n",
849
+ " [-0.9141, 0.0266, 0.6016, ..., 0.2695, -0.1895, 0.5430],\n",
850
+ " [ 0.4570, 0.6641, -0.1992, ..., 0.4824, 1.0312, 0.3457],\n",
851
+ " [ 0.1621, -0.6953, 1.1719, ..., 0.9141, -0.0156, 0.6484]]],\n",
852
+ " dtype=torch.bfloat16),\n",
853
+ " 'noise_pred_29': tensor([[[ 1.9688, -0.8320, 0.6445, ..., -1.7734, 0.9219, 0.3672],\n",
854
+ " [ 1.5469, 0.0732, -0.3477, ..., 0.7930, -0.2100, -1.8281],\n",
855
+ " [ 0.2793, 0.2354, 0.3262, ..., -1.1250, -1.8359, -2.7188],\n",
856
+ " ...,\n",
857
+ " [-1.1953, 0.4746, 1.6172, ..., -0.1738, -0.8086, 0.6484],\n",
858
+ " [ 1.2969, 1.5781, 0.0952, ..., -0.8164, 1.0859, -1.4766],\n",
859
+ " [ 0.6797, -0.9180, 2.6562, ..., -0.3145, -1.7812, -0.5508]]],\n",
860
+ " dtype=torch.bfloat16),\n",
861
+ " 't_30': tensor([0.5273], dtype=torch.bfloat16),\n",
862
+ " 'latents_30_start': tensor([[[ 1.0469, -0.4004, 0.3281, ..., -0.9102, 0.5625, 0.2109],\n",
863
+ " [ 0.8984, 0.0815, -0.0087, ..., 0.3926, -0.0840, -0.8125],\n",
864
+ " [ 0.2539, 0.2734, 0.0540, ..., -0.4102, -0.4961, -1.0781],\n",
865
+ " ...,\n",
866
+ " [-0.8906, 0.0168, 0.5664, ..., 0.2734, -0.1729, 0.5312],\n",
867
+ " [ 0.4297, 0.6328, -0.2012, ..., 0.5000, 1.0078, 0.3770],\n",
868
+ " [ 0.1484, -0.6758, 1.1172, ..., 0.9219, 0.0212, 0.6602]]],\n",
869
+ " dtype=torch.bfloat16),\n",
870
+ " 'noise_pred_30': tensor([[[ 1.9766, -0.8359, 0.6484, ..., -1.7812, 0.9219, 0.3691],\n",
871
+ " [ 1.5469, 0.0693, -0.3359, ..., 0.7969, -0.2090, -1.8203],\n",
872
+ " [ 0.2852, 0.2363, 0.3320, ..., -1.1172, -1.8359, -2.7344],\n",
873
+ " ...,\n",
874
+ " [-1.1953, 0.4766, 1.6406, ..., -0.1855, -0.8203, 0.6484],\n",
875
+ " [ 1.2891, 1.5781, 0.1064, ..., -0.8203, 1.0859, -1.4922],\n",
876
+ " [ 0.6953, -0.9219, 2.6719, ..., -0.3145, -1.7656, -0.5312]]],\n",
877
+ " dtype=torch.bfloat16),\n",
878
+ " 't_31': tensor([0.5078], dtype=torch.bfloat16),\n",
879
+ " 'latents_31_start': tensor([[[ 1.0078, -0.3828, 0.3145, ..., -0.8711, 0.5430, 0.2031],\n",
880
+ " [ 0.8672, 0.0801, -0.0015, ..., 0.3750, -0.0796, -0.7734],\n",
881
+ " [ 0.2480, 0.2676, 0.0469, ..., -0.3867, -0.4570, -1.0234],\n",
882
+ " ...,\n",
883
+ " [-0.8672, 0.0067, 0.5312, ..., 0.2773, -0.1553, 0.5156],\n",
884
+ " [ 0.4023, 0.5977, -0.2031, ..., 0.5156, 0.9844, 0.4082],\n",
885
+ " [ 0.1338, -0.6562, 1.0625, ..., 0.9297, 0.0588, 0.6719]]],\n",
886
+ " dtype=torch.bfloat16),\n",
887
+ " 'noise_pred_31': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9297, 0.3555],\n",
888
+ " [ 1.5469, 0.0796, -0.3516, ..., 0.7969, -0.2051, -1.8281],\n",
889
+ " [ 0.2754, 0.2344, 0.3262, ..., -1.1172, -1.8359, -2.7344],\n",
890
+ " ...,\n",
891
+ " [-1.1953, 0.4766, 1.6328, ..., -0.1895, -0.8125, 0.6289],\n",
892
+ " [ 1.2969, 1.5859, 0.0938, ..., -0.8125, 1.0781, -1.4766],\n",
893
+ " [ 0.6836, -0.9375, 2.6562, ..., -0.2949, -1.7734, -0.5234]]],\n",
894
+ " dtype=torch.bfloat16),\n",
895
+ " 't_32': tensor([0.4844], dtype=torch.bfloat16),\n",
896
+ " 'latents_32_start': tensor([[[ 0.9648, -0.3652, 0.3008, ..., -0.8320, 0.5234, 0.1953],\n",
897
+ " [ 0.8320, 0.0781, 0.0061, ..., 0.3574, -0.0752, -0.7344],\n",
898
+ " [ 0.2422, 0.2617, 0.0398, ..., -0.3633, -0.4180, -0.9648],\n",
899
+ " ...,\n",
900
+ " [-0.8398, -0.0037, 0.4961, ..., 0.2812, -0.1377, 0.5000],\n",
901
+ " [ 0.3750, 0.5625, -0.2051, ..., 0.5352, 0.9609, 0.4395],\n",
902
+ " [ 0.1191, -0.6367, 1.0078, ..., 0.9375, 0.0977, 0.6836]]],\n",
903
+ " dtype=torch.bfloat16),\n",
904
+ " 'noise_pred_32': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9258, 0.3535],\n",
905
+ " [ 1.5391, 0.0728, -0.3457, ..., 0.7891, -0.2021, -1.8203],\n",
906
+ " [ 0.2754, 0.2344, 0.3223, ..., -1.1172, -1.8281, -2.7344],\n",
907
+ " ...,\n",
908
+ " [-1.1875, 0.4688, 1.6250, ..., -0.1768, -0.8086, 0.6133],\n",
909
+ " [ 1.2812, 1.5703, 0.1079, ..., -0.8320, 1.0703, -1.4844],\n",
910
+ " [ 0.6914, -0.9297, 2.6562, ..., -0.3027, -1.7656, -0.5156]]],\n",
911
+ " dtype=torch.bfloat16),\n",
912
+ " 't_33': tensor([0.4629], dtype=torch.bfloat16),\n",
913
+ " 'latents_33_start': tensor([[[ 0.9219, -0.3477, 0.2871, ..., -0.7930, 0.5039, 0.1875],\n",
914
+ " [ 0.7969, 0.0767, 0.0138, ..., 0.3398, -0.0708, -0.6953],\n",
915
+ " [ 0.2363, 0.2559, 0.0327, ..., -0.3379, -0.3770, -0.9023],\n",
916
+ " ...,\n",
917
+ " [-0.8125, -0.0141, 0.4609, ..., 0.2852, -0.1196, 0.4863],\n",
918
+ " [ 0.3457, 0.5273, -0.2070, ..., 0.5547, 0.9375, 0.4727],\n",
919
+ " [ 0.1035, -0.6172, 0.9492, ..., 0.9453, 0.1367, 0.6953]]],\n",
920
+ " dtype=torch.bfloat16),\n",
921
+ " 'noise_pred_33': tensor([[[ 1.9609, -0.8242, 0.6484, ..., -1.7578, 0.9258, 0.3477],\n",
922
+ " [ 1.5312, 0.0684, -0.3379, ..., 0.7891, -0.2012, -1.8125],\n",
923
+ " [ 0.2812, 0.2158, 0.3164, ..., -1.1172, -1.8125, -2.7188],\n",
924
+ " ...,\n",
925
+ " [-1.1797, 0.4570, 1.6250, ..., -0.1826, -0.8086, 0.6055],\n",
926
+ " [ 1.2812, 1.5625, 0.1025, ..., -0.8125, 1.0625, -1.4922],\n",
927
+ " [ 0.6797, -0.9297, 2.6250, ..., -0.2949, -1.7656, -0.5117]]],\n",
928
+ " dtype=torch.bfloat16),\n",
929
+ " 't_34': tensor([0.4375], dtype=torch.bfloat16),\n",
930
+ " 'latents_34_start': tensor([[[ 0.8789, -0.3281, 0.2715, ..., -0.7539, 0.4824, 0.1797],\n",
931
+ " [ 0.7617, 0.0752, 0.0215, ..., 0.3223, -0.0664, -0.6523],\n",
932
+ " [ 0.2295, 0.2500, 0.0255, ..., -0.3125, -0.3359, -0.8398],\n",
933
+ " ...,\n",
934
+ " [-0.7852, -0.0245, 0.4238, ..., 0.2891, -0.1011, 0.4727],\n",
935
+ " [ 0.3164, 0.4922, -0.2090, ..., 0.5742, 0.9141, 0.5078],\n",
936
+ " [ 0.0879, -0.5977, 0.8906, ..., 0.9531, 0.1768, 0.7070]]],\n",
937
+ " dtype=torch.bfloat16),\n",
938
+ " 'noise_pred_34': tensor([[[ 1.9766, -0.8242, 0.6523, ..., -1.7812, 0.9258, 0.3535],\n",
939
+ " [ 1.5312, 0.0640, -0.3457, ..., 0.8086, -0.1934, -1.8047],\n",
940
+ " [ 0.2715, 0.1992, 0.3105, ..., -1.0938, -1.8047, -2.7344],\n",
941
+ " ...,\n",
942
+ " [-1.1875, 0.4609, 1.6172, ..., -0.1953, -0.8164, 0.6172],\n",
943
+ " [ 1.2812, 1.5625, 0.0942, ..., -0.8242, 1.0781, -1.4922],\n",
944
+ " [ 0.6562, -0.9531, 2.6406, ..., -0.3008, -1.7734, -0.4980]]],\n",
945
+ " dtype=torch.bfloat16),\n",
946
+ " 't_35': tensor([0.4160], dtype=torch.bfloat16),\n",
947
+ " 'latents_35_start': tensor([[[ 0.8320, -0.3086, 0.2559, ..., -0.7109, 0.4609, 0.1719],\n",
948
+ " [ 0.7266, 0.0737, 0.0295, ..., 0.3027, -0.0620, -0.6094],\n",
949
+ " [ 0.2236, 0.2451, 0.0183, ..., -0.2871, -0.2930, -0.7773],\n",
950
+ " ...,\n",
951
+ " [-0.7578, -0.0352, 0.3867, ..., 0.2930, -0.0820, 0.4590],\n",
952
+ " [ 0.2871, 0.4551, -0.2109, ..., 0.5938, 0.8906, 0.5430],\n",
953
+ " [ 0.0728, -0.5742, 0.8281, ..., 0.9609, 0.2178, 0.7188]]],\n",
954
+ " dtype=torch.bfloat16),\n",
955
+ " 'noise_pred_35': tensor([[[ 1.9766, -0.8242, 0.6484, ..., -1.7734, 0.9258, 0.3496],\n",
956
+ " [ 1.5312, 0.0708, -0.3418, ..., 0.8047, -0.1953, -1.7969],\n",
957
+ " [ 0.2832, 0.1875, 0.3086, ..., -1.0938, -1.7891, -2.7344],\n",
958
+ " ...,\n",
959
+ " [-1.1719, 0.4551, 1.6172, ..., -0.1953, -0.8047, 0.6055],\n",
960
+ " [ 1.2578, 1.5625, 0.0898, ..., -0.8242, 1.0859, -1.4766],\n",
961
+ " [ 0.6602, -0.9414, 2.6406, ..., -0.3164, -1.7656, -0.5078]]],\n",
962
+ " dtype=torch.bfloat16),\n",
963
+ " 't_36': tensor([0.3926], dtype=torch.bfloat16),\n",
964
+ " 'latents_36_start': tensor([[[ 0.7852, -0.2891, 0.2402, ..., -0.6680, 0.4395, 0.1641],\n",
965
+ " [ 0.6914, 0.0723, 0.0376, ..., 0.2832, -0.0574, -0.5664],\n",
966
+ " [ 0.2168, 0.2402, 0.0110, ..., -0.2617, -0.2500, -0.7109],\n",
967
+ " ...,\n",
968
+ " [-0.7305, -0.0459, 0.3477, ..., 0.2969, -0.0630, 0.4453],\n",
969
+ " [ 0.2578, 0.4180, -0.2129, ..., 0.6133, 0.8633, 0.5781],\n",
970
+ " [ 0.0571, -0.5508, 0.7656, ..., 0.9688, 0.2598, 0.7305]]],\n",
971
+ " dtype=torch.bfloat16),\n",
972
+ " 'noise_pred_36': tensor([[[ 1.9609, -0.8164, 0.6445, ..., -1.7656, 0.9102, 0.3477],\n",
973
+ " [ 1.5234, 0.0654, -0.3320, ..., 0.8164, -0.2041, -1.7812],\n",
974
+ " [ 0.2734, 0.1836, 0.3066, ..., -1.0938, -1.7891, -2.7031],\n",
975
+ " ...,\n",
976
+ " [-1.1875, 0.4688, 1.6016, ..., -0.2051, -0.8008, 0.5977],\n",
977
+ " [ 1.2500, 1.5234, 0.0747, ..., -0.8438, 1.0781, -1.4922],\n",
978
+ " [ 0.6484, -0.9219, 2.6250, ..., -0.3027, -1.7812, -0.5078]]],\n",
979
+ " dtype=torch.bfloat16),\n",
980
+ " 't_37': tensor([0.3652], dtype=torch.bfloat16),\n",
981
+ " 'latents_37_start': tensor([[[ 0.7383, -0.2695, 0.2246, ..., -0.6250, 0.4180, 0.1553],\n",
982
+ " [ 0.6562, 0.0708, 0.0457, ..., 0.2637, -0.0525, -0.5234],\n",
983
+ " [ 0.2100, 0.2354, 0.0035, ..., -0.2354, -0.2061, -0.6445],\n",
984
+ " ...,\n",
985
+ " [-0.7031, -0.0574, 0.3086, ..., 0.3027, -0.0435, 0.4316],\n",
986
+ " [ 0.2275, 0.3809, -0.2148, ..., 0.6328, 0.8359, 0.6133],\n",
987
+ " [ 0.0413, -0.5273, 0.7031, ..., 0.9766, 0.3027, 0.7422]]],\n",
988
+ " dtype=torch.bfloat16),\n",
989
+ " 'noise_pred_37': tensor([[[ 1.9922, -0.8320, 0.6523, ..., -1.7812, 0.9336, 0.3613],\n",
990
+ " [ 1.5312, 0.0713, -0.3477, ..., 0.8281, -0.1963, -1.7891],\n",
991
+ " [ 0.2754, 0.1777, 0.2930, ..., -1.0859, -1.7734, -2.7031],\n",
992
+ " ...,\n",
993
+ " [-1.1875, 0.4590, 1.6172, ..., -0.1855, -0.8086, 0.5820],\n",
994
+ " [ 1.2422, 1.5391, 0.0552, ..., -0.8633, 1.0781, -1.5156],\n",
995
+ " [ 0.6602, -0.9219, 2.6250, ..., -0.3047, -1.7734, -0.5000]]],\n",
996
+ " dtype=torch.bfloat16),\n",
997
+ " 't_38': tensor([0.3418], dtype=torch.bfloat16),\n",
998
+ " 'latents_38_start': tensor([[[ 0.6875, -0.2490, 0.2080, ..., -0.5820, 0.3945, 0.1465],\n",
999
+ " [ 0.6172, 0.0688, 0.0544, ..., 0.2432, -0.0476, -0.4785],\n",
1000
+ " [ 0.2031, 0.2305, -0.0038, ..., -0.2080, -0.1621, -0.5781],\n",
1001
+ " ...,\n",
1002
+ " [-0.6719, -0.0688, 0.2676, ..., 0.3066, -0.0232, 0.4180],\n",
1003
+ " [ 0.1963, 0.3418, -0.2158, ..., 0.6562, 0.8086, 0.6523],\n",
1004
+ " [ 0.0248, -0.5039, 0.6367, ..., 0.9844, 0.3477, 0.7539]]],\n",
1005
+ " dtype=torch.bfloat16),\n",
1006
+ " 'noise_pred_38': tensor([[[ 1.9766, -0.8477, 0.6484, ..., -1.7812, 0.9531, 0.3789],\n",
1007
+ " [ 1.5234, 0.0781, -0.3496, ..., 0.8281, -0.1973, -1.7734],\n",
1008
+ " [ 0.2617, 0.1660, 0.2949, ..., -1.0547, -1.7422, -2.6875],\n",
1009
+ " ...,\n",
1010
+ " [-1.1641, 0.4668, 1.6328, ..., -0.1836, -0.8125, 0.5547],\n",
1011
+ " [ 1.2344, 1.5312, 0.0752, ..., -0.8867, 1.0234, -1.5156],\n",
1012
+ " [ 0.6406, -0.9219, 2.6250, ..., -0.2988, -1.7812, -0.5117]]],\n",
1013
+ " dtype=torch.bfloat16),\n",
1014
+ " 't_39': tensor([0.3164], dtype=torch.bfloat16),\n",
1015
+ " 'latents_39_start': tensor([[[ 0.6367, -0.2275, 0.1914, ..., -0.5352, 0.3711, 0.1367],\n",
1016
+ " [ 0.5781, 0.0669, 0.0635, ..., 0.2217, -0.0425, -0.4336],\n",
1017
+ " [ 0.1963, 0.2266, -0.0114, ..., -0.1807, -0.1172, -0.5078],\n",
1018
+ " ...,\n",
1019
+ " [-0.6406, -0.0811, 0.2256, ..., 0.3105, -0.0023, 0.4043],\n",
1020
+ " [ 0.1641, 0.3027, -0.2178, ..., 0.6797, 0.7812, 0.6914],\n",
1021
+ " [ 0.0083, -0.4805, 0.5703, ..., 0.9922, 0.3926, 0.7656]]],\n",
1022
+ " dtype=torch.bfloat16),\n",
1023
+ " 'noise_pred_39': tensor([[[ 1.9531, -0.8320, 0.6523, ..., -1.7500, 0.9336, 0.3965],\n",
1024
+ " [ 1.5156, 0.0767, -0.3691, ..., 0.8281, -0.1777, -1.7500],\n",
1025
+ " [ 0.2637, 0.1729, 0.2734, ..., -1.0234, -1.6797, -2.6250],\n",
1026
+ " ...,\n",
1027
+ " [-1.1406, 0.4531, 1.6016, ..., -0.1973, -0.8164, 0.5234],\n",
1028
+ " [ 1.2344, 1.5078, 0.0593, ..., -0.8906, 1.0234, -1.4844],\n",
1029
+ " [ 0.6367, -0.9023, 2.5938, ..., -0.2910, -1.7500, -0.5078]]],\n",
1030
+ " dtype=torch.bfloat16),\n",
1031
+ " 't_40': tensor([0.2891], dtype=torch.bfloat16),\n",
1032
+ " 'latents_40_start': tensor([[[ 0.5859, -0.2061, 0.1738, ..., -0.4883, 0.3457, 0.1260],\n",
1033
+ " [ 0.5391, 0.0649, 0.0732, ..., 0.2002, -0.0378, -0.3867],\n",
1034
+ " [ 0.1895, 0.2217, -0.0186, ..., -0.1543, -0.0732, -0.4395],\n",
1035
+ " ...,\n",
1036
+ " [-0.6094, -0.0928, 0.1836, ..., 0.3164, 0.0192, 0.3906],\n",
1037
+ " [ 0.1318, 0.2637, -0.2197, ..., 0.7031, 0.7539, 0.7305],\n",
1038
+ " [-0.0084, -0.4570, 0.5039, ..., 1.0000, 0.4375, 0.7773]]],\n",
1039
+ " dtype=torch.bfloat16),\n",
1040
+ " 'noise_pred_40': tensor([[[ 1.9766, -0.8555, 0.6445, ..., -1.7500, 0.9414, 0.4004],\n",
1041
+ " [ 1.5234, 0.0693, -0.3613, ..., 0.8672, -0.1650, -1.7266],\n",
1042
+ " [ 0.2598, 0.1660, 0.2734, ..., -1.0156, -1.7031, -2.6250],\n",
1043
+ " ...,\n",
1044
+ " [-1.1641, 0.4531, 1.6016, ..., -0.1611, -0.8125, 0.5039],\n",
1045
+ " [ 1.2109, 1.5156, 0.0391, ..., -0.8906, 0.9883, -1.4688],\n",
1046
+ " [ 0.6250, -0.9102, 2.6094, ..., -0.2930, -1.7578, -0.4922]]],\n",
1047
+ " dtype=torch.bfloat16),\n",
1048
+ " 't_41': tensor([0.2617], dtype=torch.bfloat16),\n",
1049
+ " 'latents_41_start': tensor([[[ 0.5312, -0.1826, 0.1562, ..., -0.4414, 0.3203, 0.1152],\n",
1050
+ " [ 0.4980, 0.0630, 0.0830, ..., 0.1768, -0.0334, -0.3398],\n",
1051
+ " [ 0.1826, 0.2168, -0.0259, ..., -0.1270, -0.0273, -0.3691],\n",
1052
+ " ...,\n",
1053
+ " [-0.5781, -0.1050, 0.1406, ..., 0.3203, 0.0410, 0.3770],\n",
1054
+ " [ 0.0991, 0.2227, -0.2207, ..., 0.7266, 0.7266, 0.7695],\n",
1055
+ " [-0.0253, -0.4316, 0.4336, ..., 1.0078, 0.4844, 0.7891]]],\n",
1056
+ " dtype=torch.bfloat16),\n",
1057
+ " 'noise_pred_41': tensor([[[ 1.9375, -0.8555, 0.6367, ..., -1.7266, 0.9531, 0.4297],\n",
1058
+ " [ 1.5312, 0.0659, -0.3633, ..., 0.8711, -0.1660, -1.7266],\n",
1059
+ " [ 0.2520, 0.1826, 0.2676, ..., -0.9844, -1.6328, -2.5781],\n",
1060
+ " ...,\n",
1061
+ " [-1.1484, 0.4453, 1.5859, ..., -0.1562, -0.8281, 0.4922],\n",
1062
+ " [ 1.1797, 1.5078, 0.0229, ..., -0.9102, 0.9531, -1.4609],\n",
1063
+ " [ 0.6211, -0.9102, 2.5938, ..., -0.2832, -1.7266, -0.4727]]],\n",
1064
+ " dtype=torch.bfloat16),\n",
1065
+ " 't_42': tensor([0.2354], dtype=torch.bfloat16),\n",
1066
+ " 'latents_42_start': tensor([[[ 0.4785, -0.1592, 0.1387, ..., -0.3945, 0.2949, 0.1035],\n",
1067
+ " [ 0.4551, 0.0613, 0.0928, ..., 0.1523, -0.0288, -0.2930],\n",
1068
+ " [ 0.1758, 0.2119, -0.0332, ..., -0.0996, 0.0178, -0.2969],\n",
1069
+ " ...,\n",
1070
+ " [-0.5469, -0.1172, 0.0967, ..., 0.3242, 0.0640, 0.3633],\n",
1071
+ " [ 0.0664, 0.1816, -0.2217, ..., 0.7500, 0.6992, 0.8086],\n",
1072
+ " [-0.0425, -0.4062, 0.3613, ..., 1.0156, 0.5312, 0.8008]]],\n",
1073
+ " dtype=torch.bfloat16),\n",
1074
+ " 'noise_pred_42': tensor([[[ 1.9297, -0.8594, 0.6289, ..., -1.7031, 0.9609, 0.4297],\n",
1075
+ " [ 1.5000, 0.0811, -0.3652, ..., 0.8711, -0.1494, -1.7031],\n",
1076
+ " [ 0.2246, 0.1543, 0.2695, ..., -0.9492, -1.6328, -2.5312],\n",
1077
+ " ...,\n",
1078
+ " [-1.1328, 0.4395, 1.5781, ..., -0.1680, -0.8398, 0.4453],\n",
1079
+ " [ 1.1484, 1.4766, 0.0073, ..., -0.9648, 0.8984, -1.4219],\n",
1080
+ " [ 0.5938, -0.8672, 2.5312, ..., -0.2949, -1.7031, -0.4766]]],\n",
1081
+ " dtype=torch.bfloat16),\n",
1082
+ " 't_43': tensor([0.2070], dtype=torch.bfloat16),\n",
1083
+ " 'latents_43_start': tensor([[[ 0.4238, -0.1348, 0.1211, ..., -0.3457, 0.2676, 0.0913],\n",
1084
+ " [ 0.4121, 0.0591, 0.1030, ..., 0.1279, -0.0245, -0.2441],\n",
1085
+ " [ 0.1699, 0.2080, -0.0408, ..., -0.0728, 0.0640, -0.2246],\n",
1086
+ " ...,\n",
1087
+ " [-0.5156, -0.1299, 0.0520, ..., 0.3281, 0.0879, 0.3516],\n",
1088
+ " [ 0.0339, 0.1396, -0.2217, ..., 0.7773, 0.6719, 0.8477],\n",
1089
+ " [-0.0593, -0.3809, 0.2891, ..., 1.0234, 0.5781, 0.8125]]],\n",
1090
+ " dtype=torch.bfloat16),\n",
1091
+ " 'noise_pred_43': tensor([[[ 1.9141, -0.8711, 0.6328, ..., -1.6484, 0.9883, 0.4453],\n",
1092
+ " [ 1.4844, 0.0693, -0.4004, ..., 0.8867, -0.1152, -1.6797],\n",
1093
+ " [ 0.2432, 0.1484, 0.2090, ..., -0.8867, -1.5938, -2.4531],\n",
1094
+ " ...,\n",
1095
+ " [-1.1094, 0.4453, 1.5703, ..., -0.1865, -0.8594, 0.3906],\n",
1096
+ " [ 1.0938, 1.4375, -0.0374, ..., -0.9648, 0.8359, -1.4531],\n",
1097
+ " [ 0.5898, -0.8516, 2.4688, ..., -0.2773, -1.6484, -0.4746]]],\n",
1098
+ " dtype=torch.bfloat16),\n",
1099
+ " 't_44': tensor([0.1777], dtype=torch.bfloat16),\n",
1100
+ " 'latents_44_start': tensor([[[ 0.3672, -0.1094, 0.1025, ..., -0.2969, 0.2393, 0.0781],\n",
1101
+ " [ 0.3691, 0.0571, 0.1147, ..., 0.1021, -0.0212, -0.1953],\n",
1102
+ " [ 0.1631, 0.2041, -0.0469, ..., -0.0469, 0.1104, -0.1533],\n",
1103
+ " ...,\n",
1104
+ " [-0.4844, -0.1426, 0.0063, ..., 0.3340, 0.1128, 0.3398],\n",
1105
+ " [ 0.0022, 0.0977, -0.2207, ..., 0.8047, 0.6484, 0.8906],\n",
1106
+ " [-0.0762, -0.3555, 0.2168, ..., 1.0312, 0.6250, 0.8281]]],\n",
1107
+ " dtype=torch.bfloat16),\n",
1108
+ " 'noise_pred_44': tensor([[[ 1.8984, -0.8984, 0.5977, ..., -1.6094, 0.9805, 0.4551],\n",
1109
+ " [ 1.4609, 0.0474, -0.4160, ..., 0.9102, -0.1060, -1.6484],\n",
1110
+ " [ 0.2119, 0.0791, 0.2021, ..., -0.8359, -1.5391, -2.4219],\n",
1111
+ " ...,\n",
1112
+ " [-1.0859, 0.4395, 1.5391, ..., -0.2002, -0.8516, 0.3359],\n",
1113
+ " [ 1.0625, 1.4219, -0.0486, ..., -0.9609, 0.7930, -1.4453],\n",
1114
+ " [ 0.5898, -0.8008, 2.4219, ..., -0.3223, -1.5703, -0.4609]]],\n",
1115
+ " dtype=torch.bfloat16),\n",
1116
+ " 't_45': tensor([0.1484], dtype=torch.bfloat16),\n",
1117
+ " 'latents_45_start': tensor([[[ 0.3105, -0.0825, 0.0850, ..., -0.2490, 0.2100, 0.0645],\n",
1118
+ " [ 0.3262, 0.0557, 0.1270, ..., 0.0747, -0.0181, -0.1465],\n",
1119
+ " [ 0.1562, 0.2021, -0.0530, ..., -0.0219, 0.1562, -0.0811],\n",
1120
+ " ...,\n",
1121
+ " [-0.4512, -0.1553, -0.0398, ..., 0.3398, 0.1387, 0.3301],\n",
1122
+ " [-0.0295, 0.0552, -0.2197, ..., 0.8320, 0.6250, 0.9336],\n",
1123
+ " [-0.0938, -0.3320, 0.1445, ..., 1.0391, 0.6719, 0.8438]]],\n",
1124
+ " dtype=torch.bfloat16),\n",
1125
+ " 'noise_pred_45': tensor([[[ 1.8516, -0.8984, 0.5547, ..., -1.5547, 0.9688, 0.4727],\n",
1126
+ " [ 1.4219, 0.0500, -0.4258, ..., 0.9141, -0.0588, -1.5703],\n",
1127
+ " [ 0.1592, 0.0850, 0.1924, ..., -0.7500, -1.4766, -2.3125],\n",
1128
+ " ...,\n",
1129
+ " [-1.0469, 0.4414, 1.4766, ..., -0.1494, -0.8438, 0.3047],\n",
1130
+ " [ 0.9883, 1.4062, -0.0874, ..., -0.9844, 0.7422, -1.4297],\n",
1131
+ " [ 0.5742, -0.7578, 2.3125, ..., -0.3301, -1.3672, -0.4238]]],\n",
1132
+ " dtype=torch.bfloat16),\n",
1133
+ " 't_46': tensor([0.1172], dtype=torch.bfloat16),\n",
1134
+ " 'latents_46_start': tensor([[[ 0.2539, -0.0549, 0.0679, ..., -0.2012, 0.1807, 0.0500],\n",
1135
+ " [ 0.2832, 0.0542, 0.1396, ..., 0.0469, -0.0162, -0.0986],\n",
1136
+ " [ 0.1514, 0.1992, -0.0588, ..., 0.0011, 0.2012, -0.0103],\n",
1137
+ " ...,\n",
1138
+ " [-0.4199, -0.1689, -0.0850, ..., 0.3438, 0.1641, 0.3203],\n",
1139
+ " [-0.0598, 0.0122, -0.2168, ..., 0.8633, 0.6016, 0.9766],\n",
1140
+ " [-0.1113, -0.3086, 0.0737, ..., 1.0469, 0.7148, 0.8555]]],\n",
1141
+ " dtype=torch.bfloat16),\n",
1142
+ " 'noise_pred_46': tensor([[[ 1.7734, -0.9492, 0.5430, ..., -1.5312, 0.9727, 0.5273],\n",
1143
+ " [ 1.4219, 0.0054, -0.4844, ..., 0.9297, 0.0198, -1.4531],\n",
1144
+ " [ 0.1377, -0.0330, 0.1299, ..., -0.6914, -1.3906, -2.2188],\n",
1145
+ " ...,\n",
1146
+ " [-0.9727, 0.4297, 1.3906, ..., -0.1172, -0.8164, 0.2295],\n",
1147
+ " [ 0.8945, 1.3516, -0.1758, ..., -0.9961, 0.6445, -1.5078],\n",
1148
+ " [ 0.5234, -0.7422, 2.2031, ..., -0.3848, -1.1641, -0.4883]]],\n",
1149
+ " dtype=torch.bfloat16),\n",
1150
+ " 't_47': tensor([0.0854], dtype=torch.bfloat16),\n",
1151
+ " 'latents_47_start': tensor([[[ 0.1982, -0.0250, 0.0508, ..., -0.1523, 0.1504, 0.0334],\n",
1152
+ " [ 0.2383, 0.0540, 0.1553, ..., 0.0176, -0.0168, -0.0530],\n",
1153
+ " [ 0.1475, 0.2002, -0.0630, ..., 0.0228, 0.2451, 0.0596],\n",
1154
+ " ...,\n",
1155
+ " [-0.3887, -0.1826, -0.1289, ..., 0.3477, 0.1895, 0.3125],\n",
1156
+ " [-0.0879, -0.0303, -0.2109, ..., 0.8945, 0.5820, 1.0234],\n",
1157
+ " [-0.1279, -0.2852, 0.0044, ..., 1.0625, 0.7500, 0.8711]]],\n",
1158
+ " dtype=torch.bfloat16),\n",
1159
+ " 'noise_pred_47': tensor([[[ 1.6250, -0.9688, 0.4160, ..., -1.4609, 0.9961, 0.5742],\n",
1160
+ " [ 1.3594, 0.0728, -0.5430, ..., 0.9062, 0.0530, -1.3438],\n",
1161
+ " [ 0.1553, -0.1787, 0.0908, ..., -0.5820, -1.1875, -1.9688],\n",
1162
+ " ...,\n",
1163
+ " [-0.8281, 0.4160, 1.2422, ..., -0.0122, -0.7500, 0.1396],\n",
1164
+ " [ 0.7734, 1.2812, -0.2295, ..., -0.9883, 0.5039, -1.4844],\n",
1165
+ " [ 0.4453, -0.6719, 1.9688, ..., -0.4180, -0.9141, -0.6211]]],\n",
1166
+ " dtype=torch.bfloat16),\n",
1167
+ " 't_48': tensor([0.0532], dtype=torch.bfloat16),\n",
1168
+ " 'latents_48_start': tensor([[[ 0.1455, 0.0065, 0.0374, ..., -0.1050, 0.1182, 0.0148],\n",
1169
+ " [ 0.1943, 0.0515, 0.1729, ..., -0.0118, -0.0186, -0.0093],\n",
1170
+ " [ 0.1426, 0.2061, -0.0659, ..., 0.0417, 0.2832, 0.1235],\n",
1171
+ " ...,\n",
1172
+ " [-0.3613, -0.1963, -0.1689, ..., 0.3477, 0.2139, 0.3086],\n",
1173
+ " [-0.1133, -0.0718, -0.2031, ..., 0.9258, 0.5664, 1.0703],\n",
1174
+ " [-0.1426, -0.2637, -0.0596, ..., 1.0781, 0.7812, 0.8906]]],\n",
1175
+ " dtype=torch.bfloat16),\n",
1176
+ " 'noise_pred_48': tensor([[[ 1.2031, -0.9297, 0.2559, ..., -1.4141, 0.8906, 0.4258],\n",
1177
+ " [ 1.1016, 0.0625, -0.3242, ..., 0.8047, 0.1318, -1.1094],\n",
1178
+ " [ 0.1094, -0.2695, 0.2334, ..., -0.5820, -1.1016, -1.5234],\n",
1179
+ " ...,\n",
1180
+ " [-0.5938, 0.4551, 1.0938, ..., 0.0281, -0.6289, 0.1357],\n",
1181
+ " [ 0.6523, 1.0625, -0.2275, ..., -1.0938, 0.4297, -1.3984],\n",
1182
+ " [ 0.3906, -0.4180, 1.5391, ..., -0.5742, -0.6250, -0.7852]]],\n",
1183
+ " dtype=torch.bfloat16),\n",
1184
+ " 't_49': tensor([0.0200], dtype=torch.bfloat16),\n",
1185
+ " 'latents_49_start': tensor([[[ 1.0547e-01, 3.7354e-02, 2.8809e-02, ..., -5.8105e-02,\n",
1186
+ " 8.8867e-02, 6.1035e-04],\n",
1187
+ " [ 1.5820e-01, 4.9316e-02, 1.8359e-01, ..., -3.8574e-02,\n",
1188
+ " -2.2949e-02, 2.7588e-02],\n",
1189
+ " [ 1.3867e-01, 2.1484e-01, -7.3730e-02, ..., 6.1035e-02,\n",
1190
+ " 3.2031e-01, 1.7383e-01],\n",
1191
+ " ...,\n",
1192
+ " [-3.4180e-01, -2.1094e-01, -2.0508e-01, ..., 3.4766e-01,\n",
1193
+ " 2.3438e-01, 3.0469e-01],\n",
1194
+ " [-1.3477e-01, -1.0693e-01, -1.9531e-01, ..., 9.6094e-01,\n",
1195
+ " 5.5078e-01, 1.1172e+00],\n",
1196
+ " [-1.5527e-01, -2.5000e-01, -1.1035e-01, ..., 1.0938e+00,\n",
1197
+ " 8.0078e-01, 9.1797e-01]]], dtype=torch.bfloat16),\n",
1198
+ " 'noise_pred_49': tensor([[[ 0.7461, -0.5586, 0.2197, ..., -1.0469, 0.7109, 0.4902],\n",
1199
+ " [ 0.6094, 0.0464, -0.1650, ..., 0.4980, 0.2314, -0.9414],\n",
1200
+ " [ 0.1064, -0.2109, 0.1846, ..., -0.3633, -0.8086, -1.0234],\n",
1201
+ " ...,\n",
1202
+ " [-0.2559, 0.3711, 0.7461, ..., -0.2217, -0.2988, 0.0339],\n",
1203
+ " [ 0.4980, 0.5156, -0.0260, ..., -1.1250, 0.1064, -1.1250],\n",
1204
+ " [ 0.2471, 0.0179, 0.6875, ..., -0.7188, -0.5898, -0.8672]]],\n",
1205
+ " dtype=torch.bfloat16),\n",
1206
+ " 'output': tensor([[[ 0.0903, 0.0486, 0.0244, ..., -0.0371, 0.0747, -0.0092],\n",
1207
+ " [ 0.1465, 0.0483, 0.1865, ..., -0.0486, -0.0276, 0.0464],\n",
1208
+ " [ 0.1367, 0.2188, -0.0776, ..., 0.0684, 0.3359, 0.1943],\n",
1209
+ " ...,\n",
1210
+ " [-0.3359, -0.2188, -0.2197, ..., 0.3516, 0.2402, 0.3047],\n",
1211
+ " [-0.1445, -0.1172, -0.1943, ..., 0.9844, 0.5469, 1.1406],\n",
1212
+ " [-0.1602, -0.2500, -0.1240, ..., 1.1094, 0.8125, 0.9336]]],\n",
1213
+ " dtype=torch.bfloat16),\n",
1214
+ " 'height': 576,\n",
1215
+ " 'width': 448,\n",
1216
+ " 't': tensor([1.], dtype=torch.bfloat16),\n",
1217
+ " 'latents_start': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n",
1218
+ " [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n",
1219
+ " [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n",
1220
+ " ...,\n",
1221
+ " [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n",
1222
+ " [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n",
1223
+ " [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n",
1224
+ " dtype=torch.bfloat16),\n",
1225
+ " 'noise_pred': tensor([[[ 1.8906, -0.8945, 0.5938, ..., -1.7578, 1.0078, 0.2539],\n",
1226
+ " [ 1.5781, 0.0278, -0.2793, ..., 0.7305, -0.1553, -1.7969],\n",
1227
+ " [ 0.3027, 0.2949, 0.1621, ..., -1.0625, -1.5938, -2.6406],\n",
1228
+ " ...,\n",
1229
+ " [-1.2578, 0.5352, 1.5859, ..., -0.2773, -1.0312, 0.3203],\n",
1230
+ " [ 1.2734, 1.5312, 0.0728, ..., -0.6211, 0.8984, -1.1562],\n",
1231
+ " [ 0.6172, -0.9336, 2.6719, ..., -0.1050, -1.8672, -0.3691]]],\n",
1232
+ " dtype=torch.bfloat16)}"
1233
+ ]
1234
+ },
1235
+ "execution_count": 18,
1236
+ "metadata": {},
1237
+ "output_type": "execute_result"
1238
+ }
1239
+ ],
1240
+ "source": [
1241
+ "src[0]"
1242
+ ]
1243
+ },
1244
+ {
1245
+ "cell_type": "code",
1246
+ "execution_count": null,
1247
+ "id": "22f19ae9",
1248
+ "metadata": {},
1249
+ "outputs": [],
1250
+ "source": []
1251
+ }
1252
+ ],
1253
+ "metadata": {
1254
+ "kernelspec": {
1255
+ "display_name": "Python 3",
1256
+ "language": "python",
1257
+ "name": "python3"
1258
+ },
1259
+ "language_info": {
1260
+ "codemirror_mode": {
1261
+ "name": "ipython",
1262
+ "version": 3
1263
+ },
1264
+ "file_extension": ".py",
1265
+ "mimetype": "text/x-python",
1266
+ "name": "python",
1267
+ "nbconvert_exporter": "python",
1268
+ "pygments_lexer": "ipython3",
1269
+ "version": "3.10.12"
1270
+ }
1271
+ },
1272
+ "nbformat": 4,
1273
+ "nbformat_minor": 5
1274
+ }
scripts/save_regression_outputs.py CHANGED
@@ -1,94 +1,59 @@
1
- #
2
- # %cd /home/ubuntu/Qwen-Image-Edit-Angles
3
 
4
  import torch
5
- import huggingface_hub
6
  import tqdm
 
7
 
8
  from qwenimage.datamodels import QwenConfig
9
  from qwenimage.foundation import QwenImageFoundationSaveInterm
10
- from datasets import concatenate_datasets, load_dataset, interleave_datasets
11
-
12
- repo_tree = huggingface_hub.list_repo_tree(
13
- "WeiChow/CrispEdit-2M",
14
- "data",
15
- repo_type="dataset",
16
- )
17
-
18
- all_paths = []
19
- for i in repo_tree:
20
- all_paths.append(i.path)
21
-
22
- parquet_prefixes = set()
23
- for path in all_paths:
24
- if path.endswith('.parquet'):
25
- filename = path.split('/')[-1]
26
- if '_' in filename:
27
- prefix = filename.split('_')[0]
28
- parquet_prefixes.add(prefix)
29
-
30
- print("Found parquet prefixes:", sorted(parquet_prefixes))
31
-
32
-
33
-
34
- total_per = 10
35
-
36
- EDIT_TYPES = [
37
- "color",
38
- "style",
39
- "replace",
40
- "remove",
41
- "add",
42
- "motion change",
43
- "background change",
44
- ]
45
-
46
-
47
-
48
-
49
- all_edit_datasets = []
50
- for edit_type in EDIT_TYPES:
51
- to_concat = []
52
- for ds_n in range(total_per):
53
- ds = load_dataset("parquet", data_files=f"/data/CrispEdit/{edit_type}_{ds_n:05d}.parquet", split="train")
54
- to_concat.append(ds)
55
- edit_type_concat = concatenate_datasets(to_concat)
56
- all_edit_datasets.append(edit_type_concat)
57
-
58
- # consistent ordering for indexing, also allow extension
59
- join_ds = interleave_datasets(all_edit_datasets)
60
-
61
-
62
-
63
-
64
-
65
-
66
-
67
-
68
- from pathlib import Path
69
-
70
-
71
- save_base_dir = Path("/data/regression_output")
72
- save_base_dir.mkdir(exist_ok=True, parents=True)
73
-
74
-
75
 
76
 
 
 
 
 
77
 
78
- foundation = QwenImageFoundationSaveInterm(QwenConfig())
79
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
81
 
 
 
82
 
 
 
83
 
84
- for idx, input_data in enumerate(tqdm.tqdm(join_ds)):
85
 
86
- output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
87
- image=[input_data["input_img"]],
88
- prompt=input_data["instruction"],
89
- ))
90
 
91
- torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
 
 
 
92
 
 
93
 
94
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
 
4
  import torch
 
5
  import tqdm
6
+ from datasets import concatenate_datasets, load_dataset, interleave_datasets
7
 
8
  from qwenimage.datamodels import QwenConfig
9
  from qwenimage.foundation import QwenImageFoundationSaveInterm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
+ def main():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--start-index", type=int, default=0)
15
+ args = parser.parse_args()
16
 
17
+ total_per = 10
18
 
19
+ EDIT_TYPES = [
20
+ "color",
21
+ "style",
22
+ "replace",
23
+ "remove",
24
+ "add",
25
+ "motion change",
26
+ "background change",
27
+ ]
28
 
29
+ all_edit_datasets = []
30
+ for edit_type in EDIT_TYPES:
31
+ to_concat = []
32
+ for ds_n in range(total_per):
33
+ ds = load_dataset("parquet", data_files=f"/data/CrispEdit/{edit_type}_{ds_n:05d}.parquet", split="train")
34
+ to_concat.append(ds)
35
+ edit_type_concat = concatenate_datasets(to_concat)
36
+ all_edit_datasets.append(edit_type_concat)
37
 
38
+ # consistent ordering for indexing, also allow extension
39
+ join_ds = interleave_datasets(all_edit_datasets)
40
 
41
+ save_base_dir = Path("/data/regression_output")
42
+ save_base_dir.mkdir(exist_ok=True, parents=True)
43
 
44
+ foundation = QwenImageFoundationSaveInterm(QwenConfig())
45
 
46
+ dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
47
+
48
+ for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
 
49
 
50
+ output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
51
+ image=[input_data["input_img"]],
52
+ prompt=input_data["instruction"],
53
+ ))
54
 
55
+ torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
56
 
57
 
58
+ if __name__ == "__main__":
59
+ main()