Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
49cbc74
1
Parent(s):
16d51ab
regression training
Browse files- configs/base.yaml +1 -1
- configs/regression/base-reg.yaml +21 -0
- configs/regression/reg-mse-triplet.yaml +11 -0
- configs/regression/reg-mse.yaml +9 -0
- qwenimage/datamodels.py +45 -9
- qwenimage/debug.py +8 -4
- qwenimage/foundation.py +215 -18
- qwenimage/loss.py +87 -0
- qwenimage/sources.py +110 -12
- qwenimage/task.py +14 -0
- qwenimage/training.py +87 -36
- qwenimage/types.py +3 -0
- scripts/edit_datasets.ipynb +1 -1
- scripts/process_regression_data.ipynb +1274 -0
- scripts/save_regression_outputs.py +40 -75
configs/base.yaml
CHANGED
|
@@ -13,7 +13,7 @@ optim: "adamw"
|
|
| 13 |
learning_rate: 1.0e-4
|
| 14 |
num_workers: 4
|
| 15 |
resume_from_checkpoint: null
|
| 16 |
-
log_model_steps:
|
| 17 |
preprocessing_epoch_len: 64
|
| 18 |
preprocessing_epoch_repetitions: 1
|
| 19 |
|
|
|
|
| 13 |
learning_rate: 1.0e-4
|
| 14 |
num_workers: 4
|
| 15 |
resume_from_checkpoint: null
|
| 16 |
+
log_model_steps: null
|
| 17 |
preprocessing_epoch_len: 64
|
| 18 |
preprocessing_epoch_repetitions: 1
|
| 19 |
|
configs/regression/base-reg.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-base"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-base"
|
| 3 |
+
|
| 4 |
+
learning_rate: 1e-4
|
| 5 |
+
num_train_epochs: 1
|
| 6 |
+
max_train_steps: null
|
| 7 |
+
preprocessing_epoch_len: 0
|
| 8 |
+
preprocessing_epoch_repetitions: 1
|
| 9 |
+
num_validation_images: &val_num 32
|
| 10 |
+
num_sample_images: 2
|
| 11 |
+
train_range: [*val_num, null]
|
| 12 |
+
val_range: [0, *val_num]
|
| 13 |
+
test_range: [2, 4]
|
| 14 |
+
regression_base_pipe_steps: 8
|
| 15 |
+
|
| 16 |
+
training_type: "regression"
|
| 17 |
+
|
| 18 |
+
regression_data_dir: "/data/regression_output"
|
| 19 |
+
regression_gen_steps: 50
|
| 20 |
+
editing_data_dir: "/data/CrispEdit"
|
| 21 |
+
editing_total_per: 1
|
configs/regression/reg-mse-triplet.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse-triplet"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse-triplet"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
triplet: 1.0
|
| 7 |
+
|
| 8 |
+
validation_loss_terms:
|
| 9 |
+
mse: 1.0
|
| 10 |
+
triplet: 1.0
|
| 11 |
+
|
configs/regression/reg-mse.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb_run_name: "reg-mse"
|
| 2 |
+
output_dir: "/data/checkpoints/reg-mse"
|
| 3 |
+
|
| 4 |
+
train_loss_terms:
|
| 5 |
+
mse: 1.0
|
| 6 |
+
|
| 7 |
+
validation_loss_terms:
|
| 8 |
+
mse: 1.0
|
| 9 |
+
|
qwenimage/datamodels.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import enum
|
|
|
|
| 2 |
from typing import Literal
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from diffusers.image_processor import PipelineImageInput
|
| 6 |
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
|
|
|
|
| 8 |
from wandml.foundation.datamodels import FluxInputs
|
| 9 |
from wandml.trainers.datamodels import ExperimentTrainerParameters
|
| 10 |
|
|
@@ -27,12 +29,33 @@ class QwenInputs(BaseModel):
|
|
| 27 |
# extra="allow",
|
| 28 |
)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class QuantOptions(str, enum.Enum):
|
| 32 |
INT8WO = "int8wo"
|
| 33 |
INT4WO = "int4wo"
|
| 34 |
FP8ROW = "fp8row"
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
class QwenConfig(ExperimentTrainerParameters):
|
| 38 |
load_multi_view_lora: bool = False
|
|
@@ -49,14 +72,27 @@ class QwenConfig(ExperimentTrainerParameters):
|
|
| 49 |
quantize_text_encoder: bool = False
|
| 50 |
quantize_transformer: bool = False
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
style_title: str|None = None
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
|
|
|
| 1 |
import enum
|
| 2 |
+
from pathlib import Path
|
| 3 |
from typing import Literal
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from diffusers.image_processor import PipelineImageInput
|
| 7 |
from pydantic import BaseModel, ConfigDict, Field
|
| 8 |
|
| 9 |
+
from qwenimage.types import DataRange
|
| 10 |
from wandml.foundation.datamodels import FluxInputs
|
| 11 |
from wandml.trainers.datamodels import ExperimentTrainerParameters
|
| 12 |
|
|
|
|
| 29 |
# extra="allow",
|
| 30 |
)
|
| 31 |
|
| 32 |
+
class TrainingType(str, enum.Enum):
|
| 33 |
+
IM2IM = "im2im"
|
| 34 |
+
NAIVE = "naive"
|
| 35 |
+
REGRESSION = "regression"
|
| 36 |
+
|
| 37 |
+
@property
|
| 38 |
+
def is_style(self):
|
| 39 |
+
return self in [TrainingType.NAIVE, TrainingType.IM2IM]
|
| 40 |
|
| 41 |
class QuantOptions(str, enum.Enum):
|
| 42 |
INT8WO = "int8wo"
|
| 43 |
INT4WO = "int4wo"
|
| 44 |
FP8ROW = "fp8row"
|
| 45 |
|
| 46 |
+
LossTermSpecType = int|float|dict[str,int|float]|None
|
| 47 |
+
|
| 48 |
+
class QwenLossTerms(BaseModel):
|
| 49 |
+
mse: LossTermSpecType = 1.0
|
| 50 |
+
triplet: LossTermSpecType = 0.0
|
| 51 |
+
negative_mse: LossTermSpecType = 0.0
|
| 52 |
+
distribution_matching: LossTermSpecType = 0.0
|
| 53 |
+
negative_exponential: LossTermSpecType = 0.0
|
| 54 |
+
pixel_lpips: LossTermSpecType = 0.0
|
| 55 |
+
pixel_mse: LossTermSpecType = 0.0
|
| 56 |
+
adversarial: LossTermSpecType = 0.0
|
| 57 |
+
|
| 58 |
+
triplet_margin: float = 0.2
|
| 59 |
|
| 60 |
class QwenConfig(ExperimentTrainerParameters):
|
| 61 |
load_multi_view_lora: bool = False
|
|
|
|
| 72 |
quantize_text_encoder: bool = False
|
| 73 |
quantize_transformer: bool = False
|
| 74 |
|
| 75 |
+
|
| 76 |
+
train_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
| 77 |
+
validation_loss_terms:QwenLossTerms = Field(default_factory=QwenLossTerms)
|
| 78 |
+
|
| 79 |
+
training_type: TrainingType
|
| 80 |
+
train_range: DataRange|None=None
|
| 81 |
+
val_range: DataRange|None=None
|
| 82 |
+
test_range: DataRange|None=None
|
| 83 |
+
|
| 84 |
style_title: str|None = None
|
| 85 |
+
style_base_dir: str|None = None
|
| 86 |
+
style_csv_path: str|None = None
|
| 87 |
+
style_data_dir: str|None = None
|
| 88 |
+
style_ref_dir: str|None = None
|
| 89 |
+
style_val_with: str = "train"
|
| 90 |
+
naive_static_prompt: str|None = None
|
| 91 |
+
|
| 92 |
+
regression_data_dir: str|Path|None = None
|
| 93 |
+
regression_gen_steps: int = 50
|
| 94 |
+
editing_data_dir: str|Path|None = None
|
| 95 |
+
editing_total_per: int = 1
|
| 96 |
+
regression_base_pipe_steps: int = 8
|
| 97 |
+
|
| 98 |
|
qwenimage/debug.py
CHANGED
|
@@ -241,12 +241,16 @@ def fretry(func=None, *, exceptions=(Exception,), mod_args:tuple[Callable|None,
|
|
| 241 |
return decorator(func)
|
| 242 |
|
| 243 |
|
| 244 |
-
def texam(t: torch.Tensor):
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
if t.dtype.is_floating_point or t.dtype.is_complex:
|
| 247 |
mean_val = t.mean().item()
|
| 248 |
else:
|
| 249 |
mean_val = "N/A"
|
| 250 |
-
print(f"Min: {t.min().item()}, Max: {t.max().item()}, Mean: {mean_val}")
|
| 251 |
-
print(f"Device: {t.device}, Dtype: {t.dtype}, Requires Grad: {t.requires_grad}")
|
| 252 |
|
|
|
|
| 241 |
return decorator(func)
|
| 242 |
|
| 243 |
|
| 244 |
+
def texam(t: torch.Tensor, name=None):
|
| 245 |
+
if name is None:
|
| 246 |
+
name = ""
|
| 247 |
+
else:
|
| 248 |
+
name += " " # spacing
|
| 249 |
+
print(f"{name}Shape: {tuple(t.shape)}")
|
| 250 |
if t.dtype.is_floating_point or t.dtype.is_complex:
|
| 251 |
mean_val = t.mean().item()
|
| 252 |
else:
|
| 253 |
mean_val = "N/A"
|
| 254 |
+
print(f"{name}Min: {t.min().item()}, Max: {t.max().item()}, Mean: {mean_val}")
|
| 255 |
+
print(f"{name}Device: {t.device}, Dtype: {t.dtype}, Requires Grad: {t.requires_grad}")
|
| 256 |
|
qwenimage/foundation.py
CHANGED
|
@@ -5,6 +5,7 @@ import warnings
|
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline
|
|
|
|
| 8 |
import torch
|
| 9 |
from safetensors.torch import load_file, save_model
|
| 10 |
import torch.nn.functional as F
|
|
@@ -12,15 +13,18 @@ import torchvision.transforms.v2.functional as TF
|
|
| 12 |
from einops import rearrange
|
| 13 |
|
| 14 |
from qwenimage.datamodels import QwenConfig, QwenInputs
|
| 15 |
-
from qwenimage.debug import ctimed, ftimed, print_gpu_memory, texam
|
| 16 |
from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
|
| 17 |
from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
|
|
|
|
| 18 |
from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
|
| 19 |
from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
|
| 20 |
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
|
| 21 |
from qwenimage.optimization import simple_quantize_model
|
| 22 |
from qwenimage.sampling import TimestepDistUtils
|
| 23 |
from wandml import WandModel
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class QwenImageFoundation(WandModel):
|
|
@@ -159,8 +163,13 @@ class QwenImageFoundation(WandModel):
|
|
| 159 |
texam(latents)
|
| 160 |
return latents.to(dtype=self.dtype)
|
| 161 |
|
| 162 |
-
def latents_to_pil(self, latents):
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
latents = latents.unsqueeze(2)
|
| 165 |
|
| 166 |
latents = latents.to(self.dtype)
|
|
@@ -178,6 +187,11 @@ class QwenImageFoundation(WandModel):
|
|
| 178 |
|
| 179 |
latents = latents.to(device=self.device, dtype=self.dtype)
|
| 180 |
image = self.pipe.vae.decode(latents, return_dict=False)[0][:, :, 0] # F = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
image = self.pipe.image_processor.postprocess(image)
|
| 182 |
return image
|
| 183 |
|
|
@@ -191,6 +205,15 @@ class QwenImageFoundation(WandModel):
|
|
| 191 |
latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
|
| 192 |
return latents
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
@ftimed
|
| 196 |
def preprocess_batch(self, batch):
|
|
@@ -204,11 +227,7 @@ class QwenImageFoundation(WandModel):
|
|
| 204 |
print("preprocess_batch.references")
|
| 205 |
texam(references)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
with ctimed("text_encoder.cuda()"):
|
| 209 |
-
self.text_encoder.to(device=self.device)
|
| 210 |
-
if self.text_encoder_device != "cuda":
|
| 211 |
-
self.text_encoder_device = "cuda"
|
| 212 |
|
| 213 |
with torch.no_grad():
|
| 214 |
prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt(
|
|
@@ -229,12 +248,7 @@ class QwenImageFoundation(WandModel):
|
|
| 229 |
|
| 230 |
@ftimed
|
| 231 |
def single_step(self, batch) -> torch.Tensor:
|
| 232 |
-
|
| 233 |
-
if self.config.offload_text_encoder:
|
| 234 |
-
self.text_encoder.to(device="cpu") # offload
|
| 235 |
-
if self.text_encoder_device != "cpu":
|
| 236 |
-
self.text_encoder_device = "cpu"
|
| 237 |
-
print_gpu_memory()
|
| 238 |
|
| 239 |
if "prompt_embeds" not in batch:
|
| 240 |
batch = self.preprocess_batch(batch)
|
|
@@ -302,10 +316,7 @@ class QwenImageFoundation(WandModel):
|
|
| 302 |
|
| 303 |
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
|
| 304 |
print(inputs)
|
| 305 |
-
|
| 306 |
-
self.text_encoder.to(device=self.device)
|
| 307 |
-
if self.text_encoder_device != "cuda":
|
| 308 |
-
self.text_encoder_device = "cuda"
|
| 309 |
image = inputs.image[0]
|
| 310 |
w,h = image.size
|
| 311 |
h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
|
|
@@ -327,3 +338,189 @@ class QwenImageFoundationSaveInterm(QwenImageFoundation):
|
|
| 327 |
inputs.image = [image]
|
| 328 |
return self.pipe(**inputs.model_dump())
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline
|
| 8 |
+
import lpips
|
| 9 |
import torch
|
| 10 |
from safetensors.torch import load_file, save_model
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 13 |
from einops import rearrange
|
| 14 |
|
| 15 |
from qwenimage.datamodels import QwenConfig, QwenInputs
|
| 16 |
+
from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
|
| 17 |
from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
|
| 18 |
from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
|
| 19 |
+
from qwenimage.loss import LossAccumulator
|
| 20 |
from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
|
| 21 |
from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
|
| 22 |
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
|
| 23 |
from qwenimage.optimization import simple_quantize_model
|
| 24 |
from qwenimage.sampling import TimestepDistUtils
|
| 25 |
from wandml import WandModel
|
| 26 |
+
from wandml.core.logger import wand_logger
|
| 27 |
+
from wandml.trainers.experiment_trainer import ExperimentTrainer
|
| 28 |
|
| 29 |
|
| 30 |
class QwenImageFoundation(WandModel):
|
|
|
|
| 163 |
texam(latents)
|
| 164 |
return latents.to(dtype=self.dtype)
|
| 165 |
|
| 166 |
+
def latents_to_pil(self, latents, h=None, w=None, with_grad=False):
|
| 167 |
+
if not with_grad:
|
| 168 |
+
latents = latents.clone().detach()
|
| 169 |
+
if latents.dim() == 3: # 1d latent
|
| 170 |
+
if h is None or w is None:
|
| 171 |
+
raise ValueError(f"auto unpack needs h,w, got {h=}, {w=}")
|
| 172 |
+
latents = self.unpack_latents(latents, h=h, w=w)
|
| 173 |
latents = latents.unsqueeze(2)
|
| 174 |
|
| 175 |
latents = latents.to(self.dtype)
|
|
|
|
| 187 |
|
| 188 |
latents = latents.to(device=self.device, dtype=self.dtype)
|
| 189 |
image = self.pipe.vae.decode(latents, return_dict=False)[0][:, :, 0] # F = 1
|
| 190 |
+
|
| 191 |
+
if with_grad:
|
| 192 |
+
texam(image, "latents_to_pil.image")
|
| 193 |
+
return image
|
| 194 |
+
|
| 195 |
image = self.pipe.image_processor.postprocess(image)
|
| 196 |
return image
|
| 197 |
|
|
|
|
| 205 |
latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
|
| 206 |
return latents
|
| 207 |
|
| 208 |
+
@ftimed
|
| 209 |
+
def offload_text_encoder(self, device=str|torch.device):
|
| 210 |
+
if self.text_encoder_device == device:
|
| 211 |
+
return
|
| 212 |
+
print(f"Moving text encoder to {device}")
|
| 213 |
+
self.text_encoder_device = device
|
| 214 |
+
self.text_encoder.to(device)
|
| 215 |
+
if device == "cpu" or device == torch.device("cpu"):
|
| 216 |
+
print_gpu_memory(clear_mem="pre")
|
| 217 |
|
| 218 |
@ftimed
|
| 219 |
def preprocess_batch(self, batch):
|
|
|
|
| 227 |
print("preprocess_batch.references")
|
| 228 |
texam(references)
|
| 229 |
|
| 230 |
+
self.offload_text_encoder("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
with torch.no_grad():
|
| 233 |
prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt(
|
|
|
|
| 248 |
|
| 249 |
@ftimed
|
| 250 |
def single_step(self, batch) -> torch.Tensor:
|
| 251 |
+
self.offload_text_encoder("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
if "prompt_embeds" not in batch:
|
| 254 |
batch = self.preprocess_batch(batch)
|
|
|
|
| 316 |
|
| 317 |
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
|
| 318 |
print(inputs)
|
| 319 |
+
self.offload_text_encoder("cuda")
|
|
|
|
|
|
|
|
|
|
| 320 |
image = inputs.image[0]
|
| 321 |
w,h = image.size
|
| 322 |
h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
|
|
|
|
| 338 |
inputs.image = [image]
|
| 339 |
return self.pipe(**inputs.model_dump())
|
| 340 |
|
| 341 |
+
|
| 342 |
+
class QwenImageRegressionFoundation(QwenImageFoundation):
|
| 343 |
+
def __init__(self, config:QwenConfig, device=None):
|
| 344 |
+
super().__init__(config, device=device)
|
| 345 |
+
self.lpips_fn = lpips.LPIPS(net='vgg').to(device=self.device)
|
| 346 |
+
|
| 347 |
+
def preprocess_batch(self, batch):
|
| 348 |
+
return batch
|
| 349 |
+
|
| 350 |
+
@ftimed
|
| 351 |
+
def single_step(self, batch) -> torch.Tensor:
|
| 352 |
+
self.offload_text_encoder("cpu")
|
| 353 |
+
|
| 354 |
+
out_dict = batch["data"]
|
| 355 |
+
assert len(out_dict) == 1
|
| 356 |
+
out_dict = out_dict[0]
|
| 357 |
+
|
| 358 |
+
prompt_embeds = out_dict["prompt_embeds"]
|
| 359 |
+
prompt_embeds_mask = out_dict["prompt_embeds_mask"]
|
| 360 |
+
prompt_embeds = prompt_embeds.to(device=self.device, dtype=self.dtype)
|
| 361 |
+
prompt_embeds_mask = prompt_embeds_mask.to(device=self.device, dtype=self.dtype)
|
| 362 |
+
|
| 363 |
+
h_f16 = out_dict["height"] // 16
|
| 364 |
+
w_f16 = out_dict["width"] // 16
|
| 365 |
+
|
| 366 |
+
refs_1d = out_dict["image_latents"].to(device=self.device, dtype=self.dtype)
|
| 367 |
+
t = out_dict["t"].to(device=self.device, dtype=self.dtype)
|
| 368 |
+
x_0_1d = out_dict["output"].to(device=self.device, dtype=self.dtype)
|
| 369 |
+
x_t_1d = out_dict["latents_start"].to(device=self.device, dtype=self.dtype)
|
| 370 |
+
v_neg_1d = out_dict["noise_pred"].to(device=self.device, dtype=self.dtype)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
v_gt_1d = (x_t_1d - x_0_1d) / t
|
| 374 |
+
|
| 375 |
+
inp_1d = torch.cat([x_t_1d, refs_1d], dim=1)
|
| 376 |
+
print("inp_1d")
|
| 377 |
+
texam(inp_1d)
|
| 378 |
+
|
| 379 |
+
img_shapes = out_dict["img_shapes"]
|
| 380 |
+
txt_seq_lens = out_dict["txt_seq_lens"]
|
| 381 |
+
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=self.device)
|
| 382 |
+
|
| 383 |
+
v_pred_1d = self.transformer(
|
| 384 |
+
hidden_states=inp_1d,
|
| 385 |
+
encoder_hidden_states=prompt_embeds,
|
| 386 |
+
encoder_hidden_states_mask=prompt_embeds_mask,
|
| 387 |
+
timestep=t,
|
| 388 |
+
image_rotary_emb=image_rotary_emb,
|
| 389 |
+
return_dict=False,
|
| 390 |
+
)[0]
|
| 391 |
+
|
| 392 |
+
v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)]
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
split = batch["split"]
|
| 396 |
+
step = batch["step"]
|
| 397 |
+
if split == "train":
|
| 398 |
+
loss_terms = self.config.train_loss_terms.model_dump()
|
| 399 |
+
elif split == "validation":
|
| 400 |
+
loss_terms = self.config.validation_loss_terms.model_dump()
|
| 401 |
+
loss_accumulator = LossAccumulator(
|
| 402 |
+
terms=loss_terms,
|
| 403 |
+
step=step,
|
| 404 |
+
split=split,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if loss_accumulator.has("mse"):
|
| 408 |
+
if self.config.loss_weight_dist is not None:
|
| 409 |
+
mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="none").mean(dim=[1,2,3])
|
| 410 |
+
weights = self.timestep_dist_utils.get_loss_weighting(t)
|
| 411 |
+
mse_loss = torch.mean(mse_loss * weights)
|
| 412 |
+
else:
|
| 413 |
+
mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="mean")
|
| 414 |
+
loss_accumulator.accum("mse", mse_loss)
|
| 415 |
+
|
| 416 |
+
if loss_accumulator.has("triplet"):
|
| 417 |
+
eps = 1e-6
|
| 418 |
+
margin = loss_terms["triplet_margin"]
|
| 419 |
+
v_span = (v_gt_1d - v_neg_1d).pow(2).sum(dim=(-2,-1))
|
| 420 |
+
diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
|
| 421 |
+
diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2).sum(dim=(-2,-1))
|
| 422 |
+
diffv_gt_pred_reg = diffv_gt_pred / (v_span + eps)
|
| 423 |
+
diffv_neg_pred_reg = diffv_neg_pred / (v_span + eps)
|
| 424 |
+
|
| 425 |
+
texam(v_span, name="v_span")
|
| 426 |
+
texam(diffv_gt_pred, name="diffv_gt_pred")
|
| 427 |
+
texam(diffv_neg_pred, name="diffv_neg_pred")
|
| 428 |
+
texam(diffv_gt_pred_reg, name="diffv_gt_pred_reg")
|
| 429 |
+
texam(diffv_neg_pred_reg, name="diffv_neg_pred_reg")
|
| 430 |
+
texam(diffv_gt_pred_reg - diffv_neg_pred_reg, name="diffv_gt_pred_reg - diffv_neg_pred_reg")
|
| 431 |
+
|
| 432 |
+
triplet_loss = F.relu(diffv_gt_pred_reg - diffv_neg_pred_reg + margin).mean()
|
| 433 |
+
loss_accumulator.accum("triplet_loss", triplet_loss)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if loss_accumulator.has("negative_mse"):
|
| 437 |
+
neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
|
| 438 |
+
loss_accumulator.accum("negative_mse", neg_mse_loss)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
if loss_accumulator.has("distribution_matching"):
|
| 442 |
+
dm_v = (v_pred_1d - v_neg_1d + v_gt_1d).detach()
|
| 443 |
+
dm_mse = F.mse_loss(v_pred_1d, dm_v, reduction="mean")
|
| 444 |
+
loss_accumulator.accum("distribution_matching", dm_mse)
|
| 445 |
+
|
| 446 |
+
if loss_accumulator.has("negative_exponential"):
|
| 447 |
+
raise NotImplementedError()
|
| 448 |
+
|
| 449 |
+
if loss_accumulator.has("pixel_lpips") or loss_accumulator.has("pixel_mse"):
|
| 450 |
+
x_0_pred = x_0_1d - t * v_pred_1d
|
| 451 |
+
pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16).detach()
|
| 452 |
+
pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16)
|
| 453 |
+
|
| 454 |
+
if loss_accumulator.has("pixel_lpips"):
|
| 455 |
+
lpips_loss = self.lpips_fn(pixel_values_x0_gt, pixel_values_x0_pred)
|
| 456 |
+
loss_accumulator.accum("pixel_lpips", lpips_loss)
|
| 457 |
+
|
| 458 |
+
if loss_accumulator.has("pixel_mse"):
|
| 459 |
+
pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
|
| 460 |
+
loss_accumulator.accum("pixel_mse", pixel_mse_loss)
|
| 461 |
+
|
| 462 |
+
if loss_accumulator.has("adversarial"):
|
| 463 |
+
raise NotImplementedError()
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
loss = loss_accumulator.total
|
| 467 |
+
|
| 468 |
+
logs = loss_accumulator.logs()
|
| 469 |
+
wand_logger.log(logs, step=step, commit=False)
|
| 470 |
+
|
| 471 |
+
if self.should_log_training(step):
|
| 472 |
+
self.log_single_step_images(
|
| 473 |
+
h_f16,
|
| 474 |
+
w_f16,
|
| 475 |
+
t,
|
| 476 |
+
x_0_1d,
|
| 477 |
+
x_t_1d,
|
| 478 |
+
v_gt_1d,
|
| 479 |
+
v_neg_1d,
|
| 480 |
+
v_pred_1d,
|
| 481 |
+
visualize_velocities=True,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
return loss
|
| 485 |
+
|
| 486 |
+
def should_log_training(self, step) -> bool:
|
| 487 |
+
return (
|
| 488 |
+
self.training # don't log when validating
|
| 489 |
+
and ExperimentTrainer._is_step_trigger(step, self.config.log_batch_steps)
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
def log_single_step_images(
|
| 493 |
+
self,
|
| 494 |
+
h_f16,
|
| 495 |
+
w_f16,
|
| 496 |
+
t,
|
| 497 |
+
x_0_1d,
|
| 498 |
+
x_t_1d,
|
| 499 |
+
v_gt_1d,
|
| 500 |
+
v_neg_1d,
|
| 501 |
+
v_pred_1d,
|
| 502 |
+
visualize_velocities=True,
|
| 503 |
+
):
|
| 504 |
+
x_0_pred = x_0_1d - t * v_pred_1d
|
| 505 |
+
x_0_neg = x_0_1d - t * v_neg_1d
|
| 506 |
+
log_pils = {
|
| 507 |
+
"x_t_1d": self.latents_to_pil(x_t_1d, h=h_f16, w=w_f16),
|
| 508 |
+
"x_0": self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16),
|
| 509 |
+
"x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
|
| 510 |
+
"x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
|
| 511 |
+
}
|
| 512 |
+
if visualize_velocities:
|
| 513 |
+
log_pils.update({
|
| 514 |
+
"v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
|
| 515 |
+
"v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
|
| 516 |
+
"v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
|
| 517 |
+
})
|
| 518 |
+
|
| 519 |
+
wand_logger.log({
|
| 520 |
+
"train_images": log_pils,
|
| 521 |
+
}, commit=False)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
|
| 525 |
+
inputs.num_inference_steps = self.config.regression_base_pipe_steps # override
|
| 526 |
+
return super().base_pipe(inputs)
|
qwenimage/loss.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from wandml.core.wandmodel import WandModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LossAccumulator:
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
terms: dict[str, int|float|dict],
|
| 13 |
+
step: int|None=None,
|
| 14 |
+
split: str|None=None,
|
| 15 |
+
):
|
| 16 |
+
self.terms = terms
|
| 17 |
+
self.step = step
|
| 18 |
+
if split is not None:
|
| 19 |
+
self.split = split
|
| 20 |
+
self.prefix = f"{self.split}_"
|
| 21 |
+
else:
|
| 22 |
+
self.split = ""
|
| 23 |
+
self.prefix = ""
|
| 24 |
+
self.unweighted: dict[str, Tensor] = {}
|
| 25 |
+
self.weighted: dict[str, Tensor] = {}
|
| 26 |
+
|
| 27 |
+
def resolve_weight(self, name: str, step: int|None = None) -> float:
|
| 28 |
+
"""
|
| 29 |
+
loss weight spec:
|
| 30 |
+
- float | int
|
| 31 |
+
- dict: {"start": int, "end": int, "min": float, "max": float}
|
| 32 |
+
"""
|
| 33 |
+
spec = self.terms.get(name, 0.0)
|
| 34 |
+
|
| 35 |
+
if isinstance(spec, (int, float)):
|
| 36 |
+
return float(spec)
|
| 37 |
+
|
| 38 |
+
if isinstance(spec, dict):
|
| 39 |
+
try:
|
| 40 |
+
start = int(spec.get("start", 0))
|
| 41 |
+
end = int(spec["end"]) # required
|
| 42 |
+
vmin = float(spec.get("min", 0.0))
|
| 43 |
+
vmax = float(spec["max"]) # required
|
| 44 |
+
except Exception:
|
| 45 |
+
warnings.warn(f"Malformed dict {spec}; treat as disabled")
|
| 46 |
+
return 0.0
|
| 47 |
+
|
| 48 |
+
if self.step <= start:
|
| 49 |
+
return vmin
|
| 50 |
+
if self.step >= end:
|
| 51 |
+
return vmax
|
| 52 |
+
span = max(1, end - start)
|
| 53 |
+
t = (self.step - start) / span
|
| 54 |
+
return vmin + (vmax - vmin) * t
|
| 55 |
+
|
| 56 |
+
warnings.warn(f"Unknown spec type {spec}; treat as disabled")
|
| 57 |
+
return 0.0
|
| 58 |
+
|
| 59 |
+
def has(self, name: str) -> bool:
|
| 60 |
+
return self.resolve_weight(name) > 0
|
| 61 |
+
|
| 62 |
+
def accum(self, name: str, loss_value: Tensor, extra_weight: float|None = None) -> Tensor:
|
| 63 |
+
self.unweighted[name] = loss_value
|
| 64 |
+
|
| 65 |
+
w = self.resolve_weight(name)
|
| 66 |
+
if extra_weight is not None:
|
| 67 |
+
w *= float(extra_weight)
|
| 68 |
+
|
| 69 |
+
weighted = loss_value * w
|
| 70 |
+
self.weighted[name] = weighted
|
| 71 |
+
return weighted
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def total(self):
|
| 75 |
+
weighted_losses = list(self.weighted.values())
|
| 76 |
+
return torch.stack(weighted_losses).sum()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def logs(self) -> dict[str, float]:
|
| 80 |
+
# append prefix and suffix for logs
|
| 81 |
+
logs: dict[str, float] = {}
|
| 82 |
+
for k, v in self.unweighted.items():
|
| 83 |
+
logs[f"{self.prefix}_{k}"] = float(v.detach().item())
|
| 84 |
+
for k, v in self.weighted.items():
|
| 85 |
+
logs[f"{self.prefix}_{k}_weighted"] = float(v.detach().item())
|
| 86 |
+
return logs
|
| 87 |
+
|
qwenimage/sources.py
CHANGED
|
@@ -3,11 +3,35 @@
|
|
| 3 |
import csv
|
| 4 |
from pathlib import Path
|
| 5 |
import random
|
|
|
|
| 6 |
|
| 7 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from wandml.core.datamodels import SourceDataType
|
| 9 |
from wandml.core.source import Source
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
class StyleSource(Source):
|
| 12 |
_data_types = [
|
| 13 |
SourceDataType(name="image", type=Image.Image),
|
|
@@ -63,7 +87,7 @@ class StyleImagetoImageSource(Source):
|
|
| 63 |
SourceDataType(name="image", type=Image.Image),
|
| 64 |
SourceDataType(name="reference", type=Image.Image),
|
| 65 |
]
|
| 66 |
-
def __init__(self, csv_path, base_dir, style_title=None, data_range:
|
| 67 |
self.csv_path = Path(csv_path)
|
| 68 |
self.base_dir = Path(base_dir)
|
| 69 |
self.style_title = style_title
|
|
@@ -85,16 +109,9 @@ class StyleImagetoImageSource(Source):
|
|
| 85 |
})
|
| 86 |
|
| 87 |
if data_range is not None:
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
right = right * len(self.data)
|
| 92 |
-
remain_data = []
|
| 93 |
-
for i, d in enumerate(self.data):
|
| 94 |
-
if left <= i and i < right:
|
| 95 |
-
remain_data.append(d)
|
| 96 |
-
self.data = remain_data
|
| 97 |
-
|
| 98 |
print(f"{self.__class__} of len{len(self)}")
|
| 99 |
|
| 100 |
|
|
@@ -106,4 +123,85 @@ class StyleImagetoImageSource(Source):
|
|
| 106 |
prompt = item["prompt"]
|
| 107 |
input_pil = Image.open(item['input_image']).convert("RGB")
|
| 108 |
output_pil = Image.open(item['output_image']).convert("RGB")
|
| 109 |
-
return prompt, output_pil, input_pil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import csv
|
| 4 |
from pathlib import Path
|
| 5 |
import random
|
| 6 |
+
from typing import Literal
|
| 7 |
|
| 8 |
from PIL import Image
|
| 9 |
+
import torch
|
| 10 |
+
from datasets import concatenate_datasets, load_dataset, interleave_datasets
|
| 11 |
+
|
| 12 |
+
from qwenimage.types import DataRange
|
| 13 |
from wandml.core.datamodels import SourceDataType
|
| 14 |
from wandml.core.source import Source
|
| 15 |
|
| 16 |
+
def parse_datarange(dr: DataRange, length: int, return_as: Literal['list', 'range']='list'):
|
| 17 |
+
if not isinstance(length, int):
|
| 18 |
+
raise ValueError()
|
| 19 |
+
left, right = dr
|
| 20 |
+
if left is None:
|
| 21 |
+
left = 0
|
| 22 |
+
if right is None:
|
| 23 |
+
right = length
|
| 24 |
+
if (isinstance(left, float) or isinstance(right, float)) and (left<1 and right<1):
|
| 25 |
+
left = left * length
|
| 26 |
+
right = right * length
|
| 27 |
+
if return_as=="list":
|
| 28 |
+
return list(range(left, right))
|
| 29 |
+
elif return_as=="range":
|
| 30 |
+
return range(left, right)
|
| 31 |
+
else:
|
| 32 |
+
raise ValueError()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
class StyleSource(Source):
|
| 36 |
_data_types = [
|
| 37 |
SourceDataType(name="image", type=Image.Image),
|
|
|
|
| 87 |
SourceDataType(name="image", type=Image.Image),
|
| 88 |
SourceDataType(name="reference", type=Image.Image),
|
| 89 |
]
|
| 90 |
+
def __init__(self, csv_path, base_dir, style_title=None, data_range:DataRange|None=None):
|
| 91 |
self.csv_path = Path(csv_path)
|
| 92 |
self.base_dir = Path(base_dir)
|
| 93 |
self.style_title = style_title
|
|
|
|
| 109 |
})
|
| 110 |
|
| 111 |
if data_range is not None:
|
| 112 |
+
indexes = parse_datarange(data_range, len(self.data))
|
| 113 |
+
self.data = [self.data[i] for i in indexes]
|
| 114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
print(f"{self.__class__} of len{len(self)}")
|
| 116 |
|
| 117 |
|
|
|
|
| 123 |
prompt = item["prompt"]
|
| 124 |
input_pil = Image.open(item['input_image']).convert("RGB")
|
| 125 |
output_pil = Image.open(item['output_image']).convert("RGB")
|
| 126 |
+
return prompt, output_pil, input_pil
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class RegressionSource(Source):
|
| 130 |
+
_data_types = [
|
| 131 |
+
SourceDataType(name="data", type=dict),
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
def __init__(self, data_dir, gen_steps=50, data_range:DataRange|None=None):
|
| 135 |
+
if not isinstance(data_dir, Path):
|
| 136 |
+
data_dir = Path(data_dir)
|
| 137 |
+
self.data_paths = list(data_dir.glob("*.pt"))
|
| 138 |
+
if data_range is not None:
|
| 139 |
+
indexes = parse_datarange(data_range, len(self.data_paths))
|
| 140 |
+
self.data_paths = [self.data_paths[i] for i in indexes]
|
| 141 |
+
self.gen_steps = gen_steps
|
| 142 |
+
self._len = gen_steps * len(self.data_paths)
|
| 143 |
+
print(f"{self.__class__} of len{len(self)}")
|
| 144 |
+
|
| 145 |
+
def __len__(self):
|
| 146 |
+
return self._len
|
| 147 |
+
|
| 148 |
+
def __getitem__(self, idx):
|
| 149 |
+
data_idx = idx // self.gen_steps
|
| 150 |
+
step_idx = idx % self.gen_steps
|
| 151 |
+
out_dict = torch.load(self.data_paths[data_idx])
|
| 152 |
+
t = out_dict.pop(f"t_{step_idx}")
|
| 153 |
+
latents_start = out_dict.pop(f"latents_{step_idx}_start")
|
| 154 |
+
noise_pred = out_dict.pop(f"noise_pred_{step_idx}")
|
| 155 |
+
out_dict["t"] = t
|
| 156 |
+
out_dict["latents_start"] = latents_start
|
| 157 |
+
out_dict["noise_pred"] = noise_pred
|
| 158 |
+
return [out_dict,]
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class EditingSource(Source):
|
| 162 |
+
_data_types = [
|
| 163 |
+
SourceDataType(name="text", type=str),
|
| 164 |
+
SourceDataType(name="image", type=Image.Image),
|
| 165 |
+
SourceDataType(name="reference", type=Image.Image),
|
| 166 |
+
]
|
| 167 |
+
EDIT_TYPES = [
|
| 168 |
+
"color",
|
| 169 |
+
"style",
|
| 170 |
+
"replace",
|
| 171 |
+
"remove",
|
| 172 |
+
"add",
|
| 173 |
+
"motion change",
|
| 174 |
+
"background change",
|
| 175 |
+
]
|
| 176 |
+
def __init__(self, data_dir:Path, total_per=1, data_range:DataRange|None=None):
|
| 177 |
+
data_dir = Path(data_dir)
|
| 178 |
+
self.join_ds = self.build_dataset(data_dir, total_per)
|
| 179 |
+
|
| 180 |
+
if data_range is not None:
|
| 181 |
+
indexes = parse_datarange(data_range, len(self.join_ds))
|
| 182 |
+
self.join_ds = self.join_ds.select(indexes)
|
| 183 |
+
|
| 184 |
+
print(f"{self.__class__} of len{len(self)}")
|
| 185 |
+
|
| 186 |
+
def build_dataset(self, data_dir:Path, total_per:int):
|
| 187 |
+
all_edit_datasets = []
|
| 188 |
+
for edit_type in self.EDIT_TYPES:
|
| 189 |
+
to_concat = []
|
| 190 |
+
for ds_n in range(total_per):
|
| 191 |
+
ds = load_dataset("parquet", data_files=str(data_dir/f"{edit_type}_{ds_n:05d}.parquet"), split="train")
|
| 192 |
+
to_concat.append(ds)
|
| 193 |
+
edit_type_concat = concatenate_datasets(to_concat)
|
| 194 |
+
all_edit_datasets.append(edit_type_concat)
|
| 195 |
+
# consistent ordering for indexing, also allow extension by increasing total_per
|
| 196 |
+
join_ds = interleave_datasets(all_edit_datasets)
|
| 197 |
+
return join_ds
|
| 198 |
+
|
| 199 |
+
def __len__(self):
|
| 200 |
+
return len(self.join_ds)
|
| 201 |
+
|
| 202 |
+
def __getitem__(self, idx):
|
| 203 |
+
data = self.join_ds[idx]
|
| 204 |
+
reference = data["input_img"]
|
| 205 |
+
image = data["output_img"]
|
| 206 |
+
text = data["instruction"]
|
| 207 |
+
return text, image, reference
|
qwenimage/task.py
CHANGED
|
@@ -45,3 +45,17 @@ class TextToImageWithRefTask(Task):
|
|
| 45 |
"image": SourceDataType(name="reference", type=Image.Image),
|
| 46 |
}
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"image": SourceDataType(name="reference", type=Image.Image),
|
| 46 |
}
|
| 47 |
|
| 48 |
+
class RegressionTask(Task):
|
| 49 |
+
data_types = [
|
| 50 |
+
SourceDataType(name="data", type=dict),
|
| 51 |
+
]
|
| 52 |
+
type_transforms = [
|
| 53 |
+
Task.identity,
|
| 54 |
+
]
|
| 55 |
+
batch_type_transforms = [
|
| 56 |
+
Task.identity,
|
| 57 |
+
]
|
| 58 |
+
sample_type_transforms = [
|
| 59 |
+
Task.identity,
|
| 60 |
+
]
|
| 61 |
+
sample_input_dict = {}
|
qwenimage/training.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import subprocess
|
| 3 |
from pathlib import Path
|
| 4 |
import argparse
|
|
|
|
| 5 |
|
| 6 |
import yaml
|
| 7 |
import diffusers
|
|
@@ -17,10 +18,10 @@ from wandml.trainers.experiment_trainer import ExperimentTrainer
|
|
| 17 |
|
| 18 |
|
| 19 |
from qwenimage.finetuner import QwenLoraFinetuner
|
| 20 |
-
from qwenimage.sources import StyleSourceWithRandomRef, StyleImagetoImageSource
|
| 21 |
-
from qwenimage.task import TextToImageWithRefTask
|
| 22 |
from qwenimage.datamodels import QwenConfig
|
| 23 |
-
from qwenimage.foundation import QwenImageFoundation
|
| 24 |
|
| 25 |
|
| 26 |
wandml_utils.debug.DEBUG = True
|
|
@@ -33,65 +34,115 @@ def _deep_update(base: dict, updates: dict) -> dict:
|
|
| 33 |
base[k] = v
|
| 34 |
return base
|
| 35 |
|
| 36 |
-
def
|
| 37 |
-
WandAuth(ignore=True)
|
| 38 |
-
|
| 39 |
-
with open(config_path, "r") as f:
|
| 40 |
-
config = yaml.safe_load(f)
|
| 41 |
-
if update_config_paths is not None:
|
| 42 |
-
for update_config_path in update_config_paths:
|
| 43 |
-
with open(update_config_path, "r") as uf:
|
| 44 |
-
update_cfg = yaml.safe_load(uf)
|
| 45 |
-
if isinstance(update_cfg, dict):
|
| 46 |
-
config = _deep_update(config if isinstance(config, dict) else {}, update_cfg)
|
| 47 |
-
|
| 48 |
-
config = QwenConfig(
|
| 49 |
-
**config,
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
# Data
|
| 53 |
dp = WandDataPipe()
|
| 54 |
dp.set_task(TextToImageWithRefTask())
|
| 55 |
dp_test = WandDataPipe()
|
| 56 |
dp_test.set_task(TextToImageWithRefTask())
|
| 57 |
-
if config.
|
| 58 |
src = StyleSourceWithRandomRef(
|
| 59 |
-
config.
|
|
|
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
dp.add_source(src)
|
| 62 |
-
elif config.
|
| 63 |
src = StyleImagetoImageSource(
|
| 64 |
-
csv_path=config.
|
| 65 |
-
base_dir=config.
|
| 66 |
style_title=config.style_title,
|
| 67 |
data_range=config.train_range,
|
| 68 |
)
|
| 69 |
dp.add_source(src)
|
| 70 |
src_test = StyleImagetoImageSource(
|
| 71 |
-
csv_path=config.
|
| 72 |
-
base_dir=config.
|
| 73 |
style_title=config.style_title,
|
| 74 |
data_range=config.test_range,
|
| 75 |
)
|
| 76 |
dp_test.add_source(src_test)
|
| 77 |
else:
|
| 78 |
raise ValueError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Model
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
finetuner = QwenLoraFinetuner(foundation, config)
|
| 84 |
-
finetuner.load(
|
| 85 |
-
|
| 86 |
|
| 87 |
if len(dp_test) == 0:
|
|
|
|
| 88 |
dp_test = None
|
| 89 |
-
if
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
dp_val = dp_test
|
| 93 |
-
else:
|
| 94 |
-
raise ValueError()
|
| 95 |
trainer = ExperimentTrainer(
|
| 96 |
model=foundation,
|
| 97 |
datapipe=dp,
|
|
|
|
| 2 |
import subprocess
|
| 3 |
from pathlib import Path
|
| 4 |
import argparse
|
| 5 |
+
import warnings
|
| 6 |
|
| 7 |
import yaml
|
| 8 |
import diffusers
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
from qwenimage.finetuner import QwenLoraFinetuner
|
| 21 |
+
from qwenimage.sources import EditingSource, RegressionSource, StyleSourceWithRandomRef, StyleImagetoImageSource
|
| 22 |
+
from qwenimage.task import RegressionTask, TextToImageWithRefTask
|
| 23 |
from qwenimage.datamodels import QwenConfig
|
| 24 |
+
from qwenimage.foundation import QwenImageFoundation, QwenImageRegressionFoundation
|
| 25 |
|
| 26 |
|
| 27 |
wandml_utils.debug.DEBUG = True
|
|
|
|
| 34 |
base[k] = v
|
| 35 |
return base
|
| 36 |
|
| 37 |
+
def styles_datapipe(config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
dp = WandDataPipe()
|
| 39 |
dp.set_task(TextToImageWithRefTask())
|
| 40 |
dp_test = WandDataPipe()
|
| 41 |
dp_test.set_task(TextToImageWithRefTask())
|
| 42 |
+
if config.training_type == "naive":
|
| 43 |
src = StyleSourceWithRandomRef(
|
| 44 |
+
config.style_data_dir,
|
| 45 |
+
config.naive_static_prompt,
|
| 46 |
+
config.style_ref_dir,
|
| 47 |
+
set_len=config.max_train_steps,
|
| 48 |
)
|
| 49 |
dp.add_source(src)
|
| 50 |
+
elif config.training_type == "im2im":
|
| 51 |
src = StyleImagetoImageSource(
|
| 52 |
+
csv_path=config.style_csv_path,
|
| 53 |
+
base_dir=config.style_base_dir,
|
| 54 |
style_title=config.style_title,
|
| 55 |
data_range=config.train_range,
|
| 56 |
)
|
| 57 |
dp.add_source(src)
|
| 58 |
src_test = StyleImagetoImageSource(
|
| 59 |
+
csv_path=config.style_csv_path,
|
| 60 |
+
base_dir=config.style_base_dir,
|
| 61 |
style_title=config.style_title,
|
| 62 |
data_range=config.test_range,
|
| 63 |
)
|
| 64 |
dp_test.add_source(src_test)
|
| 65 |
else:
|
| 66 |
raise ValueError()
|
| 67 |
+
if config.style_val_with == "train":
|
| 68 |
+
dp_val = dp
|
| 69 |
+
elif config.style_val_with == "test":
|
| 70 |
+
dp_val = dp_test
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError()
|
| 73 |
+
return dp, dp_val, dp_test
|
| 74 |
|
| 75 |
+
def regression_datapipe(config):
|
| 76 |
+
dp = WandDataPipe()
|
| 77 |
+
dp.set_task(RegressionTask())
|
| 78 |
+
dp_val = WandDataPipe()
|
| 79 |
+
dp_val.set_task(RegressionTask())
|
| 80 |
+
dp_test = WandDataPipe()
|
| 81 |
+
dp_test.set_task(TextToImageWithRefTask())
|
| 82 |
+
|
| 83 |
+
src_train = RegressionSource(
|
| 84 |
+
data_dir=config.regression_data_dir,
|
| 85 |
+
gen_steps=config.regression_gen_steps,
|
| 86 |
+
data_range=config.train_range,
|
| 87 |
+
)
|
| 88 |
+
dp.add_source(src_train)
|
| 89 |
+
|
| 90 |
+
src_val = RegressionSource(
|
| 91 |
+
data_dir=config.regression_data_dir,
|
| 92 |
+
gen_steps=config.regression_gen_steps,
|
| 93 |
+
data_range=config.val_range,
|
| 94 |
+
)
|
| 95 |
+
dp_val.add_source(src_val)
|
| 96 |
+
|
| 97 |
+
src_test = EditingSource(
|
| 98 |
+
data_dir=config.editing_data_dir,
|
| 99 |
+
total_per=config.editing_total_per,
|
| 100 |
+
data_range=config.test_range,
|
| 101 |
+
)
|
| 102 |
+
dp_test.add_source(src_test)
|
| 103 |
+
|
| 104 |
+
return dp, dp_val, dp_test
|
| 105 |
+
|
| 106 |
+
def run_training(config_path: Path | str, update_config_paths: list[Path] | None = None):
|
| 107 |
+
WandAuth(ignore=True)
|
| 108 |
+
|
| 109 |
+
with open(config_path, "r") as f:
|
| 110 |
+
config = yaml.safe_load(f)
|
| 111 |
+
if update_config_paths is not None:
|
| 112 |
+
for update_config_path in update_config_paths:
|
| 113 |
+
with open(update_config_path, "r") as uf:
|
| 114 |
+
update_cfg = yaml.safe_load(uf)
|
| 115 |
+
if isinstance(update_cfg, dict):
|
| 116 |
+
config = _deep_update(config if isinstance(config, dict) else {}, update_cfg)
|
| 117 |
+
|
| 118 |
+
config = QwenConfig(
|
| 119 |
+
**config,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Data
|
| 123 |
+
if config.training_type.is_style:
|
| 124 |
+
dp, dp_val, dp_test = styles_datapipe(config)
|
| 125 |
+
elif config.training_type == "regression":
|
| 126 |
+
dp, dp_val, dp_test = regression_datapipe(config)
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Got {config.training_type=}")
|
| 129 |
|
| 130 |
# Model
|
| 131 |
+
if config.training_type.is_style:
|
| 132 |
+
foundation = QwenImageFoundation(config=config)
|
| 133 |
+
elif config.training_type == "regression":
|
| 134 |
+
foundation = QwenImageRegressionFoundation(config=config)
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(f"Got {config.training_type=}")
|
| 137 |
finetuner = QwenLoraFinetuner(foundation, config)
|
| 138 |
+
finetuner.load(config.resume_from_checkpoint, config.lora_rank)
|
|
|
|
| 139 |
|
| 140 |
if len(dp_test) == 0:
|
| 141 |
+
warnings.warn("Test datapipe is removed for being len=0")
|
| 142 |
dp_test = None
|
| 143 |
+
if len(dp_val) == 0:
|
| 144 |
+
warnings.warn("Validation datapipe is removed for being len=0")
|
| 145 |
+
dp_val = None
|
|
|
|
|
|
|
|
|
|
| 146 |
trainer = ExperimentTrainer(
|
| 147 |
model=foundation,
|
| 148 |
datapipe=dp,
|
qwenimage/types.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
DataRange = tuple[int|float|None,int|float|None]
|
scripts/edit_datasets.ipynb
CHANGED
|
@@ -89,7 +89,7 @@
|
|
| 89 |
" edit_type_concat = concatenate_datasets(to_concat)\n",
|
| 90 |
" all_edit_datasets.append(edit_type_concat)\n",
|
| 91 |
"\n",
|
| 92 |
-
"# consistent ordering for indexing, also allow extension\n",
|
| 93 |
"join_ds = interleave_datasets(all_edit_datasets)"
|
| 94 |
]
|
| 95 |
},
|
|
|
|
| 89 |
" edit_type_concat = concatenate_datasets(to_concat)\n",
|
| 90 |
" all_edit_datasets.append(edit_type_concat)\n",
|
| 91 |
"\n",
|
| 92 |
+
"# consistent ordering for indexing, also allow extension by increasing total_per\n",
|
| 93 |
"join_ds = interleave_datasets(all_edit_datasets)"
|
| 94 |
]
|
| 95 |
},
|
scripts/process_regression_data.ipynb
ADDED
|
@@ -0,0 +1,1274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "e1e781e9",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"%cd /home/ubuntu/Qwen-Image-Edit-Angles"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"id": "d6192ee5",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [
|
| 19 |
+
{
|
| 20 |
+
"data": {
|
| 21 |
+
"text/plain": [
|
| 22 |
+
"4941"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
"execution_count": 11,
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"output_type": "execute_result"
|
| 28 |
+
}
|
| 29 |
+
],
|
| 30 |
+
"source": [
|
| 31 |
+
"import glob\n",
|
| 32 |
+
"from pathlib import Path\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"base_data = Path(\"/data/regression_output\")\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"all_reg = list(base_data.glob(\"*.pt\"))\n",
|
| 37 |
+
"max_ind = max([int(reg_pth.stem) for reg_pth in all_reg])\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"max_ind"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": 14,
|
| 45 |
+
"id": "b5124900",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [
|
| 48 |
+
{
|
| 49 |
+
"name": "stdout",
|
| 50 |
+
"output_type": "stream",
|
| 51 |
+
"text": [
|
| 52 |
+
"prompt_embeds\n",
|
| 53 |
+
"prompt_embeds_mask\n",
|
| 54 |
+
"noise\n",
|
| 55 |
+
"image_latents\n",
|
| 56 |
+
"vae_image_sizes\n",
|
| 57 |
+
"img_shapes\n",
|
| 58 |
+
"txt_seq_lens\n",
|
| 59 |
+
"t_0\n",
|
| 60 |
+
"latents_0_start\n",
|
| 61 |
+
"noise_pred_0\n",
|
| 62 |
+
"t_1\n",
|
| 63 |
+
"latents_1_start\n",
|
| 64 |
+
"noise_pred_1\n",
|
| 65 |
+
"t_2\n",
|
| 66 |
+
"latents_2_start\n",
|
| 67 |
+
"noise_pred_2\n",
|
| 68 |
+
"t_3\n",
|
| 69 |
+
"latents_3_start\n",
|
| 70 |
+
"noise_pred_3\n",
|
| 71 |
+
"t_4\n",
|
| 72 |
+
"latents_4_start\n",
|
| 73 |
+
"noise_pred_4\n",
|
| 74 |
+
"t_5\n",
|
| 75 |
+
"latents_5_start\n",
|
| 76 |
+
"noise_pred_5\n",
|
| 77 |
+
"t_6\n",
|
| 78 |
+
"latents_6_start\n",
|
| 79 |
+
"noise_pred_6\n",
|
| 80 |
+
"t_7\n",
|
| 81 |
+
"latents_7_start\n",
|
| 82 |
+
"noise_pred_7\n",
|
| 83 |
+
"t_8\n",
|
| 84 |
+
"latents_8_start\n",
|
| 85 |
+
"noise_pred_8\n",
|
| 86 |
+
"t_9\n",
|
| 87 |
+
"latents_9_start\n",
|
| 88 |
+
"noise_pred_9\n",
|
| 89 |
+
"t_10\n",
|
| 90 |
+
"latents_10_start\n",
|
| 91 |
+
"noise_pred_10\n",
|
| 92 |
+
"t_11\n",
|
| 93 |
+
"latents_11_start\n",
|
| 94 |
+
"noise_pred_11\n",
|
| 95 |
+
"t_12\n",
|
| 96 |
+
"latents_12_start\n",
|
| 97 |
+
"noise_pred_12\n",
|
| 98 |
+
"t_13\n",
|
| 99 |
+
"latents_13_start\n",
|
| 100 |
+
"noise_pred_13\n",
|
| 101 |
+
"t_14\n",
|
| 102 |
+
"latents_14_start\n",
|
| 103 |
+
"noise_pred_14\n",
|
| 104 |
+
"t_15\n",
|
| 105 |
+
"latents_15_start\n",
|
| 106 |
+
"noise_pred_15\n",
|
| 107 |
+
"t_16\n",
|
| 108 |
+
"latents_16_start\n",
|
| 109 |
+
"noise_pred_16\n",
|
| 110 |
+
"t_17\n",
|
| 111 |
+
"latents_17_start\n",
|
| 112 |
+
"noise_pred_17\n",
|
| 113 |
+
"t_18\n",
|
| 114 |
+
"latents_18_start\n",
|
| 115 |
+
"noise_pred_18\n",
|
| 116 |
+
"t_19\n",
|
| 117 |
+
"latents_19_start\n",
|
| 118 |
+
"noise_pred_19\n",
|
| 119 |
+
"t_20\n",
|
| 120 |
+
"latents_20_start\n",
|
| 121 |
+
"noise_pred_20\n",
|
| 122 |
+
"t_21\n",
|
| 123 |
+
"latents_21_start\n",
|
| 124 |
+
"noise_pred_21\n",
|
| 125 |
+
"t_22\n",
|
| 126 |
+
"latents_22_start\n",
|
| 127 |
+
"noise_pred_22\n",
|
| 128 |
+
"t_23\n",
|
| 129 |
+
"latents_23_start\n",
|
| 130 |
+
"noise_pred_23\n",
|
| 131 |
+
"t_24\n",
|
| 132 |
+
"latents_24_start\n",
|
| 133 |
+
"noise_pred_24\n",
|
| 134 |
+
"t_25\n",
|
| 135 |
+
"latents_25_start\n",
|
| 136 |
+
"noise_pred_25\n",
|
| 137 |
+
"t_26\n",
|
| 138 |
+
"latents_26_start\n",
|
| 139 |
+
"noise_pred_26\n",
|
| 140 |
+
"t_27\n",
|
| 141 |
+
"latents_27_start\n",
|
| 142 |
+
"noise_pred_27\n",
|
| 143 |
+
"t_28\n",
|
| 144 |
+
"latents_28_start\n",
|
| 145 |
+
"noise_pred_28\n",
|
| 146 |
+
"t_29\n",
|
| 147 |
+
"latents_29_start\n",
|
| 148 |
+
"noise_pred_29\n",
|
| 149 |
+
"t_30\n",
|
| 150 |
+
"latents_30_start\n",
|
| 151 |
+
"noise_pred_30\n",
|
| 152 |
+
"t_31\n",
|
| 153 |
+
"latents_31_start\n",
|
| 154 |
+
"noise_pred_31\n",
|
| 155 |
+
"t_32\n",
|
| 156 |
+
"latents_32_start\n",
|
| 157 |
+
"noise_pred_32\n",
|
| 158 |
+
"t_33\n",
|
| 159 |
+
"latents_33_start\n",
|
| 160 |
+
"noise_pred_33\n",
|
| 161 |
+
"t_34\n",
|
| 162 |
+
"latents_34_start\n",
|
| 163 |
+
"noise_pred_34\n",
|
| 164 |
+
"t_35\n",
|
| 165 |
+
"latents_35_start\n",
|
| 166 |
+
"noise_pred_35\n",
|
| 167 |
+
"t_36\n",
|
| 168 |
+
"latents_36_start\n",
|
| 169 |
+
"noise_pred_36\n",
|
| 170 |
+
"t_37\n",
|
| 171 |
+
"latents_37_start\n",
|
| 172 |
+
"noise_pred_37\n",
|
| 173 |
+
"t_38\n",
|
| 174 |
+
"latents_38_start\n",
|
| 175 |
+
"noise_pred_38\n",
|
| 176 |
+
"t_39\n",
|
| 177 |
+
"latents_39_start\n",
|
| 178 |
+
"noise_pred_39\n",
|
| 179 |
+
"t_40\n",
|
| 180 |
+
"latents_40_start\n",
|
| 181 |
+
"noise_pred_40\n",
|
| 182 |
+
"t_41\n",
|
| 183 |
+
"latents_41_start\n",
|
| 184 |
+
"noise_pred_41\n",
|
| 185 |
+
"t_42\n",
|
| 186 |
+
"latents_42_start\n",
|
| 187 |
+
"noise_pred_42\n",
|
| 188 |
+
"t_43\n",
|
| 189 |
+
"latents_43_start\n",
|
| 190 |
+
"noise_pred_43\n",
|
| 191 |
+
"t_44\n",
|
| 192 |
+
"latents_44_start\n",
|
| 193 |
+
"noise_pred_44\n",
|
| 194 |
+
"t_45\n",
|
| 195 |
+
"latents_45_start\n",
|
| 196 |
+
"noise_pred_45\n",
|
| 197 |
+
"t_46\n",
|
| 198 |
+
"latents_46_start\n",
|
| 199 |
+
"noise_pred_46\n",
|
| 200 |
+
"t_47\n",
|
| 201 |
+
"latents_47_start\n",
|
| 202 |
+
"noise_pred_47\n",
|
| 203 |
+
"t_48\n",
|
| 204 |
+
"latents_48_start\n",
|
| 205 |
+
"noise_pred_48\n",
|
| 206 |
+
"t_49\n",
|
| 207 |
+
"latents_49_start\n",
|
| 208 |
+
"noise_pred_49\n",
|
| 209 |
+
"output\n",
|
| 210 |
+
"height\n",
|
| 211 |
+
"width\n"
|
| 212 |
+
]
|
| 213 |
+
}
|
| 214 |
+
],
|
| 215 |
+
"source": [
|
| 216 |
+
"import torch\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"out = all_reg[0]\n",
|
| 219 |
+
"out_dict = torch.load(out)\n",
|
| 220 |
+
"for k in out_dict.keys():\n",
|
| 221 |
+
" print(k)"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"cell_type": "code",
|
| 226 |
+
"execution_count": null,
|
| 227 |
+
"id": "74f693db",
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"outputs": [
|
| 230 |
+
{
|
| 231 |
+
"data": {
|
| 232 |
+
"text/plain": [
|
| 233 |
+
"'003329'"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
"execution_count": 4,
|
| 237 |
+
"metadata": {},
|
| 238 |
+
"output_type": "execute_result"
|
| 239 |
+
}
|
| 240 |
+
],
|
| 241 |
+
"source": []
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "code",
|
| 245 |
+
"execution_count": 7,
|
| 246 |
+
"id": "da107d9f",
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"outputs": [
|
| 249 |
+
{
|
| 250 |
+
"name": "stdout",
|
| 251 |
+
"output_type": "stream",
|
| 252 |
+
"text": [
|
| 253 |
+
"69G\t/data/regression_output\n"
|
| 254 |
+
]
|
| 255 |
+
}
|
| 256 |
+
],
|
| 257 |
+
"source": [
|
| 258 |
+
"!du -h {base_data}"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": null,
|
| 264 |
+
"id": "269c0bfb",
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"outputs": [],
|
| 267 |
+
"source": []
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"cell_type": "code",
|
| 271 |
+
"execution_count": 16,
|
| 272 |
+
"id": "5964bf2b",
|
| 273 |
+
"metadata": {},
|
| 274 |
+
"outputs": [],
|
| 275 |
+
"source": [
|
| 276 |
+
"class RegressionSource:\n",
|
| 277 |
+
" # WIP\n",
|
| 278 |
+
"\n",
|
| 279 |
+
" def __init__(self, data_dir, gen_steps=50):\n",
|
| 280 |
+
" if not isinstance(data_dir, Path):\n",
|
| 281 |
+
" data_dir = Path(data_dir)\n",
|
| 282 |
+
" self.data_paths = list(data_dir.glob(\"*.pt\"))\n",
|
| 283 |
+
" self.gen_steps = gen_steps\n",
|
| 284 |
+
" self._len = gen_steps * len(self.data_paths)\n",
|
| 285 |
+
" \n",
|
| 286 |
+
" def __len__(self):\n",
|
| 287 |
+
" return self._len\n",
|
| 288 |
+
" \n",
|
| 289 |
+
" def __getitem__(self, idx):\n",
|
| 290 |
+
" data_idx = idx // self.gen_steps\n",
|
| 291 |
+
" step_idx = idx % self.gen_steps\n",
|
| 292 |
+
" out_dict = torch.load(self.data_paths[data_idx])\n",
|
| 293 |
+
" t = out_dict.pop(f\"t_{step_idx}\")\n",
|
| 294 |
+
" latents_start = out_dict.pop(f\"latents_{step_idx}_start\")\n",
|
| 295 |
+
" noise_pred = out_dict.pop(f\"noise_pred_{step_idx}\")\n",
|
| 296 |
+
" out_dict[\"t\"] = t\n",
|
| 297 |
+
" out_dict[\"latents_start\"] = latents_start\n",
|
| 298 |
+
" out_dict[\"noise_pred\"] = noise_pred\n",
|
| 299 |
+
" return out_dict\n",
|
| 300 |
+
"\n",
|
| 301 |
+
" \n"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"cell_type": "code",
|
| 306 |
+
"execution_count": 17,
|
| 307 |
+
"id": "b62e7bec",
|
| 308 |
+
"metadata": {},
|
| 309 |
+
"outputs": [],
|
| 310 |
+
"source": [
|
| 311 |
+
"src = RegressionSource(\"/data/regression_output\")"
|
| 312 |
+
]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"cell_type": "code",
|
| 316 |
+
"execution_count": null,
|
| 317 |
+
"id": "4ee68ab3",
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"outputs": [],
|
| 320 |
+
"source": []
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"cell_type": "code",
|
| 324 |
+
"execution_count": 18,
|
| 325 |
+
"id": "9738e1d4",
|
| 326 |
+
"metadata": {},
|
| 327 |
+
"outputs": [
|
| 328 |
+
{
|
| 329 |
+
"data": {
|
| 330 |
+
"text/plain": [
|
| 331 |
+
"{'prompt_embeds': tensor([[[ 3.2188, 3.4375, 3.1719, ..., 0.3535, 1.7812, 2.0312],\n",
|
| 332 |
+
" [ 3.0938, 1.9297, 0.7031, ..., 2.0625, -0.2314, 1.2266],\n",
|
| 333 |
+
" [ 2.6250, 1.7031, 3.5625, ..., 0.8828, 2.1719, 1.4766],\n",
|
| 334 |
+
" ...,\n",
|
| 335 |
+
" [ 4.7812, 0.1689, 4.4688, ..., 5.0000, -1.8359, -0.7500],\n",
|
| 336 |
+
" [-0.0654, 2.1406, -1.4922, ..., 0.7930, 3.9844, 1.6406],\n",
|
| 337 |
+
" [-2.7031, 1.5547, 2.6094, ..., -0.0481, 0.1582, 0.7383]]],\n",
|
| 338 |
+
" dtype=torch.bfloat16),\n",
|
| 339 |
+
" 'prompt_embeds_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 340 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 341 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 342 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 343 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 344 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 345 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 346 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 347 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
|
| 348 |
+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),\n",
|
| 349 |
+
" 'noise': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n",
|
| 350 |
+
" [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n",
|
| 351 |
+
" [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n",
|
| 352 |
+
" ...,\n",
|
| 353 |
+
" [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n",
|
| 354 |
+
" [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n",
|
| 355 |
+
" [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n",
|
| 356 |
+
" dtype=torch.bfloat16),\n",
|
| 357 |
+
" 'image_latents': tensor([[[ 0.1719, 0.0194, 0.0084, ..., -0.1494, 0.0552, 0.2295],\n",
|
| 358 |
+
" [ 0.1777, 0.1406, 0.1592, ..., 0.1260, -0.2412, -0.0041],\n",
|
| 359 |
+
" [ 0.1187, 0.2324, 0.1104, ..., 0.0801, 0.3516, 0.4414],\n",
|
| 360 |
+
" ...,\n",
|
| 361 |
+
" [-0.0972, -0.3242, -0.3027, ..., 0.3672, 0.1699, 0.4004],\n",
|
| 362 |
+
" [-0.1221, -0.0125, -0.3867, ..., 0.7031, 0.8477, 0.8320],\n",
|
| 363 |
+
" [-0.1416, -0.1914, -0.3359, ..., 0.9883, 1.3359, 0.7422]]],\n",
|
| 364 |
+
" dtype=torch.bfloat16),\n",
|
| 365 |
+
" 'vae_image_sizes': [(448, 576)],\n",
|
| 366 |
+
" 'img_shapes': [[(1, 36, 28), (1, 36, 28)]],\n",
|
| 367 |
+
" 'txt_seq_lens': [228],\n",
|
| 368 |
+
" 't_1': tensor([0.9883], dtype=torch.bfloat16),\n",
|
| 369 |
+
" 'latents_1_start': tensor([[[ 1.9531, -0.7930, 0.6289, ..., -1.7188, 1.0312, 0.3770],\n",
|
| 370 |
+
" [ 1.6406, 0.1143, -0.1533, ..., 0.7461, -0.1748, -1.6719],\n",
|
| 371 |
+
" [ 0.3945, 0.3887, 0.1895, ..., -0.9141, -1.3125, -2.2969],\n",
|
| 372 |
+
" ...,\n",
|
| 373 |
+
" [-1.4609, 0.2471, 1.3203, ..., 0.1826, -0.6133, 0.7578],\n",
|
| 374 |
+
" [ 1.0234, 1.3516, -0.1582, ..., 0.1226, 1.4609, -0.2715],\n",
|
| 375 |
+
" [ 0.4863, -1.1016, 2.3125, ..., 0.8281, -0.8086, 0.4297]]],\n",
|
| 376 |
+
" dtype=torch.bfloat16),\n",
|
| 377 |
+
" 'noise_pred_1': tensor([[[ 1.9062, -0.9102, 0.5742, ..., -1.7422, 1.0625, 0.3359],\n",
|
| 378 |
+
" [ 1.5859, 0.0306, -0.2637, ..., 0.7539, -0.1768, -1.7969],\n",
|
| 379 |
+
" [ 0.3184, 0.3066, 0.1592, ..., -1.0391, -1.5391, -2.5625],\n",
|
| 380 |
+
" ...,\n",
|
| 381 |
+
" [-1.2734, 0.4941, 1.5781, ..., -0.2344, -1.0156, 0.3477],\n",
|
| 382 |
+
" [ 1.2422, 1.5234, 0.0510, ..., -0.5820, 0.9219, -1.0859],\n",
|
| 383 |
+
" [ 0.6172, -0.9336, 2.5781, ..., -0.0801, -1.7734, -0.3730]]],\n",
|
| 384 |
+
" dtype=torch.bfloat16),\n",
|
| 385 |
+
" 't_2': tensor([0.9766], dtype=torch.bfloat16),\n",
|
| 386 |
+
" 'latents_2_start': tensor([[[ 1.9297, -0.7812, 0.6211, ..., -1.6953, 1.0156, 0.3730],\n",
|
| 387 |
+
" [ 1.6250, 0.1138, -0.1504, ..., 0.7383, -0.1729, -1.6484],\n",
|
| 388 |
+
" [ 0.3906, 0.3848, 0.1875, ..., -0.9023, -1.2969, -2.2656],\n",
|
| 389 |
+
" ...,\n",
|
| 390 |
+
" [-1.4453, 0.2412, 1.3047, ..., 0.1855, -0.6016, 0.7539],\n",
|
| 391 |
+
" [ 1.0078, 1.3359, -0.1592, ..., 0.1299, 1.4531, -0.2578],\n",
|
| 392 |
+
" [ 0.4785, -1.0938, 2.2812, ..., 0.8281, -0.7891, 0.4336]]],\n",
|
| 393 |
+
" dtype=torch.bfloat16),\n",
|
| 394 |
+
" 'noise_pred_2': tensor([[[ 1.8984, -0.9219, 0.5664, ..., -1.7188, 1.0703, 0.3633],\n",
|
| 395 |
+
" [ 1.5859, 0.0256, -0.2539, ..., 0.7578, -0.1719, -1.7656],\n",
|
| 396 |
+
" [ 0.3105, 0.3027, 0.1611, ..., -1.0000, -1.4688, -2.4688],\n",
|
| 397 |
+
" ...,\n",
|
| 398 |
+
" [-1.2969, 0.4453, 1.5625, ..., -0.1934, -0.9883, 0.4082],\n",
|
| 399 |
+
" [ 1.2188, 1.4844, -0.0028, ..., -0.4492, 1.0312, -0.9180],\n",
|
| 400 |
+
" [ 0.5820, -1.0156, 2.5156, ..., 0.1885, -1.5391, -0.1602]]],\n",
|
| 401 |
+
" dtype=torch.bfloat16),\n",
|
| 402 |
+
" 't_3': tensor([0.9648], dtype=torch.bfloat16),\n",
|
| 403 |
+
" 'latents_3_start': tensor([[[ 1.9062, -0.7695, 0.6133, ..., -1.6719, 1.0000, 0.3691],\n",
|
| 404 |
+
" [ 1.6094, 0.1133, -0.1475, ..., 0.7305, -0.1709, -1.6250],\n",
|
| 405 |
+
" [ 0.3867, 0.3809, 0.1855, ..., -0.8906, -1.2812, -2.2344],\n",
|
| 406 |
+
" ...,\n",
|
| 407 |
+
" [-1.4297, 0.2354, 1.2891, ..., 0.1875, -0.5898, 0.7500],\n",
|
| 408 |
+
" [ 0.9922, 1.3203, -0.1592, ..., 0.1357, 1.4375, -0.2461],\n",
|
| 409 |
+
" [ 0.4707, -1.0781, 2.2500, ..., 0.8242, -0.7695, 0.4355]]],\n",
|
| 410 |
+
" dtype=torch.bfloat16),\n",
|
| 411 |
+
" 'noise_pred_3': tensor([[[ 1.8984, -0.9180, 0.5430, ..., -1.7031, 1.0938, 0.3691],\n",
|
| 412 |
+
" [ 1.5703, 0.0308, -0.2676, ..., 0.7812, -0.1602, -1.7422],\n",
|
| 413 |
+
" [ 0.3164, 0.2949, 0.1514, ..., -0.9922, -1.4609, -2.4531],\n",
|
| 414 |
+
" ...,\n",
|
| 415 |
+
" [-1.3203, 0.4277, 1.5234, ..., -0.1611, -0.9688, 0.4434],\n",
|
| 416 |
+
" [ 1.1875, 1.4609, -0.0179, ..., -0.4355, 1.0312, -0.8867],\n",
|
| 417 |
+
" [ 0.5547, -1.0234, 2.4844, ..., 0.2344, -1.4844, -0.1025]]],\n",
|
| 418 |
+
" dtype=torch.bfloat16),\n",
|
| 419 |
+
" 't_4': tensor([0.9531], dtype=torch.bfloat16),\n",
|
| 420 |
+
" 'latents_4_start': tensor([[[ 1.8828, -0.7578, 0.6055, ..., -1.6484, 0.9844, 0.3652],\n",
|
| 421 |
+
" [ 1.5859, 0.1128, -0.1445, ..., 0.7188, -0.1689, -1.6016],\n",
|
| 422 |
+
" [ 0.3828, 0.3770, 0.1836, ..., -0.8789, -1.2656, -2.2031],\n",
|
| 423 |
+
" ...,\n",
|
| 424 |
+
" [-1.4141, 0.2305, 1.2734, ..., 0.1895, -0.5781, 0.7461],\n",
|
| 425 |
+
" [ 0.9766, 1.3047, -0.1592, ..., 0.1416, 1.4219, -0.2354],\n",
|
| 426 |
+
" [ 0.4629, -1.0625, 2.2188, ..., 0.8203, -0.7500, 0.4375]]],\n",
|
| 427 |
+
" dtype=torch.bfloat16),\n",
|
| 428 |
+
" 'noise_pred_4': tensor([[[ 1.8984, -0.9141, 0.5508, ..., -1.7109, 1.0859, 0.3672],\n",
|
| 429 |
+
" [ 1.5625, 0.0238, -0.2754, ..., 0.7656, -0.1768, -1.7578],\n",
|
| 430 |
+
" [ 0.3105, 0.2988, 0.1602, ..., -1.0156, -1.4766, -2.4375],\n",
|
| 431 |
+
" ...,\n",
|
| 432 |
+
" [-1.3125, 0.4316, 1.5469, ..., -0.1621, -0.9805, 0.4141],\n",
|
| 433 |
+
" [ 1.1953, 1.4844, -0.0118, ..., -0.4590, 1.0078, -0.9492],\n",
|
| 434 |
+
" [ 0.5703, -1.0156, 2.5156, ..., 0.1777, -1.5469, -0.1475]]],\n",
|
| 435 |
+
" dtype=torch.bfloat16),\n",
|
| 436 |
+
" 't_5': tensor([0.9414], dtype=torch.bfloat16),\n",
|
| 437 |
+
" 'latents_5_start': tensor([[[ 1.8594, -0.7461, 0.5977, ..., -1.6250, 0.9688, 0.3613],\n",
|
| 438 |
+
" [ 1.5625, 0.1123, -0.1406, ..., 0.7109, -0.1670, -1.5781],\n",
|
| 439 |
+
" [ 0.3789, 0.3730, 0.1816, ..., -0.8672, -1.2500, -2.1719],\n",
|
| 440 |
+
" ...,\n",
|
| 441 |
+
" [-1.3984, 0.2246, 1.2500, ..., 0.1914, -0.5664, 0.7422],\n",
|
| 442 |
+
" [ 0.9609, 1.2891, -0.1592, ..., 0.1475, 1.4062, -0.2236],\n",
|
| 443 |
+
" [ 0.4551, -1.0469, 2.1875, ..., 0.8164, -0.7305, 0.4395]]],\n",
|
| 444 |
+
" dtype=torch.bfloat16),\n",
|
| 445 |
+
" 'noise_pred_5': tensor([[[ 1.8906, -0.8984, 0.5586, ..., -1.7031, 1.0391, 0.3516],\n",
|
| 446 |
+
" [ 1.5625, 0.0388, -0.2871, ..., 0.8008, -0.1504, -1.7734],\n",
|
| 447 |
+
" [ 0.3008, 0.2949, 0.1777, ..., -1.0312, -1.5781, -2.5469],\n",
|
| 448 |
+
" ...,\n",
|
| 449 |
+
" [-1.3047, 0.4688, 1.5938, ..., -0.2188, -1.0781, 0.3945],\n",
|
| 450 |
+
" [ 1.2031, 1.4844, 0.0082, ..., -0.5469, 0.9414, -1.1875],\n",
|
| 451 |
+
" [ 0.5781, -0.9336, 2.5625, ..., -0.0903, -1.8047, -0.3828]]],\n",
|
| 452 |
+
" dtype=torch.bfloat16),\n",
|
| 453 |
+
" 't_6': tensor([0.9258], dtype=torch.bfloat16),\n",
|
| 454 |
+
" 'latents_6_start': tensor([[[ 1.8359, -0.7344, 0.5898, ..., -1.6016, 0.9570, 0.3574],\n",
|
| 455 |
+
" [ 1.5391, 0.1118, -0.1367, ..., 0.6992, -0.1650, -1.5547],\n",
|
| 456 |
+
" [ 0.3750, 0.3691, 0.1797, ..., -0.8555, -1.2266, -2.1406],\n",
|
| 457 |
+
" ...,\n",
|
| 458 |
+
" [-1.3828, 0.2188, 1.2266, ..., 0.1943, -0.5508, 0.7383],\n",
|
| 459 |
+
" [ 0.9453, 1.2734, -0.1592, ..., 0.1543, 1.3906, -0.2080],\n",
|
| 460 |
+
" [ 0.4473, -1.0312, 2.1562, ..., 0.8164, -0.7070, 0.4453]]],\n",
|
| 461 |
+
" dtype=torch.bfloat16),\n",
|
| 462 |
+
" 'noise_pred_6': tensor([[[ 1.9219, -0.8828, 0.5820, ..., -1.7109, 1.0781, 0.3613],\n",
|
| 463 |
+
" [ 1.5703, 0.0359, -0.2812, ..., 0.7773, -0.1865, -1.8203],\n",
|
| 464 |
+
" [ 0.3301, 0.2949, 0.1924, ..., -1.0781, -1.6016, -2.5312],\n",
|
| 465 |
+
" ...,\n",
|
| 466 |
+
" [-1.2734, 0.4941, 1.6094, ..., -0.2236, -1.0391, 0.3633],\n",
|
| 467 |
+
" [ 1.2656, 1.5469, 0.0796, ..., -0.6797, 0.8672, -1.2656],\n",
|
| 468 |
+
" [ 0.6484, -0.9102, 2.5938, ..., -0.1904, -1.8516, -0.4590]]],\n",
|
| 469 |
+
" dtype=torch.bfloat16),\n",
|
| 470 |
+
" 't_7': tensor([0.9102], dtype=torch.bfloat16),\n",
|
| 471 |
+
" 'latents_7_start': tensor([[[ 1.8125, -0.7227, 0.5820, ..., -1.5781, 0.9414, 0.3535],\n",
|
| 472 |
+
" [ 1.5156, 0.1113, -0.1328, ..., 0.6875, -0.1621, -1.5312],\n",
|
| 473 |
+
" [ 0.3711, 0.3652, 0.1768, ..., -0.8398, -1.2031, -2.1094],\n",
|
| 474 |
+
" ...,\n",
|
| 475 |
+
" [-1.3672, 0.2119, 1.2031, ..., 0.1973, -0.5352, 0.7344],\n",
|
| 476 |
+
" [ 0.9297, 1.2500, -0.1602, ..., 0.1631, 1.3828, -0.1914],\n",
|
| 477 |
+
" [ 0.4395, -1.0156, 2.1250, ..., 0.8203, -0.6836, 0.4512]]],\n",
|
| 478 |
+
" dtype=torch.bfloat16),\n",
|
| 479 |
+
" 'noise_pred_7': tensor([[[ 1.9531, -0.8906, 0.5938, ..., -1.7266, 1.0938, 0.4180],\n",
|
| 480 |
+
" [ 1.5781, 0.0309, -0.3008, ..., 0.7969, -0.1699, -1.8281],\n",
|
| 481 |
+
" [ 0.3262, 0.3008, 0.2314, ..., -1.0781, -1.6797, -2.6250],\n",
|
| 482 |
+
" ...,\n",
|
| 483 |
+
" [-1.2812, 0.5039, 1.5938, ..., -0.2314, -1.0547, 0.3828],\n",
|
| 484 |
+
" [ 1.2734, 1.5781, 0.0859, ..., -0.7930, 0.8555, -1.3516],\n",
|
| 485 |
+
" [ 0.6914, -0.9062, 2.6250, ..., -0.2598, -1.8516, -0.4902]]],\n",
|
| 486 |
+
" dtype=torch.bfloat16),\n",
|
| 487 |
+
" 't_8': tensor([0.8984], dtype=torch.bfloat16),\n",
|
| 488 |
+
" 'latents_8_start': tensor([[[ 1.7891, -0.7109, 0.5742, ..., -1.5547, 0.9258, 0.3477],\n",
|
| 489 |
+
" [ 1.4922, 0.1108, -0.1289, ..., 0.6758, -0.1602, -1.5078],\n",
|
| 490 |
+
" [ 0.3672, 0.3613, 0.1738, ..., -0.8242, -1.1797, -2.0781],\n",
|
| 491 |
+
" ...,\n",
|
| 492 |
+
" [-1.3516, 0.2051, 1.1797, ..., 0.2002, -0.5195, 0.7305],\n",
|
| 493 |
+
" [ 0.9141, 1.2266, -0.1611, ..., 0.1738, 1.3750, -0.1729],\n",
|
| 494 |
+
" [ 0.4297, -1.0000, 2.0938, ..., 0.8242, -0.6602, 0.4570]]],\n",
|
| 495 |
+
" dtype=torch.bfloat16),\n",
|
| 496 |
+
" 'noise_pred_8': tensor([[[ 1.9453, -0.8789, 0.6094, ..., -1.7266, 1.0781, 0.4082],\n",
|
| 497 |
+
" [ 1.5703, 0.0396, -0.3047, ..., 0.7891, -0.1826, -1.8516],\n",
|
| 498 |
+
" [ 0.3164, 0.2949, 0.2500, ..., -1.0859, -1.7031, -2.6250],\n",
|
| 499 |
+
" ...,\n",
|
| 500 |
+
" [-1.2578, 0.5234, 1.5938, ..., -0.2246, -1.0547, 0.3770],\n",
|
| 501 |
+
" [ 1.2734, 1.6016, 0.0884, ..., -0.8828, 0.8086, -1.3672],\n",
|
| 502 |
+
" [ 0.7070, -0.8828, 2.6094, ..., -0.2832, -1.8750, -0.5117]]],\n",
|
| 503 |
+
" dtype=torch.bfloat16),\n",
|
| 504 |
+
" 't_9': tensor([0.8828], dtype=torch.bfloat16),\n",
|
| 505 |
+
" 'latents_9_start': tensor([[[ 1.7656, -0.6992, 0.5664, ..., -1.5312, 0.9102, 0.3418],\n",
|
| 506 |
+
" [ 1.4688, 0.1104, -0.1245, ..., 0.6641, -0.1572, -1.4844],\n",
|
| 507 |
+
" [ 0.3633, 0.3574, 0.1699, ..., -0.8086, -1.1562, -2.0469],\n",
|
| 508 |
+
" ...,\n",
|
| 509 |
+
" [-1.3359, 0.1982, 1.1562, ..., 0.2031, -0.5039, 0.7266],\n",
|
| 510 |
+
" [ 0.8984, 1.2031, -0.1621, ..., 0.1855, 1.3672, -0.1543],\n",
|
| 511 |
+
" [ 0.4199, -0.9883, 2.0625, ..., 0.8281, -0.6328, 0.4648]]],\n",
|
| 512 |
+
" dtype=torch.bfloat16),\n",
|
| 513 |
+
" 'noise_pred_9': tensor([[[ 1.9531, -0.8828, 0.6172, ..., -1.7188, 1.0391, 0.4141],\n",
|
| 514 |
+
" [ 1.5703, 0.0583, -0.3125, ..., 0.7930, -0.1582, -1.8594],\n",
|
| 515 |
+
" [ 0.3203, 0.2910, 0.2598, ..., -1.1016, -1.7500, -2.6562],\n",
|
| 516 |
+
" ...,\n",
|
| 517 |
+
" [-1.2422, 0.5273, 1.6094, ..., -0.2168, -1.0391, 0.4121],\n",
|
| 518 |
+
" [ 1.2656, 1.6172, 0.1001, ..., -0.8984, 0.8008, -1.4062],\n",
|
| 519 |
+
" [ 0.7383, -0.8750, 2.6250, ..., -0.2891, -1.8672, -0.5312]]],\n",
|
| 520 |
+
" dtype=torch.bfloat16),\n",
|
| 521 |
+
" 't_10': tensor([0.8711], dtype=torch.bfloat16),\n",
|
| 522 |
+
" 'latents_10_start': tensor([[[ 1.7344, -0.6875, 0.5586, ..., -1.5078, 0.8945, 0.3359],\n",
|
| 523 |
+
" [ 1.4453, 0.1094, -0.1201, ..., 0.6523, -0.1553, -1.4609],\n",
|
| 524 |
+
" [ 0.3594, 0.3535, 0.1660, ..., -0.7930, -1.1328, -2.0156],\n",
|
| 525 |
+
" ...,\n",
|
| 526 |
+
" [-1.3203, 0.1904, 1.1328, ..., 0.2061, -0.4902, 0.7227],\n",
|
| 527 |
+
" [ 0.8789, 1.1797, -0.1631, ..., 0.1982, 1.3594, -0.1348],\n",
|
| 528 |
+
" [ 0.4102, -0.9766, 2.0312, ..., 0.8320, -0.6055, 0.4727]]],\n",
|
| 529 |
+
" dtype=torch.bfloat16),\n",
|
| 530 |
+
" 'noise_pred_10': tensor([[[ 1.9609, -0.8672, 0.6445, ..., -1.7188, 1.0156, 0.4180],\n",
|
| 531 |
+
" [ 1.5781, 0.0728, -0.3125, ..., 0.7812, -0.1602, -1.8828],\n",
|
| 532 |
+
" [ 0.3125, 0.2832, 0.2832, ..., -1.1172, -1.7734, -2.6875],\n",
|
| 533 |
+
" ...,\n",
|
| 534 |
+
" [-1.2500, 0.5273, 1.6016, ..., -0.2227, -1.0625, 0.4062],\n",
|
| 535 |
+
" [ 1.2500, 1.6094, 0.1016, ..., -0.9297, 0.8086, -1.4219],\n",
|
| 536 |
+
" [ 0.7617, -0.8555, 2.6406, ..., -0.3105, -1.8594, -0.5352]]],\n",
|
| 537 |
+
" dtype=torch.bfloat16),\n",
|
| 538 |
+
" 't_11': tensor([0.8555], dtype=torch.bfloat16),\n",
|
| 539 |
+
" 'latents_11_start': tensor([[[ 1.7031, -0.6758, 0.5508, ..., -1.4844, 0.8789, 0.3301],\n",
|
| 540 |
+
" [ 1.4219, 0.1084, -0.1157, ..., 0.6406, -0.1533, -1.4375],\n",
|
| 541 |
+
" [ 0.3555, 0.3496, 0.1621, ..., -0.7773, -1.1094, -1.9766],\n",
|
| 542 |
+
" ...,\n",
|
| 543 |
+
" [-1.3047, 0.1826, 1.1094, ..., 0.2090, -0.4746, 0.7188],\n",
|
| 544 |
+
" [ 0.8594, 1.1562, -0.1641, ..., 0.2119, 1.3516, -0.1143],\n",
|
| 545 |
+
" [ 0.3984, -0.9648, 1.9922, ..., 0.8359, -0.5781, 0.4805]]],\n",
|
| 546 |
+
" dtype=torch.bfloat16),\n",
|
| 547 |
+
" 'noise_pred_11': tensor([[[ 1.9688, -0.8516, 0.6484, ..., -1.7266, 1.0000, 0.4082],\n",
|
| 548 |
+
" [ 1.5938, 0.0679, -0.3105, ..., 0.8086, -0.1455, -1.8984],\n",
|
| 549 |
+
" [ 0.3203, 0.2812, 0.2949, ..., -1.1094, -1.7812, -2.6719],\n",
|
| 550 |
+
" ...,\n",
|
| 551 |
+
" [-1.2500, 0.5273, 1.5938, ..., -0.2119, -1.0625, 0.4102],\n",
|
| 552 |
+
" [ 1.2656, 1.6016, 0.1011, ..., -0.9180, 0.8281, -1.4531],\n",
|
| 553 |
+
" [ 0.7695, -0.8320, 2.6562, ..., -0.2891, -1.8516, -0.5234]]],\n",
|
| 554 |
+
" dtype=torch.bfloat16),\n",
|
| 555 |
+
" 't_12': tensor([0.8438], dtype=torch.bfloat16),\n",
|
| 556 |
+
" 'latents_12_start': tensor([[[ 1.6719, -0.6641, 0.5430, ..., -1.4609, 0.8633, 0.3242],\n",
|
| 557 |
+
" [ 1.3984, 0.1074, -0.1113, ..., 0.6289, -0.1514, -1.4062],\n",
|
| 558 |
+
" [ 0.3516, 0.3457, 0.1582, ..., -0.7617, -1.0859, -1.9375],\n",
|
| 559 |
+
" ...,\n",
|
| 560 |
+
" [-1.2891, 0.1748, 1.0859, ..., 0.2119, -0.4590, 0.7109],\n",
|
| 561 |
+
" [ 0.8398, 1.1328, -0.1660, ..., 0.2256, 1.3359, -0.0933],\n",
|
| 562 |
+
" [ 0.3867, -0.9531, 1.9531, ..., 0.8398, -0.5508, 0.4883]]],\n",
|
| 563 |
+
" dtype=torch.bfloat16),\n",
|
| 564 |
+
" 'noise_pred_12': tensor([[[ 1.9688, -0.8477, 0.6602, ..., -1.7422, 0.9805, 0.3965],\n",
|
| 565 |
+
" [ 1.5938, 0.0845, -0.3066, ..., 0.7891, -0.1816, -1.9062],\n",
|
| 566 |
+
" [ 0.3105, 0.2754, 0.3242, ..., -1.1328, -1.8047, -2.6875],\n",
|
| 567 |
+
" ...,\n",
|
| 568 |
+
" [-1.2422, 0.5195, 1.5938, ..., -0.2227, -1.0625, 0.4180],\n",
|
| 569 |
+
" [ 1.2500, 1.6172, 0.1138, ..., -0.9492, 0.8281, -1.4609],\n",
|
| 570 |
+
" [ 0.7852, -0.8555, 2.6562, ..., -0.3047, -1.8438, -0.5430]]],\n",
|
| 571 |
+
" dtype=torch.bfloat16),\n",
|
| 572 |
+
" 't_13': tensor([0.8281], dtype=torch.bfloat16),\n",
|
| 573 |
+
" 'latents_13_start': tensor([[[ 1.6406, -0.6523, 0.5312, ..., -1.4375, 0.8477, 0.3184],\n",
|
| 574 |
+
" [ 1.3750, 0.1060, -0.1069, ..., 0.6172, -0.1484, -1.3750],\n",
|
| 575 |
+
" [ 0.3477, 0.3418, 0.1533, ..., -0.7461, -1.0625, -1.8984],\n",
|
| 576 |
+
" ...,\n",
|
| 577 |
+
" [-1.2734, 0.1670, 1.0625, ..., 0.2148, -0.4434, 0.7031],\n",
|
| 578 |
+
" [ 0.8203, 1.1094, -0.1680, ..., 0.2393, 1.3203, -0.0718],\n",
|
| 579 |
+
" [ 0.3750, -0.9414, 1.9141, ..., 0.8438, -0.5234, 0.4961]]],\n",
|
| 580 |
+
" dtype=torch.bfloat16),\n",
|
| 581 |
+
" 'noise_pred_13': tensor([[[ 1.9688, -0.8438, 0.6562, ..., -1.7500, 0.9805, 0.3906],\n",
|
| 582 |
+
" [ 1.5938, 0.0791, -0.3066, ..., 0.7734, -0.1934, -1.9062],\n",
|
| 583 |
+
" [ 0.3145, 0.2734, 0.3203, ..., -1.1250, -1.8359, -2.7188],\n",
|
| 584 |
+
" ...,\n",
|
| 585 |
+
" [-1.2422, 0.5156, 1.6094, ..., -0.2021, -1.0312, 0.4609],\n",
|
| 586 |
+
" [ 1.2656, 1.6250, 0.1108, ..., -0.9453, 0.8789, -1.4688],\n",
|
| 587 |
+
" [ 0.7930, -0.8594, 2.6719, ..., -0.3105, -1.8359, -0.5469]]],\n",
|
| 588 |
+
" dtype=torch.bfloat16),\n",
|
| 589 |
+
" 't_14': tensor([0.8125], dtype=torch.bfloat16),\n",
|
| 590 |
+
" 'latents_14_start': tensor([[[ 1.6094, -0.6406, 0.5195, ..., -1.4141, 0.8320, 0.3125],\n",
|
| 591 |
+
" [ 1.3516, 0.1050, -0.1025, ..., 0.6055, -0.1455, -1.3438],\n",
|
| 592 |
+
" [ 0.3438, 0.3379, 0.1484, ..., -0.7305, -1.0312, -1.8594],\n",
|
| 593 |
+
" ...,\n",
|
| 594 |
+
" [-1.2578, 0.1592, 1.0391, ..., 0.2178, -0.4277, 0.6953],\n",
|
| 595 |
+
" [ 0.8008, 1.0859, -0.1699, ..., 0.2539, 1.3047, -0.0498],\n",
|
| 596 |
+
" [ 0.3633, -0.9297, 1.8750, ..., 0.8477, -0.4961, 0.5039]]],\n",
|
| 597 |
+
" dtype=torch.bfloat16),\n",
|
| 598 |
+
" 'noise_pred_14': tensor([[[ 1.9609, -0.8438, 0.6562, ..., -1.7500, 0.9727, 0.3828],\n",
|
| 599 |
+
" [ 1.5938, 0.0840, -0.3203, ..., 0.7695, -0.2061, -1.8906],\n",
|
| 600 |
+
" [ 0.3125, 0.2754, 0.3262, ..., -1.1328, -1.8359, -2.7031],\n",
|
| 601 |
+
" ...,\n",
|
| 602 |
+
" [-1.2422, 0.5117, 1.6094, ..., -0.2002, -1.0156, 0.4746],\n",
|
| 603 |
+
" [ 1.2656, 1.6172, 0.1108, ..., -0.9219, 0.8945, -1.4609],\n",
|
| 604 |
+
" [ 0.7969, -0.8672, 2.6406, ..., -0.3047, -1.7969, -0.5430]]],\n",
|
| 605 |
+
" dtype=torch.bfloat16),\n",
|
| 606 |
+
" 't_15': tensor([0.7969], dtype=torch.bfloat16),\n",
|
| 607 |
+
" 'latents_15_start': tensor([[[ 1.5781, -0.6289, 0.5078, ..., -1.3906, 0.8164, 0.3066],\n",
|
| 608 |
+
" [ 1.3281, 0.1035, -0.0977, ..., 0.5938, -0.1426, -1.3125],\n",
|
| 609 |
+
" [ 0.3398, 0.3340, 0.1436, ..., -0.7148, -1.0000, -1.8203],\n",
|
| 610 |
+
" ...,\n",
|
| 611 |
+
" [-1.2422, 0.1514, 1.0156, ..., 0.2207, -0.4121, 0.6875],\n",
|
| 612 |
+
" [ 0.7812, 1.0625, -0.1719, ..., 0.2676, 1.2891, -0.0275],\n",
|
| 613 |
+
" [ 0.3516, -0.9180, 1.8359, ..., 0.8516, -0.4688, 0.5117]]],\n",
|
| 614 |
+
" dtype=torch.bfloat16),\n",
|
| 615 |
+
" 'noise_pred_15': tensor([[[ 1.9531, -0.8320, 0.6641, ..., -1.7578, 0.9570, 0.3789],\n",
|
| 616 |
+
" [ 1.5938, 0.0806, -0.3184, ..., 0.7500, -0.2031, -1.8750],\n",
|
| 617 |
+
" [ 0.3242, 0.2656, 0.3301, ..., -1.1406, -1.8359, -2.7031],\n",
|
| 618 |
+
" ...,\n",
|
| 619 |
+
" [-1.2578, 0.5195, 1.6328, ..., -0.1875, -0.9883, 0.5117],\n",
|
| 620 |
+
" [ 1.2734, 1.6094, 0.1230, ..., -0.9297, 0.9102, -1.4531],\n",
|
| 621 |
+
" [ 0.7930, -0.8633, 2.6406, ..., -0.3027, -1.8203, -0.5312]]],\n",
|
| 622 |
+
" dtype=torch.bfloat16),\n",
|
| 623 |
+
" 't_16': tensor([0.7812], dtype=torch.bfloat16),\n",
|
| 624 |
+
" 'latents_16_start': tensor([[[ 1.5469, -0.6172, 0.4980, ..., -1.3594, 0.8008, 0.3008],\n",
|
| 625 |
+
" [ 1.3047, 0.1021, -0.0928, ..., 0.5820, -0.1396, -1.2812],\n",
|
| 626 |
+
" [ 0.3340, 0.3301, 0.1387, ..., -0.6953, -0.9727, -1.7812],\n",
|
| 627 |
+
" ...,\n",
|
| 628 |
+
" [-1.2188, 0.1436, 0.9883, ..., 0.2236, -0.3965, 0.6797],\n",
|
| 629 |
+
" [ 0.7617, 1.0391, -0.1738, ..., 0.2812, 1.2734, -0.0048],\n",
|
| 630 |
+
" [ 0.3398, -0.9062, 1.7969, ..., 0.8555, -0.4395, 0.5195]]],\n",
|
| 631 |
+
" dtype=torch.bfloat16),\n",
|
| 632 |
+
" 'noise_pred_16': tensor([[[ 1.9766, -0.8320, 0.6758, ..., -1.7734, 0.9844, 0.3867],\n",
|
| 633 |
+
" [ 1.6094, 0.0923, -0.3164, ..., 0.7617, -0.2148, -1.8984],\n",
|
| 634 |
+
" [ 0.3281, 0.2695, 0.3398, ..., -1.1484, -1.8516, -2.7500],\n",
|
| 635 |
+
" ...,\n",
|
| 636 |
+
" [-1.2500, 0.5156, 1.6406, ..., -0.1953, -0.9648, 0.5273],\n",
|
| 637 |
+
" [ 1.3125, 1.6250, 0.1113, ..., -0.9102, 0.9414, -1.4609],\n",
|
| 638 |
+
" [ 0.7969, -0.8750, 2.6719, ..., -0.2988, -1.7891, -0.5469]]],\n",
|
| 639 |
+
" dtype=torch.bfloat16),\n",
|
| 640 |
+
" 't_17': tensor([0.7656], dtype=torch.bfloat16),\n",
|
| 641 |
+
" 'latents_17_start': tensor([[[ 1.5156, -0.6055, 0.4883, ..., -1.3281, 0.7852, 0.2949],\n",
|
| 642 |
+
" [ 1.2812, 0.1006, -0.0879, ..., 0.5703, -0.1367, -1.2500],\n",
|
| 643 |
+
" [ 0.3281, 0.3262, 0.1328, ..., -0.6758, -0.9414, -1.7344],\n",
|
| 644 |
+
" ...,\n",
|
| 645 |
+
" [-1.1953, 0.1357, 0.9609, ..., 0.2266, -0.3809, 0.6719],\n",
|
| 646 |
+
" [ 0.7422, 1.0156, -0.1758, ..., 0.2949, 1.2578, 0.0184],\n",
|
| 647 |
+
" [ 0.3281, -0.8906, 1.7578, ..., 0.8594, -0.4102, 0.5273]]],\n",
|
| 648 |
+
" dtype=torch.bfloat16),\n",
|
| 649 |
+
" 'noise_pred_17': tensor([[[ 1.9688, -0.8242, 0.6719, ..., -1.7578, 0.9688, 0.3691],\n",
|
| 650 |
+
" [ 1.6094, 0.0869, -0.3145, ..., 0.7500, -0.2217, -1.8828],\n",
|
| 651 |
+
" [ 0.3203, 0.2754, 0.3457, ..., -1.1406, -1.8516, -2.7500],\n",
|
| 652 |
+
" ...,\n",
|
| 653 |
+
" [-1.2266, 0.5156, 1.6250, ..., -0.1904, -0.9492, 0.5273],\n",
|
| 654 |
+
" [ 1.3047, 1.6172, 0.1040, ..., -0.9141, 0.9570, -1.4531],\n",
|
| 655 |
+
" [ 0.7852, -0.8633, 2.6562, ..., -0.2949, -1.7969, -0.5430]]],\n",
|
| 656 |
+
" dtype=torch.bfloat16),\n",
|
| 657 |
+
" 't_18': tensor([0.7461], dtype=torch.bfloat16),\n",
|
| 658 |
+
" 'latents_18_start': tensor([[[ 1.4844, -0.5938, 0.4766, ..., -1.2969, 0.7695, 0.2891],\n",
|
| 659 |
+
" [ 1.2578, 0.0991, -0.0830, ..., 0.5586, -0.1328, -1.2188],\n",
|
| 660 |
+
" [ 0.3223, 0.3223, 0.1270, ..., -0.6562, -0.9102, -1.6875],\n",
|
| 661 |
+
" ...,\n",
|
| 662 |
+
" [-1.1719, 0.1270, 0.9336, ..., 0.2295, -0.3652, 0.6641],\n",
|
| 663 |
+
" [ 0.7227, 0.9883, -0.1777, ..., 0.3105, 1.2422, 0.0420],\n",
|
| 664 |
+
" [ 0.3145, -0.8750, 1.7109, ..., 0.8633, -0.3809, 0.5352]]],\n",
|
| 665 |
+
" dtype=torch.bfloat16),\n",
|
| 666 |
+
" 'noise_pred_18': tensor([[[ 1.9844, -0.8398, 0.6680, ..., -1.7578, 0.9727, 0.3730],\n",
|
| 667 |
+
" [ 1.6172, 0.0752, -0.3184, ..., 0.7578, -0.2148, -1.8828],\n",
|
| 668 |
+
" [ 0.3066, 0.2715, 0.3398, ..., -1.1328, -1.8516, -2.7500],\n",
|
| 669 |
+
" ...,\n",
|
| 670 |
+
" [-1.2422, 0.5156, 1.6484, ..., -0.1777, -0.9336, 0.5625],\n",
|
| 671 |
+
" [ 1.3047, 1.6328, 0.1147, ..., -0.8906, 0.9883, -1.4688],\n",
|
| 672 |
+
" [ 0.7734, -0.8672, 2.6406, ..., -0.2910, -1.7891, -0.5312]]],\n",
|
| 673 |
+
" dtype=torch.bfloat16),\n",
|
| 674 |
+
" 't_19': tensor([0.7305], dtype=torch.bfloat16),\n",
|
| 675 |
+
" 'latents_19_start': tensor([[[ 1.4531, -0.5781, 0.4648, ..., -1.2656, 0.7539, 0.2832],\n",
|
| 676 |
+
" [ 1.2344, 0.0977, -0.0776, ..., 0.5469, -0.1289, -1.1875],\n",
|
| 677 |
+
" [ 0.3164, 0.3184, 0.1211, ..., -0.6367, -0.8789, -1.6406],\n",
|
| 678 |
+
" ...,\n",
|
| 679 |
+
" [-1.1484, 0.1182, 0.9062, ..., 0.2324, -0.3496, 0.6562],\n",
|
| 680 |
+
" [ 0.6992, 0.9609, -0.1797, ..., 0.3262, 1.2266, 0.0664],\n",
|
| 681 |
+
" [ 0.3008, -0.8594, 1.6641, ..., 0.8672, -0.3516, 0.5430]]],\n",
|
| 682 |
+
" dtype=torch.bfloat16),\n",
|
| 683 |
+
" 'noise_pred_19': tensor([[[ 1.9844, -0.8516, 0.6641, ..., -1.7656, 0.9688, 0.3770],\n",
|
| 684 |
+
" [ 1.6094, 0.0684, -0.3281, ..., 0.7734, -0.2119, -1.8672],\n",
|
| 685 |
+
" [ 0.3086, 0.2695, 0.3301, ..., -1.1484, -1.8438, -2.7188],\n",
|
| 686 |
+
" ...,\n",
|
| 687 |
+
" [-1.2422, 0.5117, 1.6484, ..., -0.1738, -0.8984, 0.5781],\n",
|
| 688 |
+
" [ 1.3203, 1.6328, 0.1157, ..., -0.8828, 1.0078, -1.4844],\n",
|
| 689 |
+
" [ 0.7617, -0.8672, 2.6719, ..., -0.2480, -1.8125, -0.5273]]],\n",
|
| 690 |
+
" dtype=torch.bfloat16),\n",
|
| 691 |
+
" 't_20': tensor([0.7148], dtype=torch.bfloat16),\n",
|
| 692 |
+
" 'latents_20_start': tensor([[[ 1.4219, -0.5625, 0.4531, ..., -1.2344, 0.7383, 0.2773],\n",
|
| 693 |
+
" [ 1.2109, 0.0967, -0.0723, ..., 0.5352, -0.1250, -1.1562],\n",
|
| 694 |
+
" [ 0.3105, 0.3145, 0.1157, ..., -0.6172, -0.8477, -1.5938],\n",
|
| 695 |
+
" ...,\n",
|
| 696 |
+
" [-1.1250, 0.1094, 0.8789, ..., 0.2354, -0.3340, 0.6484],\n",
|
| 697 |
+
" [ 0.6758, 0.9336, -0.1816, ..., 0.3418, 1.2109, 0.0913],\n",
|
| 698 |
+
" [ 0.2871, -0.8438, 1.6172, ..., 0.8711, -0.3203, 0.5508]]],\n",
|
| 699 |
+
" dtype=torch.bfloat16),\n",
|
| 700 |
+
" 'noise_pred_20': tensor([[[ 1.9766, -0.8438, 0.6562, ..., -1.7656, 0.9570, 0.3809],\n",
|
| 701 |
+
" [ 1.6094, 0.0713, -0.3340, ..., 0.7734, -0.2246, -1.8594],\n",
|
| 702 |
+
" [ 0.2910, 0.2598, 0.3281, ..., -1.1250, -1.8359, -2.7500],\n",
|
| 703 |
+
" ...,\n",
|
| 704 |
+
" [-1.2422, 0.5039, 1.6406, ..., -0.1738, -0.8867, 0.6016],\n",
|
| 705 |
+
" [ 1.3125, 1.6172, 0.1187, ..., -0.8672, 1.0156, -1.4922],\n",
|
| 706 |
+
" [ 0.7578, -0.8711, 2.6719, ..., -0.2559, -1.7891, -0.5547]]],\n",
|
| 707 |
+
" dtype=torch.bfloat16),\n",
|
| 708 |
+
" 't_21': tensor([0.6992], dtype=torch.bfloat16),\n",
|
| 709 |
+
" 'latents_21_start': tensor([[[ 1.3906, -0.5469, 0.4414, ..., -1.2031, 0.7227, 0.2715],\n",
|
| 710 |
+
" [ 1.1797, 0.0952, -0.0664, ..., 0.5234, -0.1211, -1.1250],\n",
|
| 711 |
+
" [ 0.3047, 0.3105, 0.1099, ..., -0.5977, -0.8164, -1.5469],\n",
|
| 712 |
+
" ...,\n",
|
| 713 |
+
" [-1.1016, 0.1006, 0.8516, ..., 0.2383, -0.3184, 0.6367],\n",
|
| 714 |
+
" [ 0.6523, 0.9062, -0.1836, ..., 0.3574, 1.1953, 0.1172],\n",
|
| 715 |
+
" [ 0.2734, -0.8281, 1.5703, ..., 0.8750, -0.2891, 0.5586]]],\n",
|
| 716 |
+
" dtype=torch.bfloat16),\n",
|
| 717 |
+
" 'noise_pred_21': tensor([[[ 1.9844, -0.8281, 0.6680, ..., -1.7578, 0.9375, 0.3457],\n",
|
| 718 |
+
" [ 1.5938, 0.0811, -0.3203, ..., 0.7656, -0.2207, -1.8594],\n",
|
| 719 |
+
" [ 0.2773, 0.2559, 0.3340, ..., -1.1328, -1.8438, -2.7344],\n",
|
| 720 |
+
" ...,\n",
|
| 721 |
+
" [-1.2266, 0.4961, 1.6406, ..., -0.1953, -0.8750, 0.5859],\n",
|
| 722 |
+
" [ 1.2969, 1.6250, 0.1147, ..., -0.8594, 1.0156, -1.4922],\n",
|
| 723 |
+
" [ 0.7578, -0.8711, 2.6719, ..., -0.2617, -1.7891, -0.5508]]],\n",
|
| 724 |
+
" dtype=torch.bfloat16),\n",
|
| 725 |
+
" 't_22': tensor([0.6797], dtype=torch.bfloat16),\n",
|
| 726 |
+
" 'latents_22_start': tensor([[[ 1.3594, -0.5312, 0.4297, ..., -1.1719, 0.7070, 0.2656],\n",
|
| 727 |
+
" [ 1.1484, 0.0938, -0.0608, ..., 0.5117, -0.1172, -1.0938],\n",
|
| 728 |
+
" [ 0.3008, 0.3066, 0.1040, ..., -0.5781, -0.7852, -1.5000],\n",
|
| 729 |
+
" ...,\n",
|
| 730 |
+
" [-1.0781, 0.0918, 0.8242, ..., 0.2422, -0.3027, 0.6250],\n",
|
| 731 |
+
" [ 0.6289, 0.8789, -0.1855, ..., 0.3730, 1.1797, 0.1436],\n",
|
| 732 |
+
" [ 0.2598, -0.8125, 1.5234, ..., 0.8789, -0.2578, 0.5664]]],\n",
|
| 733 |
+
" dtype=torch.bfloat16),\n",
|
| 734 |
+
" 'noise_pred_22': tensor([[[ 1.9922, -0.8242, 0.6523, ..., -1.7500, 0.9375, 0.3477],\n",
|
| 735 |
+
" [ 1.5859, 0.0757, -0.3379, ..., 0.7578, -0.2178, -1.8438],\n",
|
| 736 |
+
" [ 0.2930, 0.2520, 0.3320, ..., -1.1250, -1.8516, -2.7500],\n",
|
| 737 |
+
" ...,\n",
|
| 738 |
+
" [-1.2031, 0.5000, 1.6406, ..., -0.2012, -0.8750, 0.5820],\n",
|
| 739 |
+
" [ 1.3047, 1.6094, 0.1309, ..., -0.8555, 1.0234, -1.5078],\n",
|
| 740 |
+
" [ 0.7617, -0.8711, 2.6562, ..., -0.2793, -1.7969, -0.5742]]],\n",
|
| 741 |
+
" dtype=torch.bfloat16),\n",
|
| 742 |
+
" 't_23': tensor([0.6641], dtype=torch.bfloat16),\n",
|
| 743 |
+
" 'latents_23_start': tensor([[[ 1.3203, -0.5156, 0.4180, ..., -1.1406, 0.6914, 0.2598],\n",
|
| 744 |
+
" [ 1.1172, 0.0923, -0.0547, ..., 0.4980, -0.1133, -1.0625],\n",
|
| 745 |
+
" [ 0.2949, 0.3027, 0.0981, ..., -0.5586, -0.7500, -1.4531],\n",
|
| 746 |
+
" ...,\n",
|
| 747 |
+
" [-1.0547, 0.0830, 0.7930, ..., 0.2461, -0.2871, 0.6133],\n",
|
| 748 |
+
" [ 0.6055, 0.8516, -0.1875, ..., 0.3887, 1.1641, 0.1709],\n",
|
| 749 |
+
" [ 0.2461, -0.7969, 1.4766, ..., 0.8828, -0.2256, 0.5781]]],\n",
|
| 750 |
+
" dtype=torch.bfloat16),\n",
|
| 751 |
+
" 'noise_pred_23': tensor([[[ 1.9688, -0.8203, 0.6562, ..., -1.7422, 0.9336, 0.3359],\n",
|
| 752 |
+
" [ 1.5781, 0.0923, -0.3164, ..., 0.7617, -0.2188, -1.8438],\n",
|
| 753 |
+
" [ 0.2969, 0.2617, 0.3203, ..., -1.1328, -1.8594, -2.7500],\n",
|
| 754 |
+
" ...,\n",
|
| 755 |
+
" [-1.2109, 0.5117, 1.6406, ..., -0.1914, -0.8711, 0.5938],\n",
|
| 756 |
+
" [ 1.2891, 1.6094, 0.1182, ..., -0.8477, 1.0469, -1.5000],\n",
|
| 757 |
+
" [ 0.7461, -0.8945, 2.6562, ..., -0.2852, -1.8047, -0.5586]]],\n",
|
| 758 |
+
" dtype=torch.bfloat16),\n",
|
| 759 |
+
" 't_24': tensor([0.6445], dtype=torch.bfloat16),\n",
|
| 760 |
+
" 'latents_24_start': tensor([[[ 1.2812, -0.5000, 0.4062, ..., -1.1094, 0.6758, 0.2539],\n",
|
| 761 |
+
" [ 1.0859, 0.0908, -0.0488, ..., 0.4844, -0.1094, -1.0312],\n",
|
| 762 |
+
" [ 0.2891, 0.2988, 0.0923, ..., -0.5391, -0.7148, -1.4062],\n",
|
| 763 |
+
" ...,\n",
|
| 764 |
+
" [-1.0312, 0.0737, 0.7617, ..., 0.2500, -0.2715, 0.6016],\n",
|
| 765 |
+
" [ 0.5820, 0.8203, -0.1895, ..., 0.4043, 1.1484, 0.1982],\n",
|
| 766 |
+
" [ 0.2324, -0.7812, 1.4297, ..., 0.8867, -0.1924, 0.5898]]],\n",
|
| 767 |
+
" dtype=torch.bfloat16),\n",
|
| 768 |
+
" 'noise_pred_24': tensor([[[ 1.9688, -0.8164, 0.6484, ..., -1.7422, 0.9492, 0.3574],\n",
|
| 769 |
+
" [ 1.5703, 0.0918, -0.3262, ..., 0.7734, -0.2207, -1.8438],\n",
|
| 770 |
+
" [ 0.2871, 0.2637, 0.3340, ..., -1.1250, -1.8516, -2.7656],\n",
|
| 771 |
+
" ...,\n",
|
| 772 |
+
" [-1.2031, 0.4961, 1.6328, ..., -0.1768, -0.8555, 0.6055],\n",
|
| 773 |
+
" [ 1.2891, 1.6172, 0.1211, ..., -0.8516, 1.0625, -1.5000],\n",
|
| 774 |
+
" [ 0.7461, -0.8867, 2.6562, ..., -0.2891, -1.8047, -0.5391]]],\n",
|
| 775 |
+
" dtype=torch.bfloat16),\n",
|
| 776 |
+
" 't_25': tensor([0.6289], dtype=torch.bfloat16),\n",
|
| 777 |
+
" 'latents_25_start': tensor([[[ 1.2422, -0.4844, 0.3945, ..., -1.0781, 0.6562, 0.2471],\n",
|
| 778 |
+
" [ 1.0547, 0.0889, -0.0427, ..., 0.4707, -0.1055, -0.9961],\n",
|
| 779 |
+
" [ 0.2832, 0.2930, 0.0859, ..., -0.5195, -0.6797, -1.3516],\n",
|
| 780 |
+
" ...,\n",
|
| 781 |
+
" [-1.0078, 0.0645, 0.7305, ..., 0.2539, -0.2559, 0.5898],\n",
|
| 782 |
+
" [ 0.5586, 0.7891, -0.1914, ..., 0.4199, 1.1250, 0.2266],\n",
|
| 783 |
+
" [ 0.2188, -0.7656, 1.3828, ..., 0.8906, -0.1582, 0.6016]]],\n",
|
| 784 |
+
" dtype=torch.bfloat16),\n",
|
| 785 |
+
" 'noise_pred_25': tensor([[[ 1.9609, -0.8242, 0.6523, ..., -1.7422, 0.9258, 0.3418],\n",
|
| 786 |
+
" [ 1.5625, 0.0850, -0.3359, ..., 0.7812, -0.2295, -1.8516],\n",
|
| 787 |
+
" [ 0.2871, 0.2520, 0.3184, ..., -1.1250, -1.8359, -2.7500],\n",
|
| 788 |
+
" ...,\n",
|
| 789 |
+
" [-1.1875, 0.4863, 1.6250, ..., -0.1924, -0.8633, 0.6055],\n",
|
| 790 |
+
" [ 1.2969, 1.6172, 0.1240, ..., -0.8359, 1.0625, -1.5078],\n",
|
| 791 |
+
" [ 0.7305, -0.8789, 2.6562, ..., -0.2969, -1.8203, -0.5469]]],\n",
|
| 792 |
+
" dtype=torch.bfloat16),\n",
|
| 793 |
+
" 't_26': tensor([0.6094], dtype=torch.bfloat16),\n",
|
| 794 |
+
" 'latents_26_start': tensor([[[ 1.2031, -0.4688, 0.3828, ..., -1.0469, 0.6406, 0.2402],\n",
|
| 795 |
+
" [ 1.0234, 0.0874, -0.0364, ..., 0.4551, -0.1011, -0.9609],\n",
|
| 796 |
+
" [ 0.2773, 0.2891, 0.0801, ..., -0.4980, -0.6445, -1.2969],\n",
|
| 797 |
+
" ...,\n",
|
| 798 |
+
" [-0.9844, 0.0552, 0.6992, ..., 0.2578, -0.2393, 0.5781],\n",
|
| 799 |
+
" [ 0.5352, 0.7578, -0.1934, ..., 0.4355, 1.1016, 0.2559],\n",
|
| 800 |
+
" [ 0.2051, -0.7500, 1.3359, ..., 0.8945, -0.1235, 0.6133]]],\n",
|
| 801 |
+
" dtype=torch.bfloat16),\n",
|
| 802 |
+
" 'noise_pred_26': tensor([[[ 1.9609, -0.8320, 0.6602, ..., -1.7578, 0.9219, 0.3496],\n",
|
| 803 |
+
" [ 1.5703, 0.0801, -0.3359, ..., 0.7812, -0.2266, -1.8281],\n",
|
| 804 |
+
" [ 0.2793, 0.2471, 0.3223, ..., -1.1172, -1.8281, -2.7188],\n",
|
| 805 |
+
" ...,\n",
|
| 806 |
+
" [-1.1797, 0.4863, 1.6172, ..., -0.2021, -0.8516, 0.6133],\n",
|
| 807 |
+
" [ 1.2969, 1.6016, 0.1226, ..., -0.8359, 1.0625, -1.5000],\n",
|
| 808 |
+
" [ 0.7305, -0.8906, 2.6562, ..., -0.2988, -1.8047, -0.5469]]],\n",
|
| 809 |
+
" dtype=torch.bfloat16),\n",
|
| 810 |
+
" 't_27': tensor([0.5898], dtype=torch.bfloat16),\n",
|
| 811 |
+
" 'latents_27_start': tensor([[[ 1.1641, -0.4531, 0.3691, ..., -1.0156, 0.6211, 0.2334],\n",
|
| 812 |
+
" [ 0.9922, 0.0859, -0.0298, ..., 0.4395, -0.0967, -0.9258],\n",
|
| 813 |
+
" [ 0.2715, 0.2852, 0.0737, ..., -0.4766, -0.6094, -1.2422],\n",
|
| 814 |
+
" ...,\n",
|
| 815 |
+
" [-0.9609, 0.0457, 0.6680, ..., 0.2617, -0.2227, 0.5664],\n",
|
| 816 |
+
" [ 0.5078, 0.7266, -0.1953, ..., 0.4512, 1.0781, 0.2852],\n",
|
| 817 |
+
" [ 0.1904, -0.7344, 1.2812, ..., 0.8984, -0.0884, 0.6250]]],\n",
|
| 818 |
+
" dtype=torch.bfloat16),\n",
|
| 819 |
+
" 'noise_pred_27': tensor([[[ 1.9453, -0.8398, 0.6562, ..., -1.7578, 0.9141, 0.3398],\n",
|
| 820 |
+
" [ 1.5469, 0.0654, -0.3379, ..., 0.7734, -0.2236, -1.8359],\n",
|
| 821 |
+
" [ 0.2773, 0.2354, 0.3223, ..., -1.1094, -1.8359, -2.7188],\n",
|
| 822 |
+
" ...,\n",
|
| 823 |
+
" [-1.1797, 0.4844, 1.6328, ..., -0.1992, -0.8359, 0.6172],\n",
|
| 824 |
+
" [ 1.2891, 1.5859, 0.1133, ..., -0.8242, 1.0469, -1.4922],\n",
|
| 825 |
+
" [ 0.7148, -0.9180, 2.6406, ..., -0.3047, -1.8203, -0.5508]]],\n",
|
| 826 |
+
" dtype=torch.bfloat16),\n",
|
| 827 |
+
" 't_28': tensor([0.5664], dtype=torch.bfloat16),\n",
|
| 828 |
+
" 'latents_28_start': tensor([[[ 1.1250, -0.4355, 0.3555, ..., -0.9805, 0.6016, 0.2266],\n",
|
| 829 |
+
" [ 0.9609, 0.0845, -0.0231, ..., 0.4238, -0.0923, -0.8906],\n",
|
| 830 |
+
" [ 0.2656, 0.2812, 0.0674, ..., -0.4551, -0.5742, -1.1875],\n",
|
| 831 |
+
" ...,\n",
|
| 832 |
+
" [-0.9375, 0.0361, 0.6367, ..., 0.2656, -0.2061, 0.5547],\n",
|
| 833 |
+
" [ 0.4824, 0.6953, -0.1973, ..., 0.4668, 1.0547, 0.3145],\n",
|
| 834 |
+
" [ 0.1758, -0.7148, 1.2266, ..., 0.9062, -0.0522, 0.6367]]],\n",
|
| 835 |
+
" dtype=torch.bfloat16),\n",
|
| 836 |
+
" 'noise_pred_28': tensor([[[ 1.9453, -0.8281, 0.6406, ..., -1.7344, 0.9219, 0.3672],\n",
|
| 837 |
+
" [ 1.5547, 0.0684, -0.3496, ..., 0.8047, -0.1953, -1.8281],\n",
|
| 838 |
+
" [ 0.2812, 0.2207, 0.3281, ..., -1.1016, -1.8359, -2.7188],\n",
|
| 839 |
+
" ...,\n",
|
| 840 |
+
" [-1.1953, 0.4668, 1.6328, ..., -0.1758, -0.8203, 0.6445],\n",
|
| 841 |
+
" [ 1.2734, 1.5781, 0.1045, ..., -0.7969, 1.0547, -1.4922],\n",
|
| 842 |
+
" [ 0.6953, -0.9258, 2.6562, ..., -0.3086, -1.7969, -0.5508]]],\n",
|
| 843 |
+
" dtype=torch.bfloat16),\n",
|
| 844 |
+
" 't_29': tensor([0.5469], dtype=torch.bfloat16),\n",
|
| 845 |
+
" 'latents_29_start': tensor([[[ 1.0859, -0.4180, 0.3418, ..., -0.9453, 0.5820, 0.2188],\n",
|
| 846 |
+
" [ 0.9297, 0.0830, -0.0159, ..., 0.4082, -0.0884, -0.8516],\n",
|
| 847 |
+
" [ 0.2598, 0.2773, 0.0608, ..., -0.4336, -0.5352, -1.1328],\n",
|
| 848 |
+
" ...,\n",
|
| 849 |
+
" [-0.9141, 0.0266, 0.6016, ..., 0.2695, -0.1895, 0.5430],\n",
|
| 850 |
+
" [ 0.4570, 0.6641, -0.1992, ..., 0.4824, 1.0312, 0.3457],\n",
|
| 851 |
+
" [ 0.1621, -0.6953, 1.1719, ..., 0.9141, -0.0156, 0.6484]]],\n",
|
| 852 |
+
" dtype=torch.bfloat16),\n",
|
| 853 |
+
" 'noise_pred_29': tensor([[[ 1.9688, -0.8320, 0.6445, ..., -1.7734, 0.9219, 0.3672],\n",
|
| 854 |
+
" [ 1.5469, 0.0732, -0.3477, ..., 0.7930, -0.2100, -1.8281],\n",
|
| 855 |
+
" [ 0.2793, 0.2354, 0.3262, ..., -1.1250, -1.8359, -2.7188],\n",
|
| 856 |
+
" ...,\n",
|
| 857 |
+
" [-1.1953, 0.4746, 1.6172, ..., -0.1738, -0.8086, 0.6484],\n",
|
| 858 |
+
" [ 1.2969, 1.5781, 0.0952, ..., -0.8164, 1.0859, -1.4766],\n",
|
| 859 |
+
" [ 0.6797, -0.9180, 2.6562, ..., -0.3145, -1.7812, -0.5508]]],\n",
|
| 860 |
+
" dtype=torch.bfloat16),\n",
|
| 861 |
+
" 't_30': tensor([0.5273], dtype=torch.bfloat16),\n",
|
| 862 |
+
" 'latents_30_start': tensor([[[ 1.0469, -0.4004, 0.3281, ..., -0.9102, 0.5625, 0.2109],\n",
|
| 863 |
+
" [ 0.8984, 0.0815, -0.0087, ..., 0.3926, -0.0840, -0.8125],\n",
|
| 864 |
+
" [ 0.2539, 0.2734, 0.0540, ..., -0.4102, -0.4961, -1.0781],\n",
|
| 865 |
+
" ...,\n",
|
| 866 |
+
" [-0.8906, 0.0168, 0.5664, ..., 0.2734, -0.1729, 0.5312],\n",
|
| 867 |
+
" [ 0.4297, 0.6328, -0.2012, ..., 0.5000, 1.0078, 0.3770],\n",
|
| 868 |
+
" [ 0.1484, -0.6758, 1.1172, ..., 0.9219, 0.0212, 0.6602]]],\n",
|
| 869 |
+
" dtype=torch.bfloat16),\n",
|
| 870 |
+
" 'noise_pred_30': tensor([[[ 1.9766, -0.8359, 0.6484, ..., -1.7812, 0.9219, 0.3691],\n",
|
| 871 |
+
" [ 1.5469, 0.0693, -0.3359, ..., 0.7969, -0.2090, -1.8203],\n",
|
| 872 |
+
" [ 0.2852, 0.2363, 0.3320, ..., -1.1172, -1.8359, -2.7344],\n",
|
| 873 |
+
" ...,\n",
|
| 874 |
+
" [-1.1953, 0.4766, 1.6406, ..., -0.1855, -0.8203, 0.6484],\n",
|
| 875 |
+
" [ 1.2891, 1.5781, 0.1064, ..., -0.8203, 1.0859, -1.4922],\n",
|
| 876 |
+
" [ 0.6953, -0.9219, 2.6719, ..., -0.3145, -1.7656, -0.5312]]],\n",
|
| 877 |
+
" dtype=torch.bfloat16),\n",
|
| 878 |
+
" 't_31': tensor([0.5078], dtype=torch.bfloat16),\n",
|
| 879 |
+
" 'latents_31_start': tensor([[[ 1.0078, -0.3828, 0.3145, ..., -0.8711, 0.5430, 0.2031],\n",
|
| 880 |
+
" [ 0.8672, 0.0801, -0.0015, ..., 0.3750, -0.0796, -0.7734],\n",
|
| 881 |
+
" [ 0.2480, 0.2676, 0.0469, ..., -0.3867, -0.4570, -1.0234],\n",
|
| 882 |
+
" ...,\n",
|
| 883 |
+
" [-0.8672, 0.0067, 0.5312, ..., 0.2773, -0.1553, 0.5156],\n",
|
| 884 |
+
" [ 0.4023, 0.5977, -0.2031, ..., 0.5156, 0.9844, 0.4082],\n",
|
| 885 |
+
" [ 0.1338, -0.6562, 1.0625, ..., 0.9297, 0.0588, 0.6719]]],\n",
|
| 886 |
+
" dtype=torch.bfloat16),\n",
|
| 887 |
+
" 'noise_pred_31': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9297, 0.3555],\n",
|
| 888 |
+
" [ 1.5469, 0.0796, -0.3516, ..., 0.7969, -0.2051, -1.8281],\n",
|
| 889 |
+
" [ 0.2754, 0.2344, 0.3262, ..., -1.1172, -1.8359, -2.7344],\n",
|
| 890 |
+
" ...,\n",
|
| 891 |
+
" [-1.1953, 0.4766, 1.6328, ..., -0.1895, -0.8125, 0.6289],\n",
|
| 892 |
+
" [ 1.2969, 1.5859, 0.0938, ..., -0.8125, 1.0781, -1.4766],\n",
|
| 893 |
+
" [ 0.6836, -0.9375, 2.6562, ..., -0.2949, -1.7734, -0.5234]]],\n",
|
| 894 |
+
" dtype=torch.bfloat16),\n",
|
| 895 |
+
" 't_32': tensor([0.4844], dtype=torch.bfloat16),\n",
|
| 896 |
+
" 'latents_32_start': tensor([[[ 0.9648, -0.3652, 0.3008, ..., -0.8320, 0.5234, 0.1953],\n",
|
| 897 |
+
" [ 0.8320, 0.0781, 0.0061, ..., 0.3574, -0.0752, -0.7344],\n",
|
| 898 |
+
" [ 0.2422, 0.2617, 0.0398, ..., -0.3633, -0.4180, -0.9648],\n",
|
| 899 |
+
" ...,\n",
|
| 900 |
+
" [-0.8398, -0.0037, 0.4961, ..., 0.2812, -0.1377, 0.5000],\n",
|
| 901 |
+
" [ 0.3750, 0.5625, -0.2051, ..., 0.5352, 0.9609, 0.4395],\n",
|
| 902 |
+
" [ 0.1191, -0.6367, 1.0078, ..., 0.9375, 0.0977, 0.6836]]],\n",
|
| 903 |
+
" dtype=torch.bfloat16),\n",
|
| 904 |
+
" 'noise_pred_32': tensor([[[ 1.9688, -0.8242, 0.6523, ..., -1.7656, 0.9258, 0.3535],\n",
|
| 905 |
+
" [ 1.5391, 0.0728, -0.3457, ..., 0.7891, -0.2021, -1.8203],\n",
|
| 906 |
+
" [ 0.2754, 0.2344, 0.3223, ..., -1.1172, -1.8281, -2.7344],\n",
|
| 907 |
+
" ...,\n",
|
| 908 |
+
" [-1.1875, 0.4688, 1.6250, ..., -0.1768, -0.8086, 0.6133],\n",
|
| 909 |
+
" [ 1.2812, 1.5703, 0.1079, ..., -0.8320, 1.0703, -1.4844],\n",
|
| 910 |
+
" [ 0.6914, -0.9297, 2.6562, ..., -0.3027, -1.7656, -0.5156]]],\n",
|
| 911 |
+
" dtype=torch.bfloat16),\n",
|
| 912 |
+
" 't_33': tensor([0.4629], dtype=torch.bfloat16),\n",
|
| 913 |
+
" 'latents_33_start': tensor([[[ 0.9219, -0.3477, 0.2871, ..., -0.7930, 0.5039, 0.1875],\n",
|
| 914 |
+
" [ 0.7969, 0.0767, 0.0138, ..., 0.3398, -0.0708, -0.6953],\n",
|
| 915 |
+
" [ 0.2363, 0.2559, 0.0327, ..., -0.3379, -0.3770, -0.9023],\n",
|
| 916 |
+
" ...,\n",
|
| 917 |
+
" [-0.8125, -0.0141, 0.4609, ..., 0.2852, -0.1196, 0.4863],\n",
|
| 918 |
+
" [ 0.3457, 0.5273, -0.2070, ..., 0.5547, 0.9375, 0.4727],\n",
|
| 919 |
+
" [ 0.1035, -0.6172, 0.9492, ..., 0.9453, 0.1367, 0.6953]]],\n",
|
| 920 |
+
" dtype=torch.bfloat16),\n",
|
| 921 |
+
" 'noise_pred_33': tensor([[[ 1.9609, -0.8242, 0.6484, ..., -1.7578, 0.9258, 0.3477],\n",
|
| 922 |
+
" [ 1.5312, 0.0684, -0.3379, ..., 0.7891, -0.2012, -1.8125],\n",
|
| 923 |
+
" [ 0.2812, 0.2158, 0.3164, ..., -1.1172, -1.8125, -2.7188],\n",
|
| 924 |
+
" ...,\n",
|
| 925 |
+
" [-1.1797, 0.4570, 1.6250, ..., -0.1826, -0.8086, 0.6055],\n",
|
| 926 |
+
" [ 1.2812, 1.5625, 0.1025, ..., -0.8125, 1.0625, -1.4922],\n",
|
| 927 |
+
" [ 0.6797, -0.9297, 2.6250, ..., -0.2949, -1.7656, -0.5117]]],\n",
|
| 928 |
+
" dtype=torch.bfloat16),\n",
|
| 929 |
+
" 't_34': tensor([0.4375], dtype=torch.bfloat16),\n",
|
| 930 |
+
" 'latents_34_start': tensor([[[ 0.8789, -0.3281, 0.2715, ..., -0.7539, 0.4824, 0.1797],\n",
|
| 931 |
+
" [ 0.7617, 0.0752, 0.0215, ..., 0.3223, -0.0664, -0.6523],\n",
|
| 932 |
+
" [ 0.2295, 0.2500, 0.0255, ..., -0.3125, -0.3359, -0.8398],\n",
|
| 933 |
+
" ...,\n",
|
| 934 |
+
" [-0.7852, -0.0245, 0.4238, ..., 0.2891, -0.1011, 0.4727],\n",
|
| 935 |
+
" [ 0.3164, 0.4922, -0.2090, ..., 0.5742, 0.9141, 0.5078],\n",
|
| 936 |
+
" [ 0.0879, -0.5977, 0.8906, ..., 0.9531, 0.1768, 0.7070]]],\n",
|
| 937 |
+
" dtype=torch.bfloat16),\n",
|
| 938 |
+
" 'noise_pred_34': tensor([[[ 1.9766, -0.8242, 0.6523, ..., -1.7812, 0.9258, 0.3535],\n",
|
| 939 |
+
" [ 1.5312, 0.0640, -0.3457, ..., 0.8086, -0.1934, -1.8047],\n",
|
| 940 |
+
" [ 0.2715, 0.1992, 0.3105, ..., -1.0938, -1.8047, -2.7344],\n",
|
| 941 |
+
" ...,\n",
|
| 942 |
+
" [-1.1875, 0.4609, 1.6172, ..., -0.1953, -0.8164, 0.6172],\n",
|
| 943 |
+
" [ 1.2812, 1.5625, 0.0942, ..., -0.8242, 1.0781, -1.4922],\n",
|
| 944 |
+
" [ 0.6562, -0.9531, 2.6406, ..., -0.3008, -1.7734, -0.4980]]],\n",
|
| 945 |
+
" dtype=torch.bfloat16),\n",
|
| 946 |
+
" 't_35': tensor([0.4160], dtype=torch.bfloat16),\n",
|
| 947 |
+
" 'latents_35_start': tensor([[[ 0.8320, -0.3086, 0.2559, ..., -0.7109, 0.4609, 0.1719],\n",
|
| 948 |
+
" [ 0.7266, 0.0737, 0.0295, ..., 0.3027, -0.0620, -0.6094],\n",
|
| 949 |
+
" [ 0.2236, 0.2451, 0.0183, ..., -0.2871, -0.2930, -0.7773],\n",
|
| 950 |
+
" ...,\n",
|
| 951 |
+
" [-0.7578, -0.0352, 0.3867, ..., 0.2930, -0.0820, 0.4590],\n",
|
| 952 |
+
" [ 0.2871, 0.4551, -0.2109, ..., 0.5938, 0.8906, 0.5430],\n",
|
| 953 |
+
" [ 0.0728, -0.5742, 0.8281, ..., 0.9609, 0.2178, 0.7188]]],\n",
|
| 954 |
+
" dtype=torch.bfloat16),\n",
|
| 955 |
+
" 'noise_pred_35': tensor([[[ 1.9766, -0.8242, 0.6484, ..., -1.7734, 0.9258, 0.3496],\n",
|
| 956 |
+
" [ 1.5312, 0.0708, -0.3418, ..., 0.8047, -0.1953, -1.7969],\n",
|
| 957 |
+
" [ 0.2832, 0.1875, 0.3086, ..., -1.0938, -1.7891, -2.7344],\n",
|
| 958 |
+
" ...,\n",
|
| 959 |
+
" [-1.1719, 0.4551, 1.6172, ..., -0.1953, -0.8047, 0.6055],\n",
|
| 960 |
+
" [ 1.2578, 1.5625, 0.0898, ..., -0.8242, 1.0859, -1.4766],\n",
|
| 961 |
+
" [ 0.6602, -0.9414, 2.6406, ..., -0.3164, -1.7656, -0.5078]]],\n",
|
| 962 |
+
" dtype=torch.bfloat16),\n",
|
| 963 |
+
" 't_36': tensor([0.3926], dtype=torch.bfloat16),\n",
|
| 964 |
+
" 'latents_36_start': tensor([[[ 0.7852, -0.2891, 0.2402, ..., -0.6680, 0.4395, 0.1641],\n",
|
| 965 |
+
" [ 0.6914, 0.0723, 0.0376, ..., 0.2832, -0.0574, -0.5664],\n",
|
| 966 |
+
" [ 0.2168, 0.2402, 0.0110, ..., -0.2617, -0.2500, -0.7109],\n",
|
| 967 |
+
" ...,\n",
|
| 968 |
+
" [-0.7305, -0.0459, 0.3477, ..., 0.2969, -0.0630, 0.4453],\n",
|
| 969 |
+
" [ 0.2578, 0.4180, -0.2129, ..., 0.6133, 0.8633, 0.5781],\n",
|
| 970 |
+
" [ 0.0571, -0.5508, 0.7656, ..., 0.9688, 0.2598, 0.7305]]],\n",
|
| 971 |
+
" dtype=torch.bfloat16),\n",
|
| 972 |
+
" 'noise_pred_36': tensor([[[ 1.9609, -0.8164, 0.6445, ..., -1.7656, 0.9102, 0.3477],\n",
|
| 973 |
+
" [ 1.5234, 0.0654, -0.3320, ..., 0.8164, -0.2041, -1.7812],\n",
|
| 974 |
+
" [ 0.2734, 0.1836, 0.3066, ..., -1.0938, -1.7891, -2.7031],\n",
|
| 975 |
+
" ...,\n",
|
| 976 |
+
" [-1.1875, 0.4688, 1.6016, ..., -0.2051, -0.8008, 0.5977],\n",
|
| 977 |
+
" [ 1.2500, 1.5234, 0.0747, ..., -0.8438, 1.0781, -1.4922],\n",
|
| 978 |
+
" [ 0.6484, -0.9219, 2.6250, ..., -0.3027, -1.7812, -0.5078]]],\n",
|
| 979 |
+
" dtype=torch.bfloat16),\n",
|
| 980 |
+
" 't_37': tensor([0.3652], dtype=torch.bfloat16),\n",
|
| 981 |
+
" 'latents_37_start': tensor([[[ 0.7383, -0.2695, 0.2246, ..., -0.6250, 0.4180, 0.1553],\n",
|
| 982 |
+
" [ 0.6562, 0.0708, 0.0457, ..., 0.2637, -0.0525, -0.5234],\n",
|
| 983 |
+
" [ 0.2100, 0.2354, 0.0035, ..., -0.2354, -0.2061, -0.6445],\n",
|
| 984 |
+
" ...,\n",
|
| 985 |
+
" [-0.7031, -0.0574, 0.3086, ..., 0.3027, -0.0435, 0.4316],\n",
|
| 986 |
+
" [ 0.2275, 0.3809, -0.2148, ..., 0.6328, 0.8359, 0.6133],\n",
|
| 987 |
+
" [ 0.0413, -0.5273, 0.7031, ..., 0.9766, 0.3027, 0.7422]]],\n",
|
| 988 |
+
" dtype=torch.bfloat16),\n",
|
| 989 |
+
" 'noise_pred_37': tensor([[[ 1.9922, -0.8320, 0.6523, ..., -1.7812, 0.9336, 0.3613],\n",
|
| 990 |
+
" [ 1.5312, 0.0713, -0.3477, ..., 0.8281, -0.1963, -1.7891],\n",
|
| 991 |
+
" [ 0.2754, 0.1777, 0.2930, ..., -1.0859, -1.7734, -2.7031],\n",
|
| 992 |
+
" ...,\n",
|
| 993 |
+
" [-1.1875, 0.4590, 1.6172, ..., -0.1855, -0.8086, 0.5820],\n",
|
| 994 |
+
" [ 1.2422, 1.5391, 0.0552, ..., -0.8633, 1.0781, -1.5156],\n",
|
| 995 |
+
" [ 0.6602, -0.9219, 2.6250, ..., -0.3047, -1.7734, -0.5000]]],\n",
|
| 996 |
+
" dtype=torch.bfloat16),\n",
|
| 997 |
+
" 't_38': tensor([0.3418], dtype=torch.bfloat16),\n",
|
| 998 |
+
" 'latents_38_start': tensor([[[ 0.6875, -0.2490, 0.2080, ..., -0.5820, 0.3945, 0.1465],\n",
|
| 999 |
+
" [ 0.6172, 0.0688, 0.0544, ..., 0.2432, -0.0476, -0.4785],\n",
|
| 1000 |
+
" [ 0.2031, 0.2305, -0.0038, ..., -0.2080, -0.1621, -0.5781],\n",
|
| 1001 |
+
" ...,\n",
|
| 1002 |
+
" [-0.6719, -0.0688, 0.2676, ..., 0.3066, -0.0232, 0.4180],\n",
|
| 1003 |
+
" [ 0.1963, 0.3418, -0.2158, ..., 0.6562, 0.8086, 0.6523],\n",
|
| 1004 |
+
" [ 0.0248, -0.5039, 0.6367, ..., 0.9844, 0.3477, 0.7539]]],\n",
|
| 1005 |
+
" dtype=torch.bfloat16),\n",
|
| 1006 |
+
" 'noise_pred_38': tensor([[[ 1.9766, -0.8477, 0.6484, ..., -1.7812, 0.9531, 0.3789],\n",
|
| 1007 |
+
" [ 1.5234, 0.0781, -0.3496, ..., 0.8281, -0.1973, -1.7734],\n",
|
| 1008 |
+
" [ 0.2617, 0.1660, 0.2949, ..., -1.0547, -1.7422, -2.6875],\n",
|
| 1009 |
+
" ...,\n",
|
| 1010 |
+
" [-1.1641, 0.4668, 1.6328, ..., -0.1836, -0.8125, 0.5547],\n",
|
| 1011 |
+
" [ 1.2344, 1.5312, 0.0752, ..., -0.8867, 1.0234, -1.5156],\n",
|
| 1012 |
+
" [ 0.6406, -0.9219, 2.6250, ..., -0.2988, -1.7812, -0.5117]]],\n",
|
| 1013 |
+
" dtype=torch.bfloat16),\n",
|
| 1014 |
+
" 't_39': tensor([0.3164], dtype=torch.bfloat16),\n",
|
| 1015 |
+
" 'latents_39_start': tensor([[[ 0.6367, -0.2275, 0.1914, ..., -0.5352, 0.3711, 0.1367],\n",
|
| 1016 |
+
" [ 0.5781, 0.0669, 0.0635, ..., 0.2217, -0.0425, -0.4336],\n",
|
| 1017 |
+
" [ 0.1963, 0.2266, -0.0114, ..., -0.1807, -0.1172, -0.5078],\n",
|
| 1018 |
+
" ...,\n",
|
| 1019 |
+
" [-0.6406, -0.0811, 0.2256, ..., 0.3105, -0.0023, 0.4043],\n",
|
| 1020 |
+
" [ 0.1641, 0.3027, -0.2178, ..., 0.6797, 0.7812, 0.6914],\n",
|
| 1021 |
+
" [ 0.0083, -0.4805, 0.5703, ..., 0.9922, 0.3926, 0.7656]]],\n",
|
| 1022 |
+
" dtype=torch.bfloat16),\n",
|
| 1023 |
+
" 'noise_pred_39': tensor([[[ 1.9531, -0.8320, 0.6523, ..., -1.7500, 0.9336, 0.3965],\n",
|
| 1024 |
+
" [ 1.5156, 0.0767, -0.3691, ..., 0.8281, -0.1777, -1.7500],\n",
|
| 1025 |
+
" [ 0.2637, 0.1729, 0.2734, ..., -1.0234, -1.6797, -2.6250],\n",
|
| 1026 |
+
" ...,\n",
|
| 1027 |
+
" [-1.1406, 0.4531, 1.6016, ..., -0.1973, -0.8164, 0.5234],\n",
|
| 1028 |
+
" [ 1.2344, 1.5078, 0.0593, ..., -0.8906, 1.0234, -1.4844],\n",
|
| 1029 |
+
" [ 0.6367, -0.9023, 2.5938, ..., -0.2910, -1.7500, -0.5078]]],\n",
|
| 1030 |
+
" dtype=torch.bfloat16),\n",
|
| 1031 |
+
" 't_40': tensor([0.2891], dtype=torch.bfloat16),\n",
|
| 1032 |
+
" 'latents_40_start': tensor([[[ 0.5859, -0.2061, 0.1738, ..., -0.4883, 0.3457, 0.1260],\n",
|
| 1033 |
+
" [ 0.5391, 0.0649, 0.0732, ..., 0.2002, -0.0378, -0.3867],\n",
|
| 1034 |
+
" [ 0.1895, 0.2217, -0.0186, ..., -0.1543, -0.0732, -0.4395],\n",
|
| 1035 |
+
" ...,\n",
|
| 1036 |
+
" [-0.6094, -0.0928, 0.1836, ..., 0.3164, 0.0192, 0.3906],\n",
|
| 1037 |
+
" [ 0.1318, 0.2637, -0.2197, ..., 0.7031, 0.7539, 0.7305],\n",
|
| 1038 |
+
" [-0.0084, -0.4570, 0.5039, ..., 1.0000, 0.4375, 0.7773]]],\n",
|
| 1039 |
+
" dtype=torch.bfloat16),\n",
|
| 1040 |
+
" 'noise_pred_40': tensor([[[ 1.9766, -0.8555, 0.6445, ..., -1.7500, 0.9414, 0.4004],\n",
|
| 1041 |
+
" [ 1.5234, 0.0693, -0.3613, ..., 0.8672, -0.1650, -1.7266],\n",
|
| 1042 |
+
" [ 0.2598, 0.1660, 0.2734, ..., -1.0156, -1.7031, -2.6250],\n",
|
| 1043 |
+
" ...,\n",
|
| 1044 |
+
" [-1.1641, 0.4531, 1.6016, ..., -0.1611, -0.8125, 0.5039],\n",
|
| 1045 |
+
" [ 1.2109, 1.5156, 0.0391, ..., -0.8906, 0.9883, -1.4688],\n",
|
| 1046 |
+
" [ 0.6250, -0.9102, 2.6094, ..., -0.2930, -1.7578, -0.4922]]],\n",
|
| 1047 |
+
" dtype=torch.bfloat16),\n",
|
| 1048 |
+
" 't_41': tensor([0.2617], dtype=torch.bfloat16),\n",
|
| 1049 |
+
" 'latents_41_start': tensor([[[ 0.5312, -0.1826, 0.1562, ..., -0.4414, 0.3203, 0.1152],\n",
|
| 1050 |
+
" [ 0.4980, 0.0630, 0.0830, ..., 0.1768, -0.0334, -0.3398],\n",
|
| 1051 |
+
" [ 0.1826, 0.2168, -0.0259, ..., -0.1270, -0.0273, -0.3691],\n",
|
| 1052 |
+
" ...,\n",
|
| 1053 |
+
" [-0.5781, -0.1050, 0.1406, ..., 0.3203, 0.0410, 0.3770],\n",
|
| 1054 |
+
" [ 0.0991, 0.2227, -0.2207, ..., 0.7266, 0.7266, 0.7695],\n",
|
| 1055 |
+
" [-0.0253, -0.4316, 0.4336, ..., 1.0078, 0.4844, 0.7891]]],\n",
|
| 1056 |
+
" dtype=torch.bfloat16),\n",
|
| 1057 |
+
" 'noise_pred_41': tensor([[[ 1.9375, -0.8555, 0.6367, ..., -1.7266, 0.9531, 0.4297],\n",
|
| 1058 |
+
" [ 1.5312, 0.0659, -0.3633, ..., 0.8711, -0.1660, -1.7266],\n",
|
| 1059 |
+
" [ 0.2520, 0.1826, 0.2676, ..., -0.9844, -1.6328, -2.5781],\n",
|
| 1060 |
+
" ...,\n",
|
| 1061 |
+
" [-1.1484, 0.4453, 1.5859, ..., -0.1562, -0.8281, 0.4922],\n",
|
| 1062 |
+
" [ 1.1797, 1.5078, 0.0229, ..., -0.9102, 0.9531, -1.4609],\n",
|
| 1063 |
+
" [ 0.6211, -0.9102, 2.5938, ..., -0.2832, -1.7266, -0.4727]]],\n",
|
| 1064 |
+
" dtype=torch.bfloat16),\n",
|
| 1065 |
+
" 't_42': tensor([0.2354], dtype=torch.bfloat16),\n",
|
| 1066 |
+
" 'latents_42_start': tensor([[[ 0.4785, -0.1592, 0.1387, ..., -0.3945, 0.2949, 0.1035],\n",
|
| 1067 |
+
" [ 0.4551, 0.0613, 0.0928, ..., 0.1523, -0.0288, -0.2930],\n",
|
| 1068 |
+
" [ 0.1758, 0.2119, -0.0332, ..., -0.0996, 0.0178, -0.2969],\n",
|
| 1069 |
+
" ...,\n",
|
| 1070 |
+
" [-0.5469, -0.1172, 0.0967, ..., 0.3242, 0.0640, 0.3633],\n",
|
| 1071 |
+
" [ 0.0664, 0.1816, -0.2217, ..., 0.7500, 0.6992, 0.8086],\n",
|
| 1072 |
+
" [-0.0425, -0.4062, 0.3613, ..., 1.0156, 0.5312, 0.8008]]],\n",
|
| 1073 |
+
" dtype=torch.bfloat16),\n",
|
| 1074 |
+
" 'noise_pred_42': tensor([[[ 1.9297, -0.8594, 0.6289, ..., -1.7031, 0.9609, 0.4297],\n",
|
| 1075 |
+
" [ 1.5000, 0.0811, -0.3652, ..., 0.8711, -0.1494, -1.7031],\n",
|
| 1076 |
+
" [ 0.2246, 0.1543, 0.2695, ..., -0.9492, -1.6328, -2.5312],\n",
|
| 1077 |
+
" ...,\n",
|
| 1078 |
+
" [-1.1328, 0.4395, 1.5781, ..., -0.1680, -0.8398, 0.4453],\n",
|
| 1079 |
+
" [ 1.1484, 1.4766, 0.0073, ..., -0.9648, 0.8984, -1.4219],\n",
|
| 1080 |
+
" [ 0.5938, -0.8672, 2.5312, ..., -0.2949, -1.7031, -0.4766]]],\n",
|
| 1081 |
+
" dtype=torch.bfloat16),\n",
|
| 1082 |
+
" 't_43': tensor([0.2070], dtype=torch.bfloat16),\n",
|
| 1083 |
+
" 'latents_43_start': tensor([[[ 0.4238, -0.1348, 0.1211, ..., -0.3457, 0.2676, 0.0913],\n",
|
| 1084 |
+
" [ 0.4121, 0.0591, 0.1030, ..., 0.1279, -0.0245, -0.2441],\n",
|
| 1085 |
+
" [ 0.1699, 0.2080, -0.0408, ..., -0.0728, 0.0640, -0.2246],\n",
|
| 1086 |
+
" ...,\n",
|
| 1087 |
+
" [-0.5156, -0.1299, 0.0520, ..., 0.3281, 0.0879, 0.3516],\n",
|
| 1088 |
+
" [ 0.0339, 0.1396, -0.2217, ..., 0.7773, 0.6719, 0.8477],\n",
|
| 1089 |
+
" [-0.0593, -0.3809, 0.2891, ..., 1.0234, 0.5781, 0.8125]]],\n",
|
| 1090 |
+
" dtype=torch.bfloat16),\n",
|
| 1091 |
+
" 'noise_pred_43': tensor([[[ 1.9141, -0.8711, 0.6328, ..., -1.6484, 0.9883, 0.4453],\n",
|
| 1092 |
+
" [ 1.4844, 0.0693, -0.4004, ..., 0.8867, -0.1152, -1.6797],\n",
|
| 1093 |
+
" [ 0.2432, 0.1484, 0.2090, ..., -0.8867, -1.5938, -2.4531],\n",
|
| 1094 |
+
" ...,\n",
|
| 1095 |
+
" [-1.1094, 0.4453, 1.5703, ..., -0.1865, -0.8594, 0.3906],\n",
|
| 1096 |
+
" [ 1.0938, 1.4375, -0.0374, ..., -0.9648, 0.8359, -1.4531],\n",
|
| 1097 |
+
" [ 0.5898, -0.8516, 2.4688, ..., -0.2773, -1.6484, -0.4746]]],\n",
|
| 1098 |
+
" dtype=torch.bfloat16),\n",
|
| 1099 |
+
" 't_44': tensor([0.1777], dtype=torch.bfloat16),\n",
|
| 1100 |
+
" 'latents_44_start': tensor([[[ 0.3672, -0.1094, 0.1025, ..., -0.2969, 0.2393, 0.0781],\n",
|
| 1101 |
+
" [ 0.3691, 0.0571, 0.1147, ..., 0.1021, -0.0212, -0.1953],\n",
|
| 1102 |
+
" [ 0.1631, 0.2041, -0.0469, ..., -0.0469, 0.1104, -0.1533],\n",
|
| 1103 |
+
" ...,\n",
|
| 1104 |
+
" [-0.4844, -0.1426, 0.0063, ..., 0.3340, 0.1128, 0.3398],\n",
|
| 1105 |
+
" [ 0.0022, 0.0977, -0.2207, ..., 0.8047, 0.6484, 0.8906],\n",
|
| 1106 |
+
" [-0.0762, -0.3555, 0.2168, ..., 1.0312, 0.6250, 0.8281]]],\n",
|
| 1107 |
+
" dtype=torch.bfloat16),\n",
|
| 1108 |
+
" 'noise_pred_44': tensor([[[ 1.8984, -0.8984, 0.5977, ..., -1.6094, 0.9805, 0.4551],\n",
|
| 1109 |
+
" [ 1.4609, 0.0474, -0.4160, ..., 0.9102, -0.1060, -1.6484],\n",
|
| 1110 |
+
" [ 0.2119, 0.0791, 0.2021, ..., -0.8359, -1.5391, -2.4219],\n",
|
| 1111 |
+
" ...,\n",
|
| 1112 |
+
" [-1.0859, 0.4395, 1.5391, ..., -0.2002, -0.8516, 0.3359],\n",
|
| 1113 |
+
" [ 1.0625, 1.4219, -0.0486, ..., -0.9609, 0.7930, -1.4453],\n",
|
| 1114 |
+
" [ 0.5898, -0.8008, 2.4219, ..., -0.3223, -1.5703, -0.4609]]],\n",
|
| 1115 |
+
" dtype=torch.bfloat16),\n",
|
| 1116 |
+
" 't_45': tensor([0.1484], dtype=torch.bfloat16),\n",
|
| 1117 |
+
" 'latents_45_start': tensor([[[ 0.3105, -0.0825, 0.0850, ..., -0.2490, 0.2100, 0.0645],\n",
|
| 1118 |
+
" [ 0.3262, 0.0557, 0.1270, ..., 0.0747, -0.0181, -0.1465],\n",
|
| 1119 |
+
" [ 0.1562, 0.2021, -0.0530, ..., -0.0219, 0.1562, -0.0811],\n",
|
| 1120 |
+
" ...,\n",
|
| 1121 |
+
" [-0.4512, -0.1553, -0.0398, ..., 0.3398, 0.1387, 0.3301],\n",
|
| 1122 |
+
" [-0.0295, 0.0552, -0.2197, ..., 0.8320, 0.6250, 0.9336],\n",
|
| 1123 |
+
" [-0.0938, -0.3320, 0.1445, ..., 1.0391, 0.6719, 0.8438]]],\n",
|
| 1124 |
+
" dtype=torch.bfloat16),\n",
|
| 1125 |
+
" 'noise_pred_45': tensor([[[ 1.8516, -0.8984, 0.5547, ..., -1.5547, 0.9688, 0.4727],\n",
|
| 1126 |
+
" [ 1.4219, 0.0500, -0.4258, ..., 0.9141, -0.0588, -1.5703],\n",
|
| 1127 |
+
" [ 0.1592, 0.0850, 0.1924, ..., -0.7500, -1.4766, -2.3125],\n",
|
| 1128 |
+
" ...,\n",
|
| 1129 |
+
" [-1.0469, 0.4414, 1.4766, ..., -0.1494, -0.8438, 0.3047],\n",
|
| 1130 |
+
" [ 0.9883, 1.4062, -0.0874, ..., -0.9844, 0.7422, -1.4297],\n",
|
| 1131 |
+
" [ 0.5742, -0.7578, 2.3125, ..., -0.3301, -1.3672, -0.4238]]],\n",
|
| 1132 |
+
" dtype=torch.bfloat16),\n",
|
| 1133 |
+
" 't_46': tensor([0.1172], dtype=torch.bfloat16),\n",
|
| 1134 |
+
" 'latents_46_start': tensor([[[ 0.2539, -0.0549, 0.0679, ..., -0.2012, 0.1807, 0.0500],\n",
|
| 1135 |
+
" [ 0.2832, 0.0542, 0.1396, ..., 0.0469, -0.0162, -0.0986],\n",
|
| 1136 |
+
" [ 0.1514, 0.1992, -0.0588, ..., 0.0011, 0.2012, -0.0103],\n",
|
| 1137 |
+
" ...,\n",
|
| 1138 |
+
" [-0.4199, -0.1689, -0.0850, ..., 0.3438, 0.1641, 0.3203],\n",
|
| 1139 |
+
" [-0.0598, 0.0122, -0.2168, ..., 0.8633, 0.6016, 0.9766],\n",
|
| 1140 |
+
" [-0.1113, -0.3086, 0.0737, ..., 1.0469, 0.7148, 0.8555]]],\n",
|
| 1141 |
+
" dtype=torch.bfloat16),\n",
|
| 1142 |
+
" 'noise_pred_46': tensor([[[ 1.7734, -0.9492, 0.5430, ..., -1.5312, 0.9727, 0.5273],\n",
|
| 1143 |
+
" [ 1.4219, 0.0054, -0.4844, ..., 0.9297, 0.0198, -1.4531],\n",
|
| 1144 |
+
" [ 0.1377, -0.0330, 0.1299, ..., -0.6914, -1.3906, -2.2188],\n",
|
| 1145 |
+
" ...,\n",
|
| 1146 |
+
" [-0.9727, 0.4297, 1.3906, ..., -0.1172, -0.8164, 0.2295],\n",
|
| 1147 |
+
" [ 0.8945, 1.3516, -0.1758, ..., -0.9961, 0.6445, -1.5078],\n",
|
| 1148 |
+
" [ 0.5234, -0.7422, 2.2031, ..., -0.3848, -1.1641, -0.4883]]],\n",
|
| 1149 |
+
" dtype=torch.bfloat16),\n",
|
| 1150 |
+
" 't_47': tensor([0.0854], dtype=torch.bfloat16),\n",
|
| 1151 |
+
" 'latents_47_start': tensor([[[ 0.1982, -0.0250, 0.0508, ..., -0.1523, 0.1504, 0.0334],\n",
|
| 1152 |
+
" [ 0.2383, 0.0540, 0.1553, ..., 0.0176, -0.0168, -0.0530],\n",
|
| 1153 |
+
" [ 0.1475, 0.2002, -0.0630, ..., 0.0228, 0.2451, 0.0596],\n",
|
| 1154 |
+
" ...,\n",
|
| 1155 |
+
" [-0.3887, -0.1826, -0.1289, ..., 0.3477, 0.1895, 0.3125],\n",
|
| 1156 |
+
" [-0.0879, -0.0303, -0.2109, ..., 0.8945, 0.5820, 1.0234],\n",
|
| 1157 |
+
" [-0.1279, -0.2852, 0.0044, ..., 1.0625, 0.7500, 0.8711]]],\n",
|
| 1158 |
+
" dtype=torch.bfloat16),\n",
|
| 1159 |
+
" 'noise_pred_47': tensor([[[ 1.6250, -0.9688, 0.4160, ..., -1.4609, 0.9961, 0.5742],\n",
|
| 1160 |
+
" [ 1.3594, 0.0728, -0.5430, ..., 0.9062, 0.0530, -1.3438],\n",
|
| 1161 |
+
" [ 0.1553, -0.1787, 0.0908, ..., -0.5820, -1.1875, -1.9688],\n",
|
| 1162 |
+
" ...,\n",
|
| 1163 |
+
" [-0.8281, 0.4160, 1.2422, ..., -0.0122, -0.7500, 0.1396],\n",
|
| 1164 |
+
" [ 0.7734, 1.2812, -0.2295, ..., -0.9883, 0.5039, -1.4844],\n",
|
| 1165 |
+
" [ 0.4453, -0.6719, 1.9688, ..., -0.4180, -0.9141, -0.6211]]],\n",
|
| 1166 |
+
" dtype=torch.bfloat16),\n",
|
| 1167 |
+
" 't_48': tensor([0.0532], dtype=torch.bfloat16),\n",
|
| 1168 |
+
" 'latents_48_start': tensor([[[ 0.1455, 0.0065, 0.0374, ..., -0.1050, 0.1182, 0.0148],\n",
|
| 1169 |
+
" [ 0.1943, 0.0515, 0.1729, ..., -0.0118, -0.0186, -0.0093],\n",
|
| 1170 |
+
" [ 0.1426, 0.2061, -0.0659, ..., 0.0417, 0.2832, 0.1235],\n",
|
| 1171 |
+
" ...,\n",
|
| 1172 |
+
" [-0.3613, -0.1963, -0.1689, ..., 0.3477, 0.2139, 0.3086],\n",
|
| 1173 |
+
" [-0.1133, -0.0718, -0.2031, ..., 0.9258, 0.5664, 1.0703],\n",
|
| 1174 |
+
" [-0.1426, -0.2637, -0.0596, ..., 1.0781, 0.7812, 0.8906]]],\n",
|
| 1175 |
+
" dtype=torch.bfloat16),\n",
|
| 1176 |
+
" 'noise_pred_48': tensor([[[ 1.2031, -0.9297, 0.2559, ..., -1.4141, 0.8906, 0.4258],\n",
|
| 1177 |
+
" [ 1.1016, 0.0625, -0.3242, ..., 0.8047, 0.1318, -1.1094],\n",
|
| 1178 |
+
" [ 0.1094, -0.2695, 0.2334, ..., -0.5820, -1.1016, -1.5234],\n",
|
| 1179 |
+
" ...,\n",
|
| 1180 |
+
" [-0.5938, 0.4551, 1.0938, ..., 0.0281, -0.6289, 0.1357],\n",
|
| 1181 |
+
" [ 0.6523, 1.0625, -0.2275, ..., -1.0938, 0.4297, -1.3984],\n",
|
| 1182 |
+
" [ 0.3906, -0.4180, 1.5391, ..., -0.5742, -0.6250, -0.7852]]],\n",
|
| 1183 |
+
" dtype=torch.bfloat16),\n",
|
| 1184 |
+
" 't_49': tensor([0.0200], dtype=torch.bfloat16),\n",
|
| 1185 |
+
" 'latents_49_start': tensor([[[ 1.0547e-01, 3.7354e-02, 2.8809e-02, ..., -5.8105e-02,\n",
|
| 1186 |
+
" 8.8867e-02, 6.1035e-04],\n",
|
| 1187 |
+
" [ 1.5820e-01, 4.9316e-02, 1.8359e-01, ..., -3.8574e-02,\n",
|
| 1188 |
+
" -2.2949e-02, 2.7588e-02],\n",
|
| 1189 |
+
" [ 1.3867e-01, 2.1484e-01, -7.3730e-02, ..., 6.1035e-02,\n",
|
| 1190 |
+
" 3.2031e-01, 1.7383e-01],\n",
|
| 1191 |
+
" ...,\n",
|
| 1192 |
+
" [-3.4180e-01, -2.1094e-01, -2.0508e-01, ..., 3.4766e-01,\n",
|
| 1193 |
+
" 2.3438e-01, 3.0469e-01],\n",
|
| 1194 |
+
" [-1.3477e-01, -1.0693e-01, -1.9531e-01, ..., 9.6094e-01,\n",
|
| 1195 |
+
" 5.5078e-01, 1.1172e+00],\n",
|
| 1196 |
+
" [-1.5527e-01, -2.5000e-01, -1.1035e-01, ..., 1.0938e+00,\n",
|
| 1197 |
+
" 8.0078e-01, 9.1797e-01]]], dtype=torch.bfloat16),\n",
|
| 1198 |
+
" 'noise_pred_49': tensor([[[ 0.7461, -0.5586, 0.2197, ..., -1.0469, 0.7109, 0.4902],\n",
|
| 1199 |
+
" [ 0.6094, 0.0464, -0.1650, ..., 0.4980, 0.2314, -0.9414],\n",
|
| 1200 |
+
" [ 0.1064, -0.2109, 0.1846, ..., -0.3633, -0.8086, -1.0234],\n",
|
| 1201 |
+
" ...,\n",
|
| 1202 |
+
" [-0.2559, 0.3711, 0.7461, ..., -0.2217, -0.2988, 0.0339],\n",
|
| 1203 |
+
" [ 0.4980, 0.5156, -0.0260, ..., -1.1250, 0.1064, -1.1250],\n",
|
| 1204 |
+
" [ 0.2471, 0.0179, 0.6875, ..., -0.7188, -0.5898, -0.8672]]],\n",
|
| 1205 |
+
" dtype=torch.bfloat16),\n",
|
| 1206 |
+
" 'output': tensor([[[ 0.0903, 0.0486, 0.0244, ..., -0.0371, 0.0747, -0.0092],\n",
|
| 1207 |
+
" [ 0.1465, 0.0483, 0.1865, ..., -0.0486, -0.0276, 0.0464],\n",
|
| 1208 |
+
" [ 0.1367, 0.2188, -0.0776, ..., 0.0684, 0.3359, 0.1943],\n",
|
| 1209 |
+
" ...,\n",
|
| 1210 |
+
" [-0.3359, -0.2188, -0.2197, ..., 0.3516, 0.2402, 0.3047],\n",
|
| 1211 |
+
" [-0.1445, -0.1172, -0.1943, ..., 0.9844, 0.5469, 1.1406],\n",
|
| 1212 |
+
" [-0.1602, -0.2500, -0.1240, ..., 1.1094, 0.8125, 0.9336]]],\n",
|
| 1213 |
+
" dtype=torch.bfloat16),\n",
|
| 1214 |
+
" 'height': 576,\n",
|
| 1215 |
+
" 'width': 448,\n",
|
| 1216 |
+
" 't': tensor([1.], dtype=torch.bfloat16),\n",
|
| 1217 |
+
" 'latents_start': tensor([[[ 1.9766, -0.8047, 0.6367, ..., -1.7422, 1.0469, 0.3809],\n",
|
| 1218 |
+
" [ 1.6562, 0.1147, -0.1562, ..., 0.7539, -0.1768, -1.6953],\n",
|
| 1219 |
+
" [ 0.3984, 0.3926, 0.1914, ..., -0.9258, -1.3281, -2.3281],\n",
|
| 1220 |
+
" ...,\n",
|
| 1221 |
+
" [-1.4766, 0.2539, 1.3359, ..., 0.1797, -0.6250, 0.7617],\n",
|
| 1222 |
+
" [ 1.0391, 1.3672, -0.1572, ..., 0.1152, 1.4688, -0.2852],\n",
|
| 1223 |
+
" [ 0.4941, -1.1094, 2.3438, ..., 0.8281, -0.8320, 0.4258]]],\n",
|
| 1224 |
+
" dtype=torch.bfloat16),\n",
|
| 1225 |
+
" 'noise_pred': tensor([[[ 1.8906, -0.8945, 0.5938, ..., -1.7578, 1.0078, 0.2539],\n",
|
| 1226 |
+
" [ 1.5781, 0.0278, -0.2793, ..., 0.7305, -0.1553, -1.7969],\n",
|
| 1227 |
+
" [ 0.3027, 0.2949, 0.1621, ..., -1.0625, -1.5938, -2.6406],\n",
|
| 1228 |
+
" ...,\n",
|
| 1229 |
+
" [-1.2578, 0.5352, 1.5859, ..., -0.2773, -1.0312, 0.3203],\n",
|
| 1230 |
+
" [ 1.2734, 1.5312, 0.0728, ..., -0.6211, 0.8984, -1.1562],\n",
|
| 1231 |
+
" [ 0.6172, -0.9336, 2.6719, ..., -0.1050, -1.8672, -0.3691]]],\n",
|
| 1232 |
+
" dtype=torch.bfloat16)}"
|
| 1233 |
+
]
|
| 1234 |
+
},
|
| 1235 |
+
"execution_count": 18,
|
| 1236 |
+
"metadata": {},
|
| 1237 |
+
"output_type": "execute_result"
|
| 1238 |
+
}
|
| 1239 |
+
],
|
| 1240 |
+
"source": [
|
| 1241 |
+
"src[0]"
|
| 1242 |
+
]
|
| 1243 |
+
},
|
| 1244 |
+
{
|
| 1245 |
+
"cell_type": "code",
|
| 1246 |
+
"execution_count": null,
|
| 1247 |
+
"id": "22f19ae9",
|
| 1248 |
+
"metadata": {},
|
| 1249 |
+
"outputs": [],
|
| 1250 |
+
"source": []
|
| 1251 |
+
}
|
| 1252 |
+
],
|
| 1253 |
+
"metadata": {
|
| 1254 |
+
"kernelspec": {
|
| 1255 |
+
"display_name": "Python 3",
|
| 1256 |
+
"language": "python",
|
| 1257 |
+
"name": "python3"
|
| 1258 |
+
},
|
| 1259 |
+
"language_info": {
|
| 1260 |
+
"codemirror_mode": {
|
| 1261 |
+
"name": "ipython",
|
| 1262 |
+
"version": 3
|
| 1263 |
+
},
|
| 1264 |
+
"file_extension": ".py",
|
| 1265 |
+
"mimetype": "text/x-python",
|
| 1266 |
+
"name": "python",
|
| 1267 |
+
"nbconvert_exporter": "python",
|
| 1268 |
+
"pygments_lexer": "ipython3",
|
| 1269 |
+
"version": "3.10.12"
|
| 1270 |
+
}
|
| 1271 |
+
},
|
| 1272 |
+
"nbformat": 4,
|
| 1273 |
+
"nbformat_minor": 5
|
| 1274 |
+
}
|
scripts/save_regression_outputs.py
CHANGED
|
@@ -1,94 +1,59 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
import huggingface_hub
|
| 6 |
import tqdm
|
|
|
|
| 7 |
|
| 8 |
from qwenimage.datamodels import QwenConfig
|
| 9 |
from qwenimage.foundation import QwenImageFoundationSaveInterm
|
| 10 |
-
from datasets import concatenate_datasets, load_dataset, interleave_datasets
|
| 11 |
-
|
| 12 |
-
repo_tree = huggingface_hub.list_repo_tree(
|
| 13 |
-
"WeiChow/CrispEdit-2M",
|
| 14 |
-
"data",
|
| 15 |
-
repo_type="dataset",
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
all_paths = []
|
| 19 |
-
for i in repo_tree:
|
| 20 |
-
all_paths.append(i.path)
|
| 21 |
-
|
| 22 |
-
parquet_prefixes = set()
|
| 23 |
-
for path in all_paths:
|
| 24 |
-
if path.endswith('.parquet'):
|
| 25 |
-
filename = path.split('/')[-1]
|
| 26 |
-
if '_' in filename:
|
| 27 |
-
prefix = filename.split('_')[0]
|
| 28 |
-
parquet_prefixes.add(prefix)
|
| 29 |
-
|
| 30 |
-
print("Found parquet prefixes:", sorted(parquet_prefixes))
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
total_per = 10
|
| 35 |
-
|
| 36 |
-
EDIT_TYPES = [
|
| 37 |
-
"color",
|
| 38 |
-
"style",
|
| 39 |
-
"replace",
|
| 40 |
-
"remove",
|
| 41 |
-
"add",
|
| 42 |
-
"motion change",
|
| 43 |
-
"background change",
|
| 44 |
-
]
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
all_edit_datasets = []
|
| 50 |
-
for edit_type in EDIT_TYPES:
|
| 51 |
-
to_concat = []
|
| 52 |
-
for ds_n in range(total_per):
|
| 53 |
-
ds = load_dataset("parquet", data_files=f"/data/CrispEdit/{edit_type}_{ds_n:05d}.parquet", split="train")
|
| 54 |
-
to_concat.append(ds)
|
| 55 |
-
edit_type_concat = concatenate_datasets(to_concat)
|
| 56 |
-
all_edit_datasets.append(edit_type_concat)
|
| 57 |
-
|
| 58 |
-
# consistent ordering for indexing, also allow extension
|
| 59 |
-
join_ds = interleave_datasets(all_edit_datasets)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
from pathlib import Path
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
save_base_dir = Path("/data/regression_output")
|
| 72 |
-
save_base_dir.mkdir(exist_ok=True, parents=True)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
|
|
|
|
|
|
| 82 |
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
))
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
| 92 |
|
|
|
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 5 |
import tqdm
|
| 6 |
+
from datasets import concatenate_datasets, load_dataset, interleave_datasets
|
| 7 |
|
| 8 |
from qwenimage.datamodels import QwenConfig
|
| 9 |
from qwenimage.foundation import QwenImageFoundationSaveInterm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument("--start-index", type=int, default=0)
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
|
| 17 |
+
total_per = 10
|
| 18 |
|
| 19 |
+
EDIT_TYPES = [
|
| 20 |
+
"color",
|
| 21 |
+
"style",
|
| 22 |
+
"replace",
|
| 23 |
+
"remove",
|
| 24 |
+
"add",
|
| 25 |
+
"motion change",
|
| 26 |
+
"background change",
|
| 27 |
+
]
|
| 28 |
|
| 29 |
+
all_edit_datasets = []
|
| 30 |
+
for edit_type in EDIT_TYPES:
|
| 31 |
+
to_concat = []
|
| 32 |
+
for ds_n in range(total_per):
|
| 33 |
+
ds = load_dataset("parquet", data_files=f"/data/CrispEdit/{edit_type}_{ds_n:05d}.parquet", split="train")
|
| 34 |
+
to_concat.append(ds)
|
| 35 |
+
edit_type_concat = concatenate_datasets(to_concat)
|
| 36 |
+
all_edit_datasets.append(edit_type_concat)
|
| 37 |
|
| 38 |
+
# consistent ordering for indexing, also allow extension
|
| 39 |
+
join_ds = interleave_datasets(all_edit_datasets)
|
| 40 |
|
| 41 |
+
save_base_dir = Path("/data/regression_output")
|
| 42 |
+
save_base_dir.mkdir(exist_ok=True, parents=True)
|
| 43 |
|
| 44 |
+
foundation = QwenImageFoundationSaveInterm(QwenConfig())
|
| 45 |
|
| 46 |
+
dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
|
| 47 |
+
|
| 48 |
+
for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
|
|
|
|
| 49 |
|
| 50 |
+
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
|
| 51 |
+
image=[input_data["input_img"]],
|
| 52 |
+
prompt=input_data["instruction"],
|
| 53 |
+
))
|
| 54 |
|
| 55 |
+
torch.save(output_dict, save_base_dir/f"{idx:06d}.pt")
|
| 56 |
|
| 57 |
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|