Spaces:
Running
on
Zero
Running
on
Zero
Elea Zhong
commited on
Commit
·
702dbb9
1
Parent(s):
2daa93c
test inference
Browse files
qwenimage/models/pipeline_qwenimage_edit_save_interm.py
CHANGED
|
@@ -36,6 +36,7 @@ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutpu
|
|
| 36 |
|
| 37 |
from qwenimage.debug import ctimed, ftimed, texam
|
| 38 |
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
if is_torch_xla_available():
|
|
@@ -793,28 +794,14 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
|
|
| 793 |
] * batch_size
|
| 794 |
|
| 795 |
# 5. Prepare timesteps
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
image_seq_len,
|
| 802 |
-
self.scheduler.config.get("base_image_seq_len", 256),
|
| 803 |
-
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 804 |
-
self.scheduler.config.get("base_shift", 0.5),
|
| 805 |
-
self.scheduler.config.get("max_shift", 1.15),
|
| 806 |
)
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
self.scheduler,
|
| 810 |
-
num_inference_steps,
|
| 811 |
-
device,
|
| 812 |
-
sigmas=sigmas,
|
| 813 |
-
mu=mu,
|
| 814 |
-
)
|
| 815 |
-
print(f"{timesteps=}")
|
| 816 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 817 |
-
self._num_timesteps = len(timesteps)
|
| 818 |
|
| 819 |
# handle guidance
|
| 820 |
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
|
@@ -855,25 +842,26 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
|
|
| 855 |
# 6. Denoising loop
|
| 856 |
self.scheduler.set_begin_index(0)
|
| 857 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 858 |
-
for i
|
|
|
|
| 859 |
with ctimed(f"loop {i}"):
|
| 860 |
if self.interrupt:
|
| 861 |
continue
|
| 862 |
|
| 863 |
-
self._current_timestep = t
|
| 864 |
|
| 865 |
latent_model_input = latents
|
| 866 |
if image_latents is not None:
|
| 867 |
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 868 |
|
| 869 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 870 |
-
|
| 871 |
-
output_dict[f"t_{i}"] =
|
| 872 |
output_dict[f"latents_{i}_start"] = latents.clone().cpu()
|
| 873 |
with self.transformer.cache_context("cond"):
|
| 874 |
noise_pred = self.transformer(
|
| 875 |
hidden_states=latent_model_input,
|
| 876 |
-
timestep=
|
| 877 |
guidance=guidance,
|
| 878 |
encoder_hidden_states_mask=prompt_embeds_mask,
|
| 879 |
encoder_hidden_states=prompt_embeds,
|
|
@@ -890,7 +878,7 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
|
|
| 890 |
with self.transformer.cache_context("uncond"):
|
| 891 |
neg_noise_pred = self.transformer(
|
| 892 |
hidden_states=latent_model_input,
|
| 893 |
-
timestep=
|
| 894 |
guidance=guidance,
|
| 895 |
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
| 896 |
encoder_hidden_states=negative_prompt_embeds,
|
|
@@ -907,7 +895,7 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
|
|
| 907 |
|
| 908 |
# compute the previous noisy sample x_t -> x_t-1
|
| 909 |
latents_dtype = latents.dtype
|
| 910 |
-
latents =
|
| 911 |
|
| 912 |
# output_dict[f"latents_{i}"] = latents.clone().cpu()
|
| 913 |
|
|
@@ -926,8 +914,7 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
|
|
| 926 |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 927 |
|
| 928 |
# call the callback, if provided
|
| 929 |
-
|
| 930 |
-
progress_bar.update()
|
| 931 |
|
| 932 |
if XLA_AVAILABLE:
|
| 933 |
xm.mark_step()
|
|
|
|
| 36 |
|
| 37 |
from qwenimage.debug import ctimed, ftimed, texam
|
| 38 |
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
|
| 39 |
+
from qwenimage.sampling import TimestepDistUtils
|
| 40 |
|
| 41 |
|
| 42 |
if is_torch_xla_available():
|
|
|
|
| 794 |
] * batch_size
|
| 795 |
|
| 796 |
# 5. Prepare timesteps
|
| 797 |
+
t_utils = TimestepDistUtils(
|
| 798 |
+
min_seq_len=self.scheduler.config.get("base_image_seq_len", 256),
|
| 799 |
+
max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096),
|
| 800 |
+
min_mu=self.scheduler.config.get("base_shift", 0.5),
|
| 801 |
+
max_mu=self.scheduler.config.get("max_shift", 1.15),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
)
|
| 803 |
+
ts = t_utils.get_inference_t(num_inference_steps, seq_len=t_utils.get_seq_len(latents)).to(device)
|
| 804 |
+
print(f"ts={ts}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
|
| 806 |
# handle guidance
|
| 807 |
if self.transformer.config.guidance_embeds and guidance_scale is None:
|
|
|
|
| 842 |
# 6. Denoising loop
|
| 843 |
self.scheduler.set_begin_index(0)
|
| 844 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 845 |
+
for i in range(len(ts)-1):
|
| 846 |
+
t = ts[i]
|
| 847 |
with ctimed(f"loop {i}"):
|
| 848 |
if self.interrupt:
|
| 849 |
continue
|
| 850 |
|
| 851 |
+
# self._current_timestep = t
|
| 852 |
|
| 853 |
latent_model_input = latents
|
| 854 |
if image_latents is not None:
|
| 855 |
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
| 856 |
|
| 857 |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 858 |
+
in_t = t.expand(latents.shape[0]).to(latents.dtype)
|
| 859 |
+
output_dict[f"t_{i}"] = in_t.clone().cpu()
|
| 860 |
output_dict[f"latents_{i}_start"] = latents.clone().cpu()
|
| 861 |
with self.transformer.cache_context("cond"):
|
| 862 |
noise_pred = self.transformer(
|
| 863 |
hidden_states=latent_model_input,
|
| 864 |
+
timestep=in_t,
|
| 865 |
guidance=guidance,
|
| 866 |
encoder_hidden_states_mask=prompt_embeds_mask,
|
| 867 |
encoder_hidden_states=prompt_embeds,
|
|
|
|
| 878 |
with self.transformer.cache_context("uncond"):
|
| 879 |
neg_noise_pred = self.transformer(
|
| 880 |
hidden_states=latent_model_input,
|
| 881 |
+
timestep=in_t,
|
| 882 |
guidance=guidance,
|
| 883 |
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
| 884 |
encoder_hidden_states=negative_prompt_embeds,
|
|
|
|
| 895 |
|
| 896 |
# compute the previous noisy sample x_t -> x_t-1
|
| 897 |
latents_dtype = latents.dtype
|
| 898 |
+
latents = t_utils.inference_ode_step(noise_pred, latents, i, ts)
|
| 899 |
|
| 900 |
# output_dict[f"latents_{i}"] = latents.clone().cpu()
|
| 901 |
|
|
|
|
| 914 |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 915 |
|
| 916 |
# call the callback, if provided
|
| 917 |
+
progress_bar.update()
|
|
|
|
| 918 |
|
| 919 |
if XLA_AVAILABLE:
|
| 920 |
xm.mark_step()
|
scripts/inf.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/save_model_outputs.py
CHANGED
|
@@ -14,10 +14,10 @@ def main():
|
|
| 14 |
parser = argparse.ArgumentParser()
|
| 15 |
parser.add_argument("--start-index", type=int, default=0)
|
| 16 |
parser.add_argument("--end-index", type=int, default=100)
|
| 17 |
-
parser.add_argument("--imsize", type=int, default=
|
| 18 |
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
|
| 19 |
parser.add_argument("--outdir", type=str, default="/data/model_output")
|
| 20 |
-
parser.add_argument("--steps", type=int, default=
|
| 21 |
parser.add_argument("--checkpoint", type=str)
|
| 22 |
parser.add_argument("--lora-rank", type=int)
|
| 23 |
args = parser.parse_args()
|
|
@@ -53,8 +53,11 @@ def main():
|
|
| 53 |
finetuner = QwenLoraFinetuner(foundation, foundation.config)
|
| 54 |
finetuner.load(args.checkpoint, lora_rank=args.lora_rank)
|
| 55 |
|
| 56 |
-
dataset_to_process = join_ds.select(range(args.start_index,
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
| 58 |
for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
|
| 59 |
|
| 60 |
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
|
|
|
|
| 14 |
parser = argparse.ArgumentParser()
|
| 15 |
parser.add_argument("--start-index", type=int, default=0)
|
| 16 |
parser.add_argument("--end-index", type=int, default=100)
|
| 17 |
+
parser.add_argument("--imsize", type=int, default=1024)
|
| 18 |
parser.add_argument("--indir", type=str, default="/data/CrispEdit")
|
| 19 |
parser.add_argument("--outdir", type=str, default="/data/model_output")
|
| 20 |
+
parser.add_argument("--steps", type=int, default=2)
|
| 21 |
parser.add_argument("--checkpoint", type=str)
|
| 22 |
parser.add_argument("--lora-rank", type=int)
|
| 23 |
args = parser.parse_args()
|
|
|
|
| 53 |
finetuner = QwenLoraFinetuner(foundation, foundation.config)
|
| 54 |
finetuner.load(args.checkpoint, lora_rank=args.lora_rank)
|
| 55 |
|
| 56 |
+
dataset_to_process = join_ds.select(range(args.start_index, args.end_index))
|
| 57 |
+
|
| 58 |
+
# foundation.scheduler.config["base_shift"] = 2.0
|
| 59 |
+
# foundation.scheduler.config["max_shift"] = 2.0
|
| 60 |
+
|
| 61 |
for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
|
| 62 |
|
| 63 |
output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
|
scripts/straightness.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|