Elea Zhong commited on
Commit
6064267
·
1 Parent(s): 789676e

triplet loss experiments (prelim)

Browse files
configs/base.yaml CHANGED
@@ -11,12 +11,12 @@ gradient_accumulation_steps: 1
11
  train_batch_size: 1
12
  optim: "adamw"
13
  learning_rate: 1.0e-4
14
- num_workers: 4
15
  resume_from_checkpoint: null
16
  log_model_steps: null
17
  preprocessing_epoch_len: 64
18
  preprocessing_epoch_repetitions: 1
19
-
20
 
21
  # Logging
22
  record_training: true
@@ -33,5 +33,8 @@ sample_steps:
33
  every: 500
34
  global_step: 0
35
  save_steps: 1000
36
- log_batch_steps: 500
 
 
37
  seed: 67
 
 
11
  train_batch_size: 1
12
  optim: "adamw"
13
  learning_rate: 1.0e-4
14
+ num_workers: 8
15
  resume_from_checkpoint: null
16
  log_model_steps: null
17
  preprocessing_epoch_len: 64
18
  preprocessing_epoch_repetitions: 1
19
+ lora_rank: 16
20
 
21
  # Logging
22
  record_training: true
 
33
  every: 500
34
  global_step: 0
35
  save_steps: 1000
36
+ log_batch_steps:
37
+ "on": [0,1,100]
38
+ every: 500
39
  seed: 67
40
+
configs/regression/base.yaml CHANGED
@@ -19,3 +19,9 @@ regression_data_dir: "/data/regression_output"
19
  regression_gen_steps: 50
20
  editing_data_dir: "/data/CrispEdit"
21
  editing_total_per: 1
 
 
 
 
 
 
 
19
  regression_gen_steps: 50
20
  editing_data_dir: "/data/CrispEdit"
21
  editing_total_per: 1
22
+
23
+
24
+ validation_loss_terms:
25
+ mse: 1.0
26
+ pixel_mse: 1.0
27
+ pixel_lpips: 1.0
configs/regression/modal-datadirs.yaml DELETED
@@ -1,2 +0,0 @@
1
- regression_data_dir: "/data/regression_data/regression_output"
2
- editing_data_dir: "/data/edit_data/CrispEdit"
 
 
 
configs/regression/modal.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ regression_data_dir: "/data/regression_data/regression_output_1024"
2
+ editing_data_dir: "/data/edit_data/CrispEdit"
3
+
4
+ lora_rank: 32
5
+ vae_image_size: 1048576 # 1024 * 1024
configs/regression/mse-dm.yaml DELETED
@@ -1,11 +0,0 @@
1
- wandb_run_name: "reg-mse-dm"
2
- output_dir: "/data/checkpoints/reg-mse-dm"
3
-
4
- train_loss_terms:
5
- mse: 1.0
6
- distribution_matching: 1.0
7
-
8
- validation_loss_terms:
9
- mse: 1.0
10
- distribution_matching: 1.0
11
-
 
 
 
 
 
 
 
 
 
 
 
 
configs/regression/mse-neg-mse.yaml CHANGED
@@ -5,7 +5,4 @@ train_loss_terms:
5
  mse: 1.0
6
  negative_mse: 0.1
7
 
8
- validation_loss_terms:
9
- mse: 1.0
10
- negative_mse: 0.1
11
 
 
5
  mse: 1.0
6
  negative_mse: 0.1
7
 
 
 
 
8
 
configs/regression/mse-pixel-lpips.yaml CHANGED
@@ -4,7 +4,3 @@ output_dir: "/data/checkpoints/reg-mse-pixel-lpips"
4
  train_loss_terms:
5
  mse: 1.0
6
  pixel_lpips: 1.0
7
-
8
- validation_loss_terms:
9
- mse: 1.0
10
- pixel_lpips: 1.0
 
4
  train_loss_terms:
5
  mse: 1.0
6
  pixel_lpips: 1.0
 
 
 
 
configs/regression/mse-pixel-mse.yaml CHANGED
@@ -5,7 +5,3 @@ train_loss_terms:
5
  mse: 1.0
6
  pixel_mse: 1.0
7
 
8
- validation_loss_terms:
9
- mse: 1.0
10
- pixel_mse: 1.0
11
-
 
5
  mse: 1.0
6
  pixel_mse: 1.0
7
 
 
 
 
 
configs/regression/mse-triplet.yaml DELETED
@@ -1,13 +0,0 @@
1
- wandb_run_name: "reg-mse-triplet"
2
- output_dir: "/data/checkpoints/reg-mse-triplet"
3
-
4
- train_loss_terms:
5
- mse: 1.0
6
- triplet: 1.0
7
-
8
- validation_loss_terms:
9
- mse: 1.0
10
- triplet: 1.0
11
-
12
-
13
- triplet_margin: -500 # tune
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/regression/mse.yaml CHANGED
@@ -4,6 +4,4 @@ output_dir: "/data/checkpoints/reg-mse"
4
  train_loss_terms:
5
  mse: 1.0
6
 
7
- validation_loss_terms:
8
- mse: 1.0
9
 
 
4
  train_loss_terms:
5
  mse: 1.0
6
 
 
 
7
 
configs/regression/triplet/mse-triplet-a.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-a"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-a"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: 0.0
9
+ triplet_min_abs_diff: 0.0
configs/regression/triplet/mse-triplet-b.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-b"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-b"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: 0.0
9
+ triplet_min_abs_diff: 0.1
configs/regression/triplet/mse-triplet-c.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-c"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-c"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: 0.1
9
+ triplet_min_abs_diff: 0.1
configs/regression/triplet/mse-triplet-d.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-c"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-c"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: 0.5
9
+ triplet_min_abs_diff: 0.1
configs/regression/triplet/mse-triplet-e.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "reg-mse-triplet-e"
2
+ output_dir: "/data/checkpoints/reg-mse-triplet-e"
3
+
4
+ train_loss_terms:
5
+ mse: 1.0
6
+ triplet: 1.0
7
+
8
+ triplet_margin: 0.1
9
+ triplet_min_abs_diff: 0.25
configs/regression/val_metrics.yaml DELETED
@@ -1,9 +0,0 @@
1
-
2
-
3
-
4
- validation_loss_terms:
5
- mse: 1.0
6
- pixel_mse: 1.0
7
- pixel_lpips: 1.0
8
-
9
-
 
 
 
 
 
 
 
 
 
 
qwenimage/datamodels.py CHANGED
@@ -50,12 +50,20 @@ class QwenLossTerms(BaseModel):
50
  triplet: LossTermSpecType = 0.0
51
  negative_mse: LossTermSpecType = 0.0
52
  distribution_matching: LossTermSpecType = 0.0
53
- negative_exponential: LossTermSpecType = 0.0
54
  pixel_lpips: LossTermSpecType = 0.0
55
  pixel_mse: LossTermSpecType = 0.0
 
56
  adversarial: LossTermSpecType = 0.0
 
57
 
58
- triplet_margin: float = 0.2
 
 
 
 
 
 
59
 
60
  class QwenConfig(ExperimentTrainerParameters):
61
  load_multi_view_lora: bool = False
 
50
  triplet: LossTermSpecType = 0.0
51
  negative_mse: LossTermSpecType = 0.0
52
  distribution_matching: LossTermSpecType = 0.0
53
+ pixel_triplet: LossTermSpecType = 0.0
54
  pixel_lpips: LossTermSpecType = 0.0
55
  pixel_mse: LossTermSpecType = 0.0
56
+ pixel_distribution_matching: LossTermSpecType = 0.0
57
  adversarial: LossTermSpecType = 0.0
58
+ teacher: LossTermSpecType = 0.0
59
 
60
+ triplet_margin: float = 0.0
61
+ triplet_min_abs_diff: float = 0.0
62
+ teacher_steps: int = 4
63
+
64
+ @property
65
+ def pixel_terms(self) -> bool:
66
+ return ("pixel_lpips", "pixel_mse", "pixel_triplet", "pixel_distribution_matching",)
67
 
68
  class QwenConfig(ExperimentTrainerParameters):
69
  load_multi_view_lora: bool = False
qwenimage/foundation.py CHANGED
@@ -395,13 +395,14 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
395
  split = batch["split"]
396
  step = batch["step"]
397
  if split == "train":
398
- loss_terms = self.config.train_loss_terms.model_dump()
399
  elif split == "validation":
400
- loss_terms = self.config.validation_loss_terms.model_dump()
401
  loss_accumulator = LossAccumulator(
402
- terms=loss_terms,
403
  step=step,
404
  split=split,
 
405
  )
406
 
407
  if loss_accumulator.has("mse"):
@@ -414,31 +415,42 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
414
  loss_accumulator.accum("mse", mse_loss)
415
 
416
  if loss_accumulator.has("triplet"):
417
- # eps = 1e-6
418
- margin = loss_terms["triplet_margin"]
419
- # v_span = (v_gt_1d - v_neg_1d).pow(2).sum(dim=(-2,-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
- # triplet_weight = (v_gt_1d - v_neg_1d).pow(2).mean(dim=(-2,-1))
422
-
423
- diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
424
- diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
425
- # diffv_gt_pred_reg = diffv_gt_pred # / (v_span + eps)
426
- # diffv_neg_pred_reg = diffv_neg_pred # / (v_span + eps)
427
-
428
- # texam(v_span, name="v_span")
429
- # texam(triplet_weight, name="triplet_weight")
430
- texam(diffv_gt_pred, name="diffv_gt_pred")
431
- texam(diffv_neg_pred, name="diffv_neg_pred")
432
- # texam(diffv_gt_pred_reg, name="diffv_gt_pred_reg")
433
- # texam(diffv_neg_pred_reg, name="diffv_neg_pred_reg")
434
- # texam(diffv_gt_pred_reg - diffv_neg_pred_reg, name="diffv_gt_pred_reg - diffv_neg_pred_reg")
435
-
436
- triplet_loss = F.relu(diffv_gt_pred - diffv_neg_pred + margin).mean()
437
- #
438
- # triplet_loss = F.relu(diffv_gt_pred_reg - diffv_neg_pred_reg + margin).mean()
439
- # triplet_loss = torch.mean(triplet_loss_batched * triplet_weight)
440
  loss_accumulator.accum("triplet", triplet_loss)
441
 
 
 
 
 
 
 
 
 
442
 
443
  if loss_accumulator.has("negative_mse"):
444
  neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
@@ -453,7 +465,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
453
  if loss_accumulator.has("negative_exponential"):
454
  raise NotImplementedError()
455
 
456
- if loss_accumulator.has("pixel_lpips") or loss_accumulator.has("pixel_mse"):
457
  x_0_pred = x_t_1d - t * v_pred_1d
458
  pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
459
  pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
@@ -468,6 +480,14 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
468
  if loss_accumulator.has("pixel_mse"):
469
  pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
470
  loss_accumulator.accum("pixel_mse", pixel_mse_loss)
 
 
 
 
 
 
 
 
471
 
472
  if loss_accumulator.has("adversarial"):
473
  raise NotImplementedError()
@@ -492,7 +512,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
492
  v_gt_1d,
493
  v_neg_1d,
494
  v_pred_1d,
495
- visualize_velocities=True,
496
  )
497
 
498
  return loss
@@ -513,7 +533,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
513
  v_gt_1d,
514
  v_neg_1d,
515
  v_pred_1d,
516
- visualize_velocities=True,
517
  ):
518
  t_float = t.float().cpu().item()
519
  x_0_pred = x_t_1d - t * v_pred_1d
@@ -526,18 +546,49 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
526
  "x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
527
  "x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
528
  }
529
- if visualize_velocities:
530
  log_pils.update({
531
  "v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
532
  "v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
533
  "v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
534
  })
535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  wand_logger.log({
537
  "train_images": log_pils,
538
  }, commit=False)
539
 
540
 
541
  def base_pipe(self, inputs: QwenInputs) -> list[Image]:
542
- inputs.num_inference_steps = self.config.regression_base_pipe_steps # override
 
 
 
 
 
 
 
 
 
 
543
  return super().base_pipe(inputs)
 
395
  split = batch["split"]
396
  step = batch["step"]
397
  if split == "train":
398
+ loss_terms = self.config.train_loss_terms
399
  elif split == "validation":
400
+ loss_terms = self.config.validation_loss_terms
401
  loss_accumulator = LossAccumulator(
402
+ terms=loss_terms.model_dump(),
403
  step=step,
404
  split=split,
405
+ term_groups={"pixel":loss_terms.pixel_terms}
406
  )
407
 
408
  if loss_accumulator.has("mse"):
 
415
  loss_accumulator.accum("mse", mse_loss)
416
 
417
  if loss_accumulator.has("triplet"):
418
+ # 1d, B,L,C
419
+ margin = loss_terms.triplet_margin
420
+ triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
421
+ print(f"{triplet_min_abs_diff=}")
422
+ v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2, keepdim=True)
423
+ zero_weight = torch.zeros_like(v_gt_neg_diff)
424
+ v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
425
+ ones = torch.ones_like(v_gt_neg_diff)
426
+ filtered_nums = torch.sum(torch.where(v_gt_neg_diff > triplet_min_abs_diff, ones, zero_weight))
427
+ wand_logger.log({
428
+ "filtered_nums": filtered_nums,
429
+ }, commit=False)
430
+
431
+
432
+ diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
433
+ diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
434
+ loss_unreduced = diffv_gt_pred - diffv_neg_pred
435
+ loss_weighted = (loss_unreduced * v_weight).sum(dim=2)
436
+ triplet_loss = F.relu(loss_weighted + margin).mean()
437
+ ones = torch.ones_like(loss_weighted)
438
+ zeros = torch.zeros_like(loss_weighted)
439
+ loss_nonzero_nums = torch.sum(torch.where((loss_weighted + margin)>0, ones, zeros))
440
+ wand_logger.log({
441
+ "loss_nonzero_nums": loss_nonzero_nums,
442
+ }, commit=False)
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  loss_accumulator.accum("triplet", triplet_loss)
445
 
446
+ texam(v_gt_neg_diff, "v_gt_neg_diff")
447
+ texam(v_weight, "v_weight")
448
+ texam(diffv_gt_pred, "diffv_gt_pred")
449
+ texam(diffv_neg_pred, "diffv_neg_pred")
450
+ texam(loss_unreduced, "loss_unreduced")
451
+ texam(loss_weighted, "loss_weighted")
452
+
453
+
454
 
455
  if loss_accumulator.has("negative_mse"):
456
  neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
 
465
  if loss_accumulator.has("negative_exponential"):
466
  raise NotImplementedError()
467
 
468
+ if loss_accumulator.has_group("pixel"):
469
  x_0_pred = x_t_1d - t * v_pred_1d
470
  pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
471
  pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
 
480
  if loss_accumulator.has("pixel_mse"):
481
  pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
482
  loss_accumulator.accum("pixel_mse", pixel_mse_loss)
483
+
484
+ if loss_accumulator.has("pixel_triplet"):
485
+ raise NotImplementedError()
486
+ loss_accumulator.accum("pixel_triplet", pixel_triplet_loss)
487
+
488
+ if loss_accumulator.has("pixel_distribution_matching"):
489
+ raise NotImplementedError()
490
+ loss_accumulator.accum("pixel_distribution_matching", pixel_distribution_matching_loss)
491
 
492
  if loss_accumulator.has("adversarial"):
493
  raise NotImplementedError()
 
512
  v_gt_1d,
513
  v_neg_1d,
514
  v_pred_1d,
515
+ visualize_velocities=False,
516
  )
517
 
518
  return loss
 
533
  v_gt_1d,
534
  v_neg_1d,
535
  v_pred_1d,
536
+ visualize_velocities=False,
537
  ):
538
  t_float = t.float().cpu().item()
539
  x_0_pred = x_t_1d - t * v_pred_1d
 
546
  "x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
547
  "x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
548
  }
549
+ if visualize_velocities: # naively visualizing through vae (works with flux)
550
  log_pils.update({
551
  "v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
552
  "v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
553
  "v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
554
  })
555
 
556
+ # create gt-neg difference maps
557
+ v_pred_2d = self.unpack_latents(v_pred_1d, h_f16, w_f16)
558
+ v_gt_2d = self.unpack_latents(v_gt_1d, h_f16, w_f16)
559
+ v_neg_2d = self.unpack_latents(v_neg_1d, h_f16, w_f16)
560
+ gt_neg_diff_map_2d = (v_gt_2d - v_neg_2d).pow(2).mean(dim=1, keepdim=True)
561
+ gt_pred_diff_map_2d = (v_gt_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True)
562
+ neg_pred_diff_map_2d = (v_neg_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True)
563
+ diff_max = torch.max(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d]))
564
+ diff_min = torch.min(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d]))
565
+ print(f"{diff_min}, {diff_max}")
566
+ # norms to 0-1
567
+ diff_span = diff_max - diff_min
568
+ gt_neg_diff_map_2d = (gt_neg_diff_map_2d - diff_min) / diff_span
569
+ gt_pred_diff_map_2d = (gt_pred_diff_map_2d - diff_min) / diff_span
570
+ neg_pred_diff_map_2d = (neg_pred_diff_map_2d - diff_min) / diff_span
571
+ log_pils.update({
572
+ "gt-neg":gt_neg_diff_map_2d.float().cpu(),
573
+ "gt-pred":gt_pred_diff_map_2d.float().cpu(),
574
+ "neg-pred":neg_pred_diff_map_2d.float().cpu(),
575
+ })
576
+
577
  wand_logger.log({
578
  "train_images": log_pils,
579
  }, commit=False)
580
 
581
 
582
  def base_pipe(self, inputs: QwenInputs) -> list[Image]:
583
+ # config overrides
584
+ inputs.num_inference_steps = self.config.regression_base_pipe_steps
585
+ inputs.latent_size_override = self.config.vae_image_size
586
+ inputs.vae_image_override = self.config.vae_image_size
587
+ image = inputs.image[0]
588
+ w,h = image.size
589
+ h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
590
+ image = TF.resize(image, (h_r, w_r))
591
+ inputs.image = [image]
592
+ inputs.height = h_r
593
+ inputs.width = w_r
594
  return super().base_pipe(inputs)
qwenimage/loss.py CHANGED
@@ -12,9 +12,11 @@ class LossAccumulator:
12
  terms: dict[str, int|float|dict],
13
  step: int|None=None,
14
  split: str|None=None,
 
15
  ):
16
  self.terms = terms
17
  self.step = step
 
18
  if split is not None:
19
  self.split = split
20
  self.prefix = f"{self.split}_"
@@ -55,7 +57,13 @@ class LossAccumulator:
55
 
56
  warnings.warn(f"Unknown spec type {spec}; treat as disabled")
57
  return 0.0
58
-
 
 
 
 
 
 
59
  def has(self, name: str) -> bool:
60
  return self.resolve_weight(name) > 0
61
 
 
12
  terms: dict[str, int|float|dict],
13
  step: int|None=None,
14
  split: str|None=None,
15
+ term_groups: dict[str, tuple[str, ...]]|None = None,
16
  ):
17
  self.terms = terms
18
  self.step = step
19
+ self.term_groups = term_groups
20
  if split is not None:
21
  self.split = split
22
  self.prefix = f"{self.split}_"
 
57
 
58
  warnings.warn(f"Unknown spec type {spec}; treat as disabled")
59
  return 0.0
60
+
61
+ def has_group(self, name: str):
62
+ if name not in self.term_groups:
63
+ return False
64
+ all_group_terms = self.term_groups[name]
65
+ return any([self.resolve_weight(tn) > 0 for tn in all_group_terms])
66
+
67
  def has(self, name: str) -> bool:
68
  return self.resolve_weight(name) > 0
69
 
scripts/logit_normal_dist.ipynb CHANGED
@@ -76,17 +76,17 @@
76
  },
77
  {
78
  "cell_type": "code",
79
- "execution_count": 9,
80
  "id": "aec3ae8f",
81
  "metadata": {},
82
  "outputs": [
83
  {
84
  "data": {
85
  "text/plain": [
86
- "[<matplotlib.lines.Line2D at 0x74338838e380>]"
87
  ]
88
  },
89
- "execution_count": 9,
90
  "metadata": {},
91
  "output_type": "execute_result"
92
  },
@@ -117,7 +117,7 @@
117
  },
118
  {
119
  "cell_type": "code",
120
- "execution_count": 16,
121
  "id": "3bc68e7c",
122
  "metadata": {},
123
  "outputs": [
@@ -127,7 +127,7 @@
127
  "1.0986122886681098"
128
  ]
129
  },
130
- "execution_count": 16,
131
  "metadata": {},
132
  "output_type": "execute_result"
133
  }
@@ -139,7 +139,7 @@
139
  },
140
  {
141
  "cell_type": "code",
142
- "execution_count": 17,
143
  "id": "facb782e",
144
  "metadata": {},
145
  "outputs": [
@@ -149,7 +149,7 @@
149
  "tensor([1.0000, 0.8808, 0.0000])"
150
  ]
151
  },
152
- "execution_count": 17,
153
  "metadata": {},
154
  "output_type": "execute_result"
155
  }
@@ -166,11 +166,30 @@
166
  },
167
  {
168
  "cell_type": "code",
169
- "execution_count": null,
170
  "id": "f006f2fa",
171
  "metadata": {},
172
- "outputs": [],
173
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  },
175
  {
176
  "cell_type": "code",
 
76
  },
77
  {
78
  "cell_type": "code",
79
+ "execution_count": 6,
80
  "id": "aec3ae8f",
81
  "metadata": {},
82
  "outputs": [
83
  {
84
  "data": {
85
  "text/plain": [
86
+ "[<matplotlib.lines.Line2D at 0x7ea4980d6b30>]"
87
  ]
88
  },
89
+ "execution_count": 6,
90
  "metadata": {},
91
  "output_type": "execute_result"
92
  },
 
117
  },
118
  {
119
  "cell_type": "code",
120
+ "execution_count": 7,
121
  "id": "3bc68e7c",
122
  "metadata": {},
123
  "outputs": [
 
127
  "1.0986122886681098"
128
  ]
129
  },
130
+ "execution_count": 7,
131
  "metadata": {},
132
  "output_type": "execute_result"
133
  }
 
139
  },
140
  {
141
  "cell_type": "code",
142
+ "execution_count": 8,
143
  "id": "facb782e",
144
  "metadata": {},
145
  "outputs": [
 
149
  "tensor([1.0000, 0.8808, 0.0000])"
150
  ]
151
  },
152
+ "execution_count": 8,
153
  "metadata": {},
154
  "output_type": "execute_result"
155
  }
 
166
  },
167
  {
168
  "cell_type": "code",
169
+ "execution_count": 9,
170
  "id": "f006f2fa",
171
  "metadata": {},
172
+ "outputs": [
173
+ {
174
+ "data": {
175
+ "text/plain": [
176
+ "tensor([1.0000, 0.7484, 0.0000])"
177
+ ]
178
+ },
179
+ "execution_count": 9,
180
+ "metadata": {},
181
+ "output_type": "execute_result"
182
+ }
183
+ ],
184
+ "source": [
185
+ "t = torch.tensor([1, 0.5, 0])\n",
186
+ "t_i = TimestepDistUtils.t_shift(\n",
187
+ " mu=torch.tensor(1.09),\n",
188
+ " sigma=1.0,\n",
189
+ " t=t\n",
190
+ ")\n",
191
+ "t_i"
192
+ ]
193
  },
194
  {
195
  "cell_type": "code",
scripts/save_regression_outputs.py CHANGED
@@ -15,6 +15,7 @@ def main():
15
  parser.add_argument("--imsize", type=int, default=512)
16
  parser.add_argument("--indir", type=str, default="/data/CrispEdit")
17
  parser.add_argument("--outdir", type=str, default="/data/regression_output")
 
18
  args = parser.parse_args()
19
 
20
  total_per = 10
@@ -54,6 +55,8 @@ def main():
54
  image=[input_data["input_img"]],
55
  prompt=input_data["instruction"],
56
  vae_image_override=args.imsize * args.imsize,
 
 
57
  ))
58
 
59
  torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
 
15
  parser.add_argument("--imsize", type=int, default=512)
16
  parser.add_argument("--indir", type=str, default="/data/CrispEdit")
17
  parser.add_argument("--outdir", type=str, default="/data/regression_output")
18
+ parser.add_argument("--steps", type=int, default=50)
19
  args = parser.parse_args()
20
 
21
  total_per = 10
 
55
  image=[input_data["input_img"]],
56
  prompt=input_data["instruction"],
57
  vae_image_override=args.imsize * args.imsize,
58
+ latent_size_override=args.imsize * args.imsize,
59
+ num_inference_steps=args.steps,
60
  ))
61
 
62
  torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
scripts/save_regression_outputs_modal.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+ import sys
5
+ sys.path.append(str(Path(__file__).parent.parent))
6
+
7
+ import fal
8
+ import modal
9
+ import torch
10
+ import tqdm
11
+ from datasets import concatenate_datasets, load_dataset, interleave_datasets
12
+
13
+ from qwenimage.datamodels import QwenConfig
14
+ from qwenimage.foundation import QwenImageFoundationSaveInterm
15
+
16
+ REQUIREMENTS_PATH = os.path.abspath("requirements.txt")
17
+ WAND_REQUIREMENTS_PATH = os.path.abspath("scripts/wand_requirements.txt")
18
+
19
+ local_modules = ["qwenimage","wandml","scripts"]
20
+
21
+
22
+
23
+ EDIT_TYPES = [
24
+ "color",
25
+ "style",
26
+ "replace",
27
+ "remove",
28
+ "add",
29
+ "motion change",
30
+ "background change",
31
+ ]
32
+
33
+ modalapp = modal.App("next-stroke")
34
+ modalapp.image = (
35
+ modal.Image.debian_slim(python_version="3.10")
36
+ .apt_install("git", "ffmpeg", "libsm6", "libxext6")
37
+ .pip_install_from_requirements(REQUIREMENTS_PATH)
38
+ .pip_install_from_requirements(WAND_REQUIREMENTS_PATH)
39
+ .add_local_python_source(*local_modules)
40
+ )
41
+
42
+ @modalapp.function(
43
+ gpu="H100",
44
+ max_containers=8,
45
+ timeout=1 * 60 * 60,
46
+ volumes={
47
+ "/data/wand_cache": modal.Volume.from_name("FLUX_MODELS"),
48
+ "/data/checkpoints": modal.Volume.from_name("training_checkpoints", create_if_missing=True),
49
+ "/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),
50
+
51
+ "/root/.cache/huggingface/hub": modal.Volume.from_name("hf_cache", create_if_missing=True),
52
+ "/root/.cache/huggingface/datasets": modal.Volume.from_name("hf_cache_datasets", create_if_missing=True),
53
+
54
+ "/data/regression_data": modal.Volume.from_name("regression_data"),
55
+ "/data/edit_data": modal.Volume.from_name("edit_data"),
56
+ },
57
+ secrets=[
58
+ modal.Secret.from_name("wand-modal-gcloud-keyfile"),
59
+ modal.Secret.from_name("elea-huggingface-secret"),
60
+ ],
61
+ )
62
+ def generate_regression_data(start_index=0, end_index=None, imsize=1024, indir="/data/edit_data/CrispEdit", outdir="/data/regression_data/regression_output_1024", total_per=10):
63
+
64
+ all_edit_datasets = []
65
+ for edit_type in EDIT_TYPES:
66
+ to_concat = []
67
+ for ds_n in range(total_per):
68
+ ds = load_dataset("parquet", data_files=f"{indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
69
+ to_concat.append(ds)
70
+ edit_type_concat = concatenate_datasets(to_concat)
71
+ all_edit_datasets.append(edit_type_concat)
72
+ join_ds = interleave_datasets(all_edit_datasets)
73
+
74
+ save_base_dir = Path(outdir)
75
+ save_base_dir.mkdir(exist_ok=True, parents=True)
76
+
77
+ foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=imsize * imsize))
78
+
79
+ if end_index is None:
80
+ end_index = len(join_ds)
81
+ dataset_to_process = join_ds.select(range(start_index, end_index))
82
+
83
+ for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=start_index):
84
+
85
+ output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
86
+ image=[input_data["input_img"]],
87
+ prompt=input_data["instruction"],
88
+ vae_image_override=imsize * imsize,
89
+ latent_size_override=imsize * imsize,
90
+ ))
91
+
92
+ torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
93
+
94
+
95
+ @modalapp.local_entrypoint()
96
+ def main(start:int, end:int, num_workers:int):
97
+ per_worker_load = (end - start) // num_workers
98
+ remainder = (end - start) % num_workers
99
+ if remainder > 0:
100
+ per_worker_load += 1
101
+ worker_load_starts = []
102
+ worker_load_ends = []
103
+ cur_start = start
104
+ for worker_idx in range(num_workers):
105
+ if worker_idx < num_workers -1:
106
+ cur_end = cur_start + per_worker_load
107
+ else:
108
+ cur_end = end # pass last worker less
109
+ worker_load_starts.append(cur_start)
110
+ worker_load_ends.append(cur_end)
111
+ cur_start += per_worker_load
112
+
113
+ print(f"loads: {list(zip(worker_load_starts, worker_load_ends))}")
114
+ outputs = list(generate_regression_data.map(worker_load_starts, worker_load_ends))
115
+ print(outputs)
116
+
117
+
118
+
119
+
scripts/straightness.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
scripts/train.py CHANGED
@@ -70,6 +70,7 @@ modalapp.image = (
70
  "/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),
71
 
72
  "/root/.cache/huggingface/hub": modal.Volume.from_name("hf_cache", create_if_missing=True),
 
73
 
74
  "/data/regression_data": modal.Volume.from_name("regression_data"),
75
  "/data/edit_data": modal.Volume.from_name("edit_data"),
 
70
  "/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),
71
 
72
  "/root/.cache/huggingface/hub": modal.Volume.from_name("hf_cache", create_if_missing=True),
73
+ "/root/.cache/huggingface/datasets": modal.Volume.from_name("hf_cache_datasets", create_if_missing=True),
74
 
75
  "/data/regression_data": modal.Volume.from_name("regression_data"),
76
  "/data/edit_data": modal.Volume.from_name("edit_data"),
scripts/train_multi copy.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+
4
+ nohup python scripts/train.py configs/base.yaml --where modal \
5
+ --update configs/regression/base.yaml \
6
+ --update configs/regression/modal.yaml \
7
+ --update configs/regression/dm/mse-dm-a.yaml \
8
+ --update configs/compare/5k_steps.yaml \
9
+ > logs/mse-dm-a.log 2>&1 &
10
+
11
+ nohup python scripts/train.py configs/base.yaml --where modal \
12
+ --update configs/regression/base.yaml \
13
+ --update configs/regression/modal.yaml \
14
+ --update configs/regression/dm/mse-dm-b.yaml \
15
+ --update configs/compare/5k_steps.yaml \
16
+ > logs/mse-dm-b.log 2>&1 &
17
+
18
+ nohup python scripts/train.py configs/base.yaml --where modal \
19
+ --update configs/regression/base.yaml \
20
+ --update configs/regression/modal.yaml \
21
+ --update configs/regression/dm/mse-dm-c.yaml \
22
+ --update configs/compare/5k_steps.yaml \
23
+ > logs/mse-dm-c.log 2>&1 &
24
+
25
+ nohup python scripts/train.py configs/base.yaml --where modal \
26
+ --update configs/regression/base.yaml \
27
+ --update configs/regression/modal.yaml \
28
+ --update configs/regression/dm/mse-dm-d.yaml \
29
+ --update configs/compare/5k_steps.yaml \
30
+ > logs/mse-dm-d.log 2>&1 &
scripts/train_multi.sh CHANGED
@@ -1,61 +1,38 @@
1
  #!/bin/bash
2
 
 
3
  # nohup python scripts/train.py configs/base.yaml --where modal \
4
  # --update configs/regression/base.yaml \
5
- # --update configs/regression/modal-datadirs.yaml \
6
  # --update configs/regression/mse.yaml \
7
- # --update configs/regression/val_metrics.yaml \
8
  # --update configs/compare/5k_steps.yaml \
9
- # --update configs/optim/cosine.yaml \
10
- # --update configs/regression/lo_mse.yaml \
11
  # > logs/mse.log 2>&1 &
12
 
 
13
  nohup python scripts/train.py configs/base.yaml --where modal \
14
  --update configs/regression/base.yaml \
15
- --update configs/regression/modal-datadirs.yaml \
16
- --update configs/regression/mse-triplet.yaml \
17
- --update configs/regression/val_metrics.yaml \
18
  --update configs/compare/5k_steps.yaml \
19
- --update configs/optim/cosine.yaml \
20
- --update configs/regression/lo_mse.yaml \
21
- > logs/mse-triplet.log 2>&1 &
22
-
23
- # nohup python scripts/train.py configs/base.yaml --where modal \
24
- # --update configs/regression/base.yaml \
25
- # --update configs/regression/modal-datadirs.yaml \
26
- # --update configs/regression/mse-neg-mse.yaml \
27
- # --update configs/regression/val_metrics.yaml \
28
- # --update configs/compare/5k_steps.yaml \
29
- # --update configs/optim/cosine.yaml \
30
- # --update configs/regression/lo_mse.yaml \
31
- # > logs/mse-neg-mse.log 2>&1 &
32
 
33
  nohup python scripts/train.py configs/base.yaml --where modal \
34
  --update configs/regression/base.yaml \
35
- --update configs/regression/modal-datadirs.yaml \
36
- --update configs/regression/mse-pixel-mse.yaml \
37
- --update configs/regression/val_metrics.yaml \
38
  --update configs/compare/5k_steps.yaml \
39
- --update configs/optim/cosine.yaml \
40
- --update configs/regression/lo_mse.yaml \
41
- > logs/mse-pixel-mse.log 2>&1 &
42
 
43
  nohup python scripts/train.py configs/base.yaml --where modal \
44
  --update configs/regression/base.yaml \
45
- --update configs/regression/modal-datadirs.yaml \
46
- --update configs/regression/mse-pixel-lpips.yaml \
47
- --update configs/regression/val_metrics.yaml \
48
  --update configs/compare/5k_steps.yaml \
49
- --update configs/optim/cosine.yaml \
50
- --update configs/regression/lo_mse.yaml \
51
- > logs/mse-pixel-lpips.log 2>&1 &
52
 
53
  nohup python scripts/train.py configs/base.yaml --where modal \
54
  --update configs/regression/base.yaml \
55
- --update configs/regression/modal-datadirs.yaml \
56
- --update configs/regression/mse-dm.yaml \
57
- --update configs/regression/val_metrics.yaml \
58
  --update configs/compare/5k_steps.yaml \
59
- --update configs/optim/cosine.yaml \
60
- --update configs/regression/lo_mse.yaml \
61
- > logs/mse-pixel-lpips.log 2>&1 &
 
1
  #!/bin/bash
2
 
3
+
4
  # nohup python scripts/train.py configs/base.yaml --where modal \
5
  # --update configs/regression/base.yaml \
6
+ # --update configs/regression/modal.yaml \
7
  # --update configs/regression/mse.yaml \
 
8
  # --update configs/compare/5k_steps.yaml \
 
 
9
  # > logs/mse.log 2>&1 &
10
 
11
+
12
  nohup python scripts/train.py configs/base.yaml --where modal \
13
  --update configs/regression/base.yaml \
14
+ --update configs/regression/modal.yaml \
15
+ --update configs/regression/triplet/mse-triplet-b.yaml \
 
16
  --update configs/compare/5k_steps.yaml \
17
+ > logs/mse-triplet-b.log 2>&1 &
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  nohup python scripts/train.py configs/base.yaml --where modal \
20
  --update configs/regression/base.yaml \
21
+ --update configs/regression/modal.yaml \
22
+ --update configs/regression/triplet/mse-triplet-c.yaml \
 
23
  --update configs/compare/5k_steps.yaml \
24
+ > logs/mse-triplet-c.log 2>&1 &
 
 
25
 
26
  nohup python scripts/train.py configs/base.yaml --where modal \
27
  --update configs/regression/base.yaml \
28
+ --update configs/regression/modal.yaml \
29
+ --update configs/regression/triplet/mse-triplet-d.yaml \
 
30
  --update configs/compare/5k_steps.yaml \
31
+ > logs/mse-triplet-d.log 2>&1 &
 
 
32
 
33
  nohup python scripts/train.py configs/base.yaml --where modal \
34
  --update configs/regression/base.yaml \
35
+ --update configs/regression/modal.yaml \
36
+ --update configs/regression/triplet/mse-triplet-e.yaml \
 
37
  --update configs/compare/5k_steps.yaml \
38
+ > logs/mse-triplet-e.log 2>&1 &