Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
6064267
1
Parent(s):
789676e
triplet loss experiments (prelim)
Browse files- configs/base.yaml +6 -3
- configs/regression/base.yaml +6 -0
- configs/regression/modal-datadirs.yaml +0 -2
- configs/regression/modal.yaml +5 -0
- configs/regression/mse-dm.yaml +0 -11
- configs/regression/mse-neg-mse.yaml +0 -3
- configs/regression/mse-pixel-lpips.yaml +0 -4
- configs/regression/mse-pixel-mse.yaml +0 -4
- configs/regression/mse-triplet.yaml +0 -13
- configs/regression/mse.yaml +0 -2
- configs/regression/triplet/mse-triplet-a.yaml +9 -0
- configs/regression/triplet/mse-triplet-b.yaml +9 -0
- configs/regression/triplet/mse-triplet-c.yaml +9 -0
- configs/regression/triplet/mse-triplet-d.yaml +9 -0
- configs/regression/triplet/mse-triplet-e.yaml +9 -0
- configs/regression/val_metrics.yaml +0 -9
- qwenimage/datamodels.py +10 -2
- qwenimage/foundation.py +81 -30
- qwenimage/loss.py +9 -1
- scripts/logit_normal_dist.ipynb +29 -10
- scripts/save_regression_outputs.py +3 -0
- scripts/save_regression_outputs_modal.py +119 -0
- scripts/straightness.ipynb +0 -0
- scripts/train.py +1 -0
- scripts/train_multi copy.sh +30 -0
- scripts/train_multi.sh +15 -38
configs/base.yaml
CHANGED
|
@@ -11,12 +11,12 @@ gradient_accumulation_steps: 1
|
|
| 11 |
train_batch_size: 1
|
| 12 |
optim: "adamw"
|
| 13 |
learning_rate: 1.0e-4
|
| 14 |
-
num_workers:
|
| 15 |
resume_from_checkpoint: null
|
| 16 |
log_model_steps: null
|
| 17 |
preprocessing_epoch_len: 64
|
| 18 |
preprocessing_epoch_repetitions: 1
|
| 19 |
-
|
| 20 |
|
| 21 |
# Logging
|
| 22 |
record_training: true
|
|
@@ -33,5 +33,8 @@ sample_steps:
|
|
| 33 |
every: 500
|
| 34 |
global_step: 0
|
| 35 |
save_steps: 1000
|
| 36 |
-
log_batch_steps:
|
|
|
|
|
|
|
| 37 |
seed: 67
|
|
|
|
|
|
| 11 |
train_batch_size: 1
|
| 12 |
optim: "adamw"
|
| 13 |
learning_rate: 1.0e-4
|
| 14 |
+
num_workers: 8
|
| 15 |
resume_from_checkpoint: null
|
| 16 |
log_model_steps: null
|
| 17 |
preprocessing_epoch_len: 64
|
| 18 |
preprocessing_epoch_repetitions: 1
|
| 19 |
+
lora_rank: 16
|
| 20 |
|
| 21 |
# Logging
|
| 22 |
record_training: true
|
|
|
|
| 33 |
every: 500
|
| 34 |
global_step: 0
|
| 35 |
save_steps: 1000
|
| 36 |
+
log_batch_steps:
|
| 37 |
+
"on": [0,1,100]
|
| 38 |
+
every: 500
|
| 39 |
seed: 67
|
| 40 |
+
|
configs/regression/base.yaml
CHANGED
|
@@ -19,3 +19,9 @@ regression_data_dir: "/data/regression_output"
|
|
| 19 |
regression_gen_steps: 50
|
| 20 |
editing_data_dir: "/data/CrispEdit"
|
| 21 |
editing_total_per: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
regression_gen_steps: 50
|
| 20 |
editing_data_dir: "/data/CrispEdit"
|
| 21 |
editing_total_per: 1
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
validation_loss_terms:
|
| 25 |
+
mse: 1.0
|
| 26 |
+
pixel_mse: 1.0
|
| 27 |
+
pixel_lpips: 1.0
|
configs/regression/modal-datadirs.yaml
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
regression_data_dir: "/data/regression_data/regression_output"
|
| 2 |
-
editing_data_dir: "/data/edit_data/CrispEdit"
|
|
|
|
|
|
|
|
|
configs/regression/modal.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
regression_data_dir: "/data/regression_data/regression_output_1024"
|
| 2 |
+
editing_data_dir: "/data/edit_data/CrispEdit"
|
| 3 |
+
|
| 4 |
+
lora_rank: 32
|
| 5 |
+
vae_image_size: 1048576 # 1024 * 1024
|
configs/regression/mse-dm.yaml
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
wandb_run_name: "reg-mse-dm"
|
| 2 |
-
output_dir: "/data/checkpoints/reg-mse-dm"
|
| 3 |
-
|
| 4 |
-
train_loss_terms:
|
| 5 |
-
mse: 1.0
|
| 6 |
-
distribution_matching: 1.0
|
| 7 |
-
|
| 8 |
-
validation_loss_terms:
|
| 9 |
-
mse: 1.0
|
| 10 |
-
distribution_matching: 1.0
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/regression/mse-neg-mse.yaml
CHANGED
|
@@ -5,7 +5,4 @@ train_loss_terms:
|
|
| 5 |
mse: 1.0
|
| 6 |
negative_mse: 0.1
|
| 7 |
|
| 8 |
-
validation_loss_terms:
|
| 9 |
-
mse: 1.0
|
| 10 |
-
negative_mse: 0.1
|
| 11 |
|
|
|
|
| 5 |
mse: 1.0
|
| 6 |
negative_mse: 0.1
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
configs/regression/mse-pixel-lpips.yaml
CHANGED
|
@@ -4,7 +4,3 @@ output_dir: "/data/checkpoints/reg-mse-pixel-lpips"
|
|
| 4 |
train_loss_terms:
|
| 5 |
mse: 1.0
|
| 6 |
pixel_lpips: 1.0
|
| 7 |
-
|
| 8 |
-
validation_loss_terms:
|
| 9 |
-
mse: 1.0
|
| 10 |
-
pixel_lpips: 1.0
|
|
|
|
| 4 |
train_loss_terms:
|
| 5 |
mse: 1.0
|
| 6 |
pixel_lpips: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
configs/regression/mse-pixel-mse.yaml
CHANGED
|
@@ -5,7 +5,3 @@ train_loss_terms:
|
|
| 5 |
mse: 1.0
|
| 6 |
pixel_mse: 1.0
|
| 7 |
|
| 8 |
-
validation_loss_terms:
|
| 9 |
-
mse: 1.0
|
| 10 |
-
pixel_mse: 1.0
|
| 11 |
-
|
|
|
|
| 5 |
mse: 1.0
|
| 6 |
pixel_mse: 1.0
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
configs/regression/mse-triplet.yaml
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
wandb_run_name: "reg-mse-triplet"
|
| 2 |
-
output_dir: "/data/checkpoints/reg-mse-triplet"
|
| 3 |
-
|
| 4 |
-
train_loss_terms:
|
| 5 |
-
mse: 1.0
|
| 6 |
-
triplet: 1.0
|
| 7 |
-
|
| 8 |
-
validation_loss_terms:
|
| 9 |
-
mse: 1.0
|
| 10 |
-
triplet: 1.0
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
triplet_margin: -500 # tune
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/regression/mse.yaml
CHANGED
|
@@ -4,6 +4,4 @@ output_dir: "/data/checkpoints/reg-mse"
|
|
| 4 |
train_loss_terms:
|
| 5 |
mse: 1.0
|
| 6 |
|
| 7 |
-
validation_loss_terms:
|
| 8 |
-
mse: 1.0
|
| 9 |
|
|
|
|
| 4 |
train_loss_terms:
|
| 5 |
mse: 1.0
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
configs/regression/triplet/mse-triplet-a.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-a"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-a"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: 0.0
|
| 9 |
+
triplet_min_abs_diff: 0.0
|
configs/regression/triplet/mse-triplet-b.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-b"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-b"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: 0.0
|
| 9 |
+
triplet_min_abs_diff: 0.1
|
configs/regression/triplet/mse-triplet-c.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-c"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-c"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: 0.1
|
| 9 |
+
triplet_min_abs_diff: 0.1
|
configs/regression/triplet/mse-triplet-d.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-c"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-c"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: 0.5
|
| 9 |
+
triplet_min_abs_diff: 0.1
|
configs/regression/triplet/mse-triplet-e.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet-e"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet-e"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
triplet_margin: 0.1
|
| 9 |
+
triplet_min_abs_diff: 0.25
|
configs/regression/val_metrics.yaml
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
validation_loss_terms:
|
| 5 |
-
mse: 1.0
|
| 6 |
-
pixel_mse: 1.0
|
| 7 |
-
pixel_lpips: 1.0
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qwenimage/datamodels.py
CHANGED
|
@@ -50,12 +50,20 @@ class QwenLossTerms(BaseModel):
|
|
| 50 |
triplet: LossTermSpecType = 0.0
|
| 51 |
negative_mse: LossTermSpecType = 0.0
|
| 52 |
distribution_matching: LossTermSpecType = 0.0
|
| 53 |
-
|
| 54 |
pixel_lpips: LossTermSpecType = 0.0
|
| 55 |
pixel_mse: LossTermSpecType = 0.0
|
|
|
|
| 56 |
adversarial: LossTermSpecType = 0.0
|
|
|
|
| 57 |
|
| 58 |
-
triplet_margin: float = 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
class QwenConfig(ExperimentTrainerParameters):
|
| 61 |
load_multi_view_lora: bool = False
|
|
|
|
| 50 |
triplet: LossTermSpecType = 0.0
|
| 51 |
negative_mse: LossTermSpecType = 0.0
|
| 52 |
distribution_matching: LossTermSpecType = 0.0
|
| 53 |
+
pixel_triplet: LossTermSpecType = 0.0
|
| 54 |
pixel_lpips: LossTermSpecType = 0.0
|
| 55 |
pixel_mse: LossTermSpecType = 0.0
|
| 56 |
+
pixel_distribution_matching: LossTermSpecType = 0.0
|
| 57 |
adversarial: LossTermSpecType = 0.0
|
| 58 |
+
teacher: LossTermSpecType = 0.0
|
| 59 |
|
| 60 |
+
triplet_margin: float = 0.0
|
| 61 |
+
triplet_min_abs_diff: float = 0.0
|
| 62 |
+
teacher_steps: int = 4
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def pixel_terms(self) -> bool:
|
| 66 |
+
return ("pixel_lpips", "pixel_mse", "pixel_triplet", "pixel_distribution_matching",)
|
| 67 |
|
| 68 |
class QwenConfig(ExperimentTrainerParameters):
|
| 69 |
load_multi_view_lora: bool = False
|
qwenimage/foundation.py
CHANGED
|
@@ -395,13 +395,14 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 395 |
split = batch["split"]
|
| 396 |
step = batch["step"]
|
| 397 |
if split == "train":
|
| 398 |
-
loss_terms = self.config.train_loss_terms
|
| 399 |
elif split == "validation":
|
| 400 |
-
loss_terms = self.config.validation_loss_terms
|
| 401 |
loss_accumulator = LossAccumulator(
|
| 402 |
-
terms=loss_terms,
|
| 403 |
step=step,
|
| 404 |
split=split,
|
|
|
|
| 405 |
)
|
| 406 |
|
| 407 |
if loss_accumulator.has("mse"):
|
|
@@ -414,31 +415,42 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 414 |
loss_accumulator.accum("mse", mse_loss)
|
| 415 |
|
| 416 |
if loss_accumulator.has("triplet"):
|
| 417 |
-
#
|
| 418 |
-
margin = loss_terms
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
# triplet_weight = (v_gt_1d - v_neg_1d).pow(2).mean(dim=(-2,-1))
|
| 422 |
-
|
| 423 |
-
diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
|
| 424 |
-
diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
|
| 425 |
-
# diffv_gt_pred_reg = diffv_gt_pred # / (v_span + eps)
|
| 426 |
-
# diffv_neg_pred_reg = diffv_neg_pred # / (v_span + eps)
|
| 427 |
-
|
| 428 |
-
# texam(v_span, name="v_span")
|
| 429 |
-
# texam(triplet_weight, name="triplet_weight")
|
| 430 |
-
texam(diffv_gt_pred, name="diffv_gt_pred")
|
| 431 |
-
texam(diffv_neg_pred, name="diffv_neg_pred")
|
| 432 |
-
# texam(diffv_gt_pred_reg, name="diffv_gt_pred_reg")
|
| 433 |
-
# texam(diffv_neg_pred_reg, name="diffv_neg_pred_reg")
|
| 434 |
-
# texam(diffv_gt_pred_reg - diffv_neg_pred_reg, name="diffv_gt_pred_reg - diffv_neg_pred_reg")
|
| 435 |
-
|
| 436 |
-
triplet_loss = F.relu(diffv_gt_pred - diffv_neg_pred + margin).mean()
|
| 437 |
-
#
|
| 438 |
-
# triplet_loss = F.relu(diffv_gt_pred_reg - diffv_neg_pred_reg + margin).mean()
|
| 439 |
-
# triplet_loss = torch.mean(triplet_loss_batched * triplet_weight)
|
| 440 |
loss_accumulator.accum("triplet", triplet_loss)
|
| 441 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
if loss_accumulator.has("negative_mse"):
|
| 444 |
neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
|
|
@@ -453,7 +465,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 453 |
if loss_accumulator.has("negative_exponential"):
|
| 454 |
raise NotImplementedError()
|
| 455 |
|
| 456 |
-
if loss_accumulator.
|
| 457 |
x_0_pred = x_t_1d - t * v_pred_1d
|
| 458 |
pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
|
| 459 |
pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
|
|
@@ -468,6 +480,14 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 468 |
if loss_accumulator.has("pixel_mse"):
|
| 469 |
pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
|
| 470 |
loss_accumulator.accum("pixel_mse", pixel_mse_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
if loss_accumulator.has("adversarial"):
|
| 473 |
raise NotImplementedError()
|
|
@@ -492,7 +512,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 492 |
v_gt_1d,
|
| 493 |
v_neg_1d,
|
| 494 |
v_pred_1d,
|
| 495 |
-
visualize_velocities=
|
| 496 |
)
|
| 497 |
|
| 498 |
return loss
|
|
@@ -513,7 +533,7 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 513 |
v_gt_1d,
|
| 514 |
v_neg_1d,
|
| 515 |
v_pred_1d,
|
| 516 |
-
visualize_velocities=
|
| 517 |
):
|
| 518 |
t_float = t.float().cpu().item()
|
| 519 |
x_0_pred = x_t_1d - t * v_pred_1d
|
|
@@ -526,18 +546,49 @@ class QwenImageRegressionFoundation(QwenImageFoundation):
|
|
| 526 |
"x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
|
| 527 |
"x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
|
| 528 |
}
|
| 529 |
-
if visualize_velocities:
|
| 530 |
log_pils.update({
|
| 531 |
"v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
|
| 532 |
"v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
|
| 533 |
"v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
|
| 534 |
})
|
| 535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
wand_logger.log({
|
| 537 |
"train_images": log_pils,
|
| 538 |
}, commit=False)
|
| 539 |
|
| 540 |
|
| 541 |
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
return super().base_pipe(inputs)
|
|
|
|
| 395 |
split = batch["split"]
|
| 396 |
step = batch["step"]
|
| 397 |
if split == "train":
|
| 398 |
+
loss_terms = self.config.train_loss_terms
|
| 399 |
elif split == "validation":
|
| 400 |
+
loss_terms = self.config.validation_loss_terms
|
| 401 |
loss_accumulator = LossAccumulator(
|
| 402 |
+
terms=loss_terms.model_dump(),
|
| 403 |
step=step,
|
| 404 |
split=split,
|
| 405 |
+
term_groups={"pixel":loss_terms.pixel_terms}
|
| 406 |
)
|
| 407 |
|
| 408 |
if loss_accumulator.has("mse"):
|
|
|
|
| 415 |
loss_accumulator.accum("mse", mse_loss)
|
| 416 |
|
| 417 |
if loss_accumulator.has("triplet"):
|
| 418 |
+
# 1d, B,L,C
|
| 419 |
+
margin = loss_terms.triplet_margin
|
| 420 |
+
triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
|
| 421 |
+
print(f"{triplet_min_abs_diff=}")
|
| 422 |
+
v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2, keepdim=True)
|
| 423 |
+
zero_weight = torch.zeros_like(v_gt_neg_diff)
|
| 424 |
+
v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
|
| 425 |
+
ones = torch.ones_like(v_gt_neg_diff)
|
| 426 |
+
filtered_nums = torch.sum(torch.where(v_gt_neg_diff > triplet_min_abs_diff, ones, zero_weight))
|
| 427 |
+
wand_logger.log({
|
| 428 |
+
"filtered_nums": filtered_nums,
|
| 429 |
+
}, commit=False)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
|
| 433 |
+
diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
|
| 434 |
+
loss_unreduced = diffv_gt_pred - diffv_neg_pred
|
| 435 |
+
loss_weighted = (loss_unreduced * v_weight).sum(dim=2)
|
| 436 |
+
triplet_loss = F.relu(loss_weighted + margin).mean()
|
| 437 |
+
ones = torch.ones_like(loss_weighted)
|
| 438 |
+
zeros = torch.zeros_like(loss_weighted)
|
| 439 |
+
loss_nonzero_nums = torch.sum(torch.where((loss_weighted + margin)>0, ones, zeros))
|
| 440 |
+
wand_logger.log({
|
| 441 |
+
"loss_nonzero_nums": loss_nonzero_nums,
|
| 442 |
+
}, commit=False)
|
| 443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
loss_accumulator.accum("triplet", triplet_loss)
|
| 445 |
|
| 446 |
+
texam(v_gt_neg_diff, "v_gt_neg_diff")
|
| 447 |
+
texam(v_weight, "v_weight")
|
| 448 |
+
texam(diffv_gt_pred, "diffv_gt_pred")
|
| 449 |
+
texam(diffv_neg_pred, "diffv_neg_pred")
|
| 450 |
+
texam(loss_unreduced, "loss_unreduced")
|
| 451 |
+
texam(loss_weighted, "loss_weighted")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
|
| 455 |
if loss_accumulator.has("negative_mse"):
|
| 456 |
neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
|
|
|
|
| 465 |
if loss_accumulator.has("negative_exponential"):
|
| 466 |
raise NotImplementedError()
|
| 467 |
|
| 468 |
+
if loss_accumulator.has_group("pixel"):
|
| 469 |
x_0_pred = x_t_1d - t * v_pred_1d
|
| 470 |
pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
|
| 471 |
pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
|
|
|
|
| 480 |
if loss_accumulator.has("pixel_mse"):
|
| 481 |
pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
|
| 482 |
loss_accumulator.accum("pixel_mse", pixel_mse_loss)
|
| 483 |
+
|
| 484 |
+
if loss_accumulator.has("pixel_triplet"):
|
| 485 |
+
raise NotImplementedError()
|
| 486 |
+
loss_accumulator.accum("pixel_triplet", pixel_triplet_loss)
|
| 487 |
+
|
| 488 |
+
if loss_accumulator.has("pixel_distribution_matching"):
|
| 489 |
+
raise NotImplementedError()
|
| 490 |
+
loss_accumulator.accum("pixel_distribution_matching", pixel_distribution_matching_loss)
|
| 491 |
|
| 492 |
if loss_accumulator.has("adversarial"):
|
| 493 |
raise NotImplementedError()
|
|
|
|
| 512 |
v_gt_1d,
|
| 513 |
v_neg_1d,
|
| 514 |
v_pred_1d,
|
| 515 |
+
visualize_velocities=False,
|
| 516 |
)
|
| 517 |
|
| 518 |
return loss
|
|
|
|
| 533 |
v_gt_1d,
|
| 534 |
v_neg_1d,
|
| 535 |
v_pred_1d,
|
| 536 |
+
visualize_velocities=False,
|
| 537 |
):
|
| 538 |
t_float = t.float().cpu().item()
|
| 539 |
x_0_pred = x_t_1d - t * v_pred_1d
|
|
|
|
| 546 |
"x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
|
| 547 |
"x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
|
| 548 |
}
|
| 549 |
+
if visualize_velocities: # naively visualizing through vae (works with flux)
|
| 550 |
log_pils.update({
|
| 551 |
"v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
|
| 552 |
"v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
|
| 553 |
"v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
|
| 554 |
})
|
| 555 |
|
| 556 |
+
# create gt-neg difference maps
|
| 557 |
+
v_pred_2d = self.unpack_latents(v_pred_1d, h_f16, w_f16)
|
| 558 |
+
v_gt_2d = self.unpack_latents(v_gt_1d, h_f16, w_f16)
|
| 559 |
+
v_neg_2d = self.unpack_latents(v_neg_1d, h_f16, w_f16)
|
| 560 |
+
gt_neg_diff_map_2d = (v_gt_2d - v_neg_2d).pow(2).mean(dim=1, keepdim=True)
|
| 561 |
+
gt_pred_diff_map_2d = (v_gt_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True)
|
| 562 |
+
neg_pred_diff_map_2d = (v_neg_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True)
|
| 563 |
+
diff_max = torch.max(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d]))
|
| 564 |
+
diff_min = torch.min(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d]))
|
| 565 |
+
print(f"{diff_min}, {diff_max}")
|
| 566 |
+
# norms to 0-1
|
| 567 |
+
diff_span = diff_max - diff_min
|
| 568 |
+
gt_neg_diff_map_2d = (gt_neg_diff_map_2d - diff_min) / diff_span
|
| 569 |
+
gt_pred_diff_map_2d = (gt_pred_diff_map_2d - diff_min) / diff_span
|
| 570 |
+
neg_pred_diff_map_2d = (neg_pred_diff_map_2d - diff_min) / diff_span
|
| 571 |
+
log_pils.update({
|
| 572 |
+
"gt-neg":gt_neg_diff_map_2d.float().cpu(),
|
| 573 |
+
"gt-pred":gt_pred_diff_map_2d.float().cpu(),
|
| 574 |
+
"neg-pred":neg_pred_diff_map_2d.float().cpu(),
|
| 575 |
+
})
|
| 576 |
+
|
| 577 |
wand_logger.log({
|
| 578 |
"train_images": log_pils,
|
| 579 |
}, commit=False)
|
| 580 |
|
| 581 |
|
| 582 |
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
|
| 583 |
+
# config overrides
|
| 584 |
+
inputs.num_inference_steps = self.config.regression_base_pipe_steps
|
| 585 |
+
inputs.latent_size_override = self.config.vae_image_size
|
| 586 |
+
inputs.vae_image_override = self.config.vae_image_size
|
| 587 |
+
image = inputs.image[0]
|
| 588 |
+
w,h = image.size
|
| 589 |
+
h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
|
| 590 |
+
image = TF.resize(image, (h_r, w_r))
|
| 591 |
+
inputs.image = [image]
|
| 592 |
+
inputs.height = h_r
|
| 593 |
+
inputs.width = w_r
|
| 594 |
return super().base_pipe(inputs)
|
qwenimage/loss.py
CHANGED
|
@@ -12,9 +12,11 @@ class LossAccumulator:
|
|
| 12 |
terms: dict[str, int|float|dict],
|
| 13 |
step: int|None=None,
|
| 14 |
split: str|None=None,
|
|
|
|
| 15 |
):
|
| 16 |
self.terms = terms
|
| 17 |
self.step = step
|
|
|
|
| 18 |
if split is not None:
|
| 19 |
self.split = split
|
| 20 |
self.prefix = f"{self.split}_"
|
|
@@ -55,7 +57,13 @@ class LossAccumulator:
|
|
| 55 |
|
| 56 |
warnings.warn(f"Unknown spec type {spec}; treat as disabled")
|
| 57 |
return 0.0
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def has(self, name: str) -> bool:
|
| 60 |
return self.resolve_weight(name) > 0
|
| 61 |
|
|
|
|
| 12 |
terms: dict[str, int|float|dict],
|
| 13 |
step: int|None=None,
|
| 14 |
split: str|None=None,
|
| 15 |
+
term_groups: dict[str, tuple[str, ...]]|None = None,
|
| 16 |
):
|
| 17 |
self.terms = terms
|
| 18 |
self.step = step
|
| 19 |
+
self.term_groups = term_groups
|
| 20 |
if split is not None:
|
| 21 |
self.split = split
|
| 22 |
self.prefix = f"{self.split}_"
|
|
|
|
| 57 |
|
| 58 |
warnings.warn(f"Unknown spec type {spec}; treat as disabled")
|
| 59 |
return 0.0
|
| 60 |
+
|
| 61 |
+
def has_group(self, name: str):
|
| 62 |
+
if name not in self.term_groups:
|
| 63 |
+
return False
|
| 64 |
+
all_group_terms = self.term_groups[name]
|
| 65 |
+
return any([self.resolve_weight(tn) > 0 for tn in all_group_terms])
|
| 66 |
+
|
| 67 |
def has(self, name: str) -> bool:
|
| 68 |
return self.resolve_weight(name) > 0
|
| 69 |
|
scripts/logit_normal_dist.ipynb
CHANGED
|
@@ -76,17 +76,17 @@
|
|
| 76 |
},
|
| 77 |
{
|
| 78 |
"cell_type": "code",
|
| 79 |
-
"execution_count":
|
| 80 |
"id": "aec3ae8f",
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [
|
| 83 |
{
|
| 84 |
"data": {
|
| 85 |
"text/plain": [
|
| 86 |
-
"[<matplotlib.lines.Line2D at
|
| 87 |
]
|
| 88 |
},
|
| 89 |
-
"execution_count":
|
| 90 |
"metadata": {},
|
| 91 |
"output_type": "execute_result"
|
| 92 |
},
|
|
@@ -117,7 +117,7 @@
|
|
| 117 |
},
|
| 118 |
{
|
| 119 |
"cell_type": "code",
|
| 120 |
-
"execution_count":
|
| 121 |
"id": "3bc68e7c",
|
| 122 |
"metadata": {},
|
| 123 |
"outputs": [
|
|
@@ -127,7 +127,7 @@
|
|
| 127 |
"1.0986122886681098"
|
| 128 |
]
|
| 129 |
},
|
| 130 |
-
"execution_count":
|
| 131 |
"metadata": {},
|
| 132 |
"output_type": "execute_result"
|
| 133 |
}
|
|
@@ -139,7 +139,7 @@
|
|
| 139 |
},
|
| 140 |
{
|
| 141 |
"cell_type": "code",
|
| 142 |
-
"execution_count":
|
| 143 |
"id": "facb782e",
|
| 144 |
"metadata": {},
|
| 145 |
"outputs": [
|
|
@@ -149,7 +149,7 @@
|
|
| 149 |
"tensor([1.0000, 0.8808, 0.0000])"
|
| 150 |
]
|
| 151 |
},
|
| 152 |
-
"execution_count":
|
| 153 |
"metadata": {},
|
| 154 |
"output_type": "execute_result"
|
| 155 |
}
|
|
@@ -166,11 +166,30 @@
|
|
| 166 |
},
|
| 167 |
{
|
| 168 |
"cell_type": "code",
|
| 169 |
-
"execution_count":
|
| 170 |
"id": "f006f2fa",
|
| 171 |
"metadata": {},
|
| 172 |
-
"outputs": [
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
},
|
| 175 |
{
|
| 176 |
"cell_type": "code",
|
|
|
|
| 76 |
},
|
| 77 |
{
|
| 78 |
"cell_type": "code",
|
| 79 |
+
"execution_count": 6,
|
| 80 |
"id": "aec3ae8f",
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [
|
| 83 |
{
|
| 84 |
"data": {
|
| 85 |
"text/plain": [
|
| 86 |
+
"[<matplotlib.lines.Line2D at 0x7ea4980d6b30>]"
|
| 87 |
]
|
| 88 |
},
|
| 89 |
+
"execution_count": 6,
|
| 90 |
"metadata": {},
|
| 91 |
"output_type": "execute_result"
|
| 92 |
},
|
|
|
|
| 117 |
},
|
| 118 |
{
|
| 119 |
"cell_type": "code",
|
| 120 |
+
"execution_count": 7,
|
| 121 |
"id": "3bc68e7c",
|
| 122 |
"metadata": {},
|
| 123 |
"outputs": [
|
|
|
|
| 127 |
"1.0986122886681098"
|
| 128 |
]
|
| 129 |
},
|
| 130 |
+
"execution_count": 7,
|
| 131 |
"metadata": {},
|
| 132 |
"output_type": "execute_result"
|
| 133 |
}
|
|
|
|
| 139 |
},
|
| 140 |
{
|
| 141 |
"cell_type": "code",
|
| 142 |
+
"execution_count": 8,
|
| 143 |
"id": "facb782e",
|
| 144 |
"metadata": {},
|
| 145 |
"outputs": [
|
|
|
|
| 149 |
"tensor([1.0000, 0.8808, 0.0000])"
|
| 150 |
]
|
| 151 |
},
|
| 152 |
+
"execution_count": 8,
|
| 153 |
"metadata": {},
|
| 154 |
"output_type": "execute_result"
|
| 155 |
}
|
|
|
|
| 166 |
},
|
| 167 |
{
|
| 168 |
"cell_type": "code",
|
| 169 |
+
"execution_count": 9,
|
| 170 |
"id": "f006f2fa",
|
| 171 |
"metadata": {},
|
| 172 |
+
"outputs": [
|
| 173 |
+
{
|
| 174 |
+
"data": {
|
| 175 |
+
"text/plain": [
|
| 176 |
+
"tensor([1.0000, 0.7484, 0.0000])"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
"execution_count": 9,
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"output_type": "execute_result"
|
| 182 |
+
}
|
| 183 |
+
],
|
| 184 |
+
"source": [
|
| 185 |
+
"t = torch.tensor([1, 0.5, 0])\n",
|
| 186 |
+
"t_i = TimestepDistUtils.t_shift(\n",
|
| 187 |
+
" mu=torch.tensor(1.09),\n",
|
| 188 |
+
" sigma=1.0,\n",
|
| 189 |
+
" t=t\n",
|
| 190 |
+
")\n",
|
| 191 |
+
"t_i"
|
| 192 |
+
]
|
| 193 |
},
|
| 194 |
{
|
| 195 |
"cell_type": "code",
|
scripts/save_regression_outputs.py
CHANGED
|
@@ -15,6 +15,7 @@ def main():
|
|
| 15 |
parser.add_argument("--imsize", type=int, default=512)
|
| 16 |
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
|
| 17 |
parser.add_argument("--outdir", type=str, default="/data/regression_output")
|
|
|
|
| 18 |
args = parser.parse_args()
|
| 19 |
|
| 20 |
total_per = 10
|
|
@@ -54,6 +55,8 @@ def main():
|
|
| 54 |
image=[input_data["input_img"]],
|
| 55 |
prompt=input_data["instruction"],
|
| 56 |
vae_image_override=args.imsize * args.imsize,
|
|
|
|
|
|
|
| 57 |
))
|
| 58 |
|
| 59 |
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
|
|
|
|
| 15 |
parser.add_argument("--imsize", type=int, default=512)
|
| 16 |
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
|
| 17 |
parser.add_argument("--outdir", type=str, default="/data/regression_output")
|
| 18 |
+
parser.add_argument("--steps", type=int, default=50)
|
| 19 |
args = parser.parse_args()
|
| 20 |
|
| 21 |
total_per = 10
|
|
|
|
| 55 |
image=[input_data["input_img"]],
|
| 56 |
prompt=input_data["instruction"],
|
| 57 |
vae_image_override=args.imsize * args.imsize,
|
| 58 |
+
latent_size_override=args.imsize * args.imsize,
|
| 59 |
+
num_inference_steps=args.steps,
|
| 60 |
))
|
| 61 |
|
| 62 |
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
|
scripts/save_regression_outputs_modal.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 6 |
+
|
| 7 |
+
import fal
|
| 8 |
+
import modal
|
| 9 |
+
import torch
|
| 10 |
+
import tqdm
|
| 11 |
+
from datasets import concatenate_datasets, load_dataset, interleave_datasets
|
| 12 |
+
|
| 13 |
+
from qwenimage.datamodels import QwenConfig
|
| 14 |
+
from qwenimage.foundation import QwenImageFoundationSaveInterm
|
| 15 |
+
|
| 16 |
+
REQUIREMENTS_PATH = os.path.abspath("requirements.txt")
|
| 17 |
+
WAND_REQUIREMENTS_PATH = os.path.abspath("scripts/wand_requirements.txt")
|
| 18 |
+
|
| 19 |
+
local_modules = ["qwenimage","wandml","scripts"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
EDIT_TYPES = [
|
| 24 |
+
"color",
|
| 25 |
+
"style",
|
| 26 |
+
"replace",
|
| 27 |
+
"remove",
|
| 28 |
+
"add",
|
| 29 |
+
"motion change",
|
| 30 |
+
"background change",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
modalapp = modal.App("next-stroke")
|
| 34 |
+
modalapp.image = (
|
| 35 |
+
modal.Image.debian_slim(python_version="3.10")
|
| 36 |
+
.apt_install("git", "ffmpeg", "libsm6", "libxext6")
|
| 37 |
+
.pip_install_from_requirements(REQUIREMENTS_PATH)
|
| 38 |
+
.pip_install_from_requirements(WAND_REQUIREMENTS_PATH)
|
| 39 |
+
.add_local_python_source(*local_modules)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
@modalapp.function(
|
| 43 |
+
gpu="H100",
|
| 44 |
+
max_containers=8,
|
| 45 |
+
timeout=1 * 60 * 60,
|
| 46 |
+
volumes={
|
| 47 |
+
"/data/wand_cache": modal.Volume.from_name("FLUX_MODELS"),
|
| 48 |
+
"/data/checkpoints": modal.Volume.from_name("training_checkpoints", create_if_missing=True),
|
| 49 |
+
"/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),
|
| 50 |
+
|
| 51 |
+
"/root/.cache/huggingface/hub": modal.Volume.from_name("hf_cache", create_if_missing=True),
|
| 52 |
+
"/root/.cache/huggingface/datasets": modal.Volume.from_name("hf_cache_datasets", create_if_missing=True),
|
| 53 |
+
|
| 54 |
+
"/data/regression_data": modal.Volume.from_name("regression_data"),
|
| 55 |
+
"/data/edit_data": modal.Volume.from_name("edit_data"),
|
| 56 |
+
},
|
| 57 |
+
secrets=[
|
| 58 |
+
modal.Secret.from_name("wand-modal-gcloud-keyfile"),
|
| 59 |
+
modal.Secret.from_name("elea-huggingface-secret"),
|
| 60 |
+
],
|
| 61 |
+
)
|
| 62 |
+
def generate_regression_data(start_index=0, end_index=None, imsize=1024, indir="/data/edit_data/CrispEdit", outdir="/data/regression_data/regression_output_1024", total_per=10):
|
| 63 |
+
|
| 64 |
+
all_edit_datasets = []
|
| 65 |
+
for edit_type in EDIT_TYPES:
|
| 66 |
+
to_concat = []
|
| 67 |
+
for ds_n in range(total_per):
|
| 68 |
+
ds = load_dataset("parquet", data_files=f"{indir}/{edit_type}_{ds_n:05d}.parquet", split="train")
|
| 69 |
+
to_concat.append(ds)
|
| 70 |
+
edit_type_concat = concatenate_datasets(to_concat)
|
| 71 |
+
all_edit_datasets.append(edit_type_concat)
|
| 72 |
+
join_ds = interleave_datasets(all_edit_datasets)
|
| 73 |
+
|
| 74 |
+
save_base_dir = Path(outdir)
|
| 75 |
+
save_base_dir.mkdir(exist_ok=True, parents=True)
|
| 76 |
+
|
| 77 |
+
foundation = QwenImageFoundationSaveInterm(QwenConfig(vae_image_size=imsize * imsize))
|
| 78 |
+
|
| 79 |
+
if end_index is None:
|
| 80 |
+
end_index = len(join_ds)
|
| 81 |
+
dataset_to_process = join_ds.select(range(start_index, end_index))
|
| 82 |
+
|
| 83 |
+
for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=start_index):
|
| 84 |
+
|
| 85 |
+
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
|
| 86 |
+
image=[input_data["input_img"]],
|
| 87 |
+
prompt=input_data["instruction"],
|
| 88 |
+
vae_image_override=imsize * imsize,
|
| 89 |
+
latent_size_override=imsize * imsize,
|
| 90 |
+
))
|
| 91 |
+
|
| 92 |
+
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@modalapp.local_entrypoint()
|
| 96 |
+
def main(start:int, end:int, num_workers:int):
|
| 97 |
+
per_worker_load = (end - start) // num_workers
|
| 98 |
+
remainder = (end - start) % num_workers
|
| 99 |
+
if remainder > 0:
|
| 100 |
+
per_worker_load += 1
|
| 101 |
+
worker_load_starts = []
|
| 102 |
+
worker_load_ends = []
|
| 103 |
+
cur_start = start
|
| 104 |
+
for worker_idx in range(num_workers):
|
| 105 |
+
if worker_idx < num_workers -1:
|
| 106 |
+
cur_end = cur_start + per_worker_load
|
| 107 |
+
else:
|
| 108 |
+
cur_end = end # pass last worker less
|
| 109 |
+
worker_load_starts.append(cur_start)
|
| 110 |
+
worker_load_ends.append(cur_end)
|
| 111 |
+
cur_start += per_worker_load
|
| 112 |
+
|
| 113 |
+
print(f"loads: {list(zip(worker_load_starts, worker_load_ends))}")
|
| 114 |
+
outputs = list(generate_regression_data.map(worker_load_starts, worker_load_ends))
|
| 115 |
+
print(outputs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
scripts/straightness.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/train.py
CHANGED
|
@@ -70,6 +70,7 @@ modalapp.image = (
|
|
| 70 |
"/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),
|
| 71 |
|
| 72 |
"/root/.cache/huggingface/hub": modal.Volume.from_name("hf_cache", create_if_missing=True),
|
|
|
|
| 73 |
|
| 74 |
"/data/regression_data": modal.Volume.from_name("regression_data"),
|
| 75 |
"/data/edit_data": modal.Volume.from_name("edit_data"),
|
|
|
|
| 70 |
"/root/.cache/torch/hub/checkpoints": modal.Volume.from_name("torch_hub_checkpoints", create_if_missing=True),
|
| 71 |
|
| 72 |
"/root/.cache/huggingface/hub": modal.Volume.from_name("hf_cache", create_if_missing=True),
|
| 73 |
+
"/root/.cache/huggingface/datasets": modal.Volume.from_name("hf_cache_datasets", create_if_missing=True),
|
| 74 |
|
| 75 |
"/data/regression_data": modal.Volume.from_name("regression_data"),
|
| 76 |
"/data/edit_data": modal.Volume.from_name("edit_data"),
|
scripts/train_multi copy.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 5 |
+
--update configs/regression/base.yaml \
|
| 6 |
+
--update configs/regression/modal.yaml \
|
| 7 |
+
--update configs/regression/dm/mse-dm-a.yaml \
|
| 8 |
+
--update configs/compare/5k_steps.yaml \
|
| 9 |
+
> logs/mse-dm-a.log 2>&1 &
|
| 10 |
+
|
| 11 |
+
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 12 |
+
--update configs/regression/base.yaml \
|
| 13 |
+
--update configs/regression/modal.yaml \
|
| 14 |
+
--update configs/regression/dm/mse-dm-b.yaml \
|
| 15 |
+
--update configs/compare/5k_steps.yaml \
|
| 16 |
+
> logs/mse-dm-b.log 2>&1 &
|
| 17 |
+
|
| 18 |
+
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 19 |
+
--update configs/regression/base.yaml \
|
| 20 |
+
--update configs/regression/modal.yaml \
|
| 21 |
+
--update configs/regression/dm/mse-dm-c.yaml \
|
| 22 |
+
--update configs/compare/5k_steps.yaml \
|
| 23 |
+
> logs/mse-dm-c.log 2>&1 &
|
| 24 |
+
|
| 25 |
+
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 26 |
+
--update configs/regression/base.yaml \
|
| 27 |
+
--update configs/regression/modal.yaml \
|
| 28 |
+
--update configs/regression/dm/mse-dm-d.yaml \
|
| 29 |
+
--update configs/compare/5k_steps.yaml \
|
| 30 |
+
> logs/mse-dm-d.log 2>&1 &
|
scripts/train_multi.sh
CHANGED
|
@@ -1,61 +1,38 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
|
|
|
| 3 |
# nohup python scripts/train.py configs/base.yaml --where modal \
|
| 4 |
# --update configs/regression/base.yaml \
|
| 5 |
-
# --update configs/regression/modal
|
| 6 |
# --update configs/regression/mse.yaml \
|
| 7 |
-
# --update configs/regression/val_metrics.yaml \
|
| 8 |
# --update configs/compare/5k_steps.yaml \
|
| 9 |
-
# --update configs/optim/cosine.yaml \
|
| 10 |
-
# --update configs/regression/lo_mse.yaml \
|
| 11 |
# > logs/mse.log 2>&1 &
|
| 12 |
|
|
|
|
| 13 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 14 |
--update configs/regression/base.yaml \
|
| 15 |
-
--update configs/regression/modal
|
| 16 |
-
--update configs/regression/mse-triplet.yaml \
|
| 17 |
-
--update configs/regression/val_metrics.yaml \
|
| 18 |
--update configs/compare/5k_steps.yaml \
|
| 19 |
-
|
| 20 |
-
--update configs/regression/lo_mse.yaml \
|
| 21 |
-
> logs/mse-triplet.log 2>&1 &
|
| 22 |
-
|
| 23 |
-
# nohup python scripts/train.py configs/base.yaml --where modal \
|
| 24 |
-
# --update configs/regression/base.yaml \
|
| 25 |
-
# --update configs/regression/modal-datadirs.yaml \
|
| 26 |
-
# --update configs/regression/mse-neg-mse.yaml \
|
| 27 |
-
# --update configs/regression/val_metrics.yaml \
|
| 28 |
-
# --update configs/compare/5k_steps.yaml \
|
| 29 |
-
# --update configs/optim/cosine.yaml \
|
| 30 |
-
# --update configs/regression/lo_mse.yaml \
|
| 31 |
-
# > logs/mse-neg-mse.log 2>&1 &
|
| 32 |
|
| 33 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 34 |
--update configs/regression/base.yaml \
|
| 35 |
-
--update configs/regression/modal
|
| 36 |
-
--update configs/regression/mse-
|
| 37 |
-
--update configs/regression/val_metrics.yaml \
|
| 38 |
--update configs/compare/5k_steps.yaml \
|
| 39 |
-
|
| 40 |
-
--update configs/regression/lo_mse.yaml \
|
| 41 |
-
> logs/mse-pixel-mse.log 2>&1 &
|
| 42 |
|
| 43 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 44 |
--update configs/regression/base.yaml \
|
| 45 |
-
--update configs/regression/modal
|
| 46 |
-
--update configs/regression/mse-
|
| 47 |
-
--update configs/regression/val_metrics.yaml \
|
| 48 |
--update configs/compare/5k_steps.yaml \
|
| 49 |
-
|
| 50 |
-
--update configs/regression/lo_mse.yaml \
|
| 51 |
-
> logs/mse-pixel-lpips.log 2>&1 &
|
| 52 |
|
| 53 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 54 |
--update configs/regression/base.yaml \
|
| 55 |
-
--update configs/regression/modal
|
| 56 |
-
--update configs/regression/mse-
|
| 57 |
-
--update configs/regression/val_metrics.yaml \
|
| 58 |
--update configs/compare/5k_steps.yaml \
|
| 59 |
-
|
| 60 |
-
--update configs/regression/lo_mse.yaml \
|
| 61 |
-
> logs/mse-pixel-lpips.log 2>&1 &
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
|
| 3 |
+
|
| 4 |
# nohup python scripts/train.py configs/base.yaml --where modal \
|
| 5 |
# --update configs/regression/base.yaml \
|
| 6 |
+
# --update configs/regression/modal.yaml \
|
| 7 |
# --update configs/regression/mse.yaml \
|
|
|
|
| 8 |
# --update configs/compare/5k_steps.yaml \
|
|
|
|
|
|
|
| 9 |
# > logs/mse.log 2>&1 &
|
| 10 |
|
| 11 |
+
|
| 12 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 13 |
--update configs/regression/base.yaml \
|
| 14 |
+
--update configs/regression/modal.yaml \
|
| 15 |
+
--update configs/regression/triplet/mse-triplet-b.yaml \
|
|
|
|
| 16 |
--update configs/compare/5k_steps.yaml \
|
| 17 |
+
> logs/mse-triplet-b.log 2>&1 &
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 20 |
--update configs/regression/base.yaml \
|
| 21 |
+
--update configs/regression/modal.yaml \
|
| 22 |
+
--update configs/regression/triplet/mse-triplet-c.yaml \
|
|
|
|
| 23 |
--update configs/compare/5k_steps.yaml \
|
| 24 |
+
> logs/mse-triplet-c.log 2>&1 &
|
|
|
|
|
|
|
| 25 |
|
| 26 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 27 |
--update configs/regression/base.yaml \
|
| 28 |
+
--update configs/regression/modal.yaml \
|
| 29 |
+
--update configs/regression/triplet/mse-triplet-d.yaml \
|
|
|
|
| 30 |
--update configs/compare/5k_steps.yaml \
|
| 31 |
+
> logs/mse-triplet-d.log 2>&1 &
|
|
|
|
|
|
|
| 32 |
|
| 33 |
nohup python scripts/train.py configs/base.yaml --where modal \
|
| 34 |
--update configs/regression/base.yaml \
|
| 35 |
+
--update configs/regression/modal.yaml \
|
| 36 |
+
--update configs/regression/triplet/mse-triplet-e.yaml \
|
|
|
|
| 37 |
--update configs/compare/5k_steps.yaml \
|
| 38 |
+
> logs/mse-triplet-e.log 2>&1 &
|
|
|
|
|
|