Elea Zhong commited on
Commit
702dbb9
·
1 Parent(s): 2daa93c

test inference

Browse files
qwenimage/models/pipeline_qwenimage_edit_save_interm.py CHANGED
@@ -36,6 +36,7 @@ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutpu
36
 
37
  from qwenimage.debug import ctimed, ftimed, texam
38
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
 
39
 
40
 
41
  if is_torch_xla_available():
@@ -793,28 +794,14 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
793
  ] * batch_size
794
 
795
  # 5. Prepare timesteps
796
- print(f"{num_inference_steps=}")
797
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
798
- image_seq_len = latents.shape[1]
799
- print(f"{image_seq_len=}")
800
- mu = calculate_shift(
801
- image_seq_len,
802
- self.scheduler.config.get("base_image_seq_len", 256),
803
- self.scheduler.config.get("max_image_seq_len", 4096),
804
- self.scheduler.config.get("base_shift", 0.5),
805
- self.scheduler.config.get("max_shift", 1.15),
806
  )
807
- print(f"{mu=}")
808
- timesteps, num_inference_steps = retrieve_timesteps(
809
- self.scheduler,
810
- num_inference_steps,
811
- device,
812
- sigmas=sigmas,
813
- mu=mu,
814
- )
815
- print(f"{timesteps=}")
816
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
817
- self._num_timesteps = len(timesteps)
818
 
819
  # handle guidance
820
  if self.transformer.config.guidance_embeds and guidance_scale is None:
@@ -855,25 +842,26 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
855
  # 6. Denoising loop
856
  self.scheduler.set_begin_index(0)
857
  with self.progress_bar(total=num_inference_steps) as progress_bar:
858
- for i, t in enumerate(timesteps):
 
859
  with ctimed(f"loop {i}"):
860
  if self.interrupt:
861
  continue
862
 
863
- self._current_timestep = t
864
 
865
  latent_model_input = latents
866
  if image_latents is not None:
867
  latent_model_input = torch.cat([latents, image_latents], dim=1)
868
 
869
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
870
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
871
- output_dict[f"t_{i}"] = (timestep / 1000).clone().cpu()
872
  output_dict[f"latents_{i}_start"] = latents.clone().cpu()
873
  with self.transformer.cache_context("cond"):
874
  noise_pred = self.transformer(
875
  hidden_states=latent_model_input,
876
- timestep=timestep / 1000,
877
  guidance=guidance,
878
  encoder_hidden_states_mask=prompt_embeds_mask,
879
  encoder_hidden_states=prompt_embeds,
@@ -890,7 +878,7 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
890
  with self.transformer.cache_context("uncond"):
891
  neg_noise_pred = self.transformer(
892
  hidden_states=latent_model_input,
893
- timestep=timestep / 1000,
894
  guidance=guidance,
895
  encoder_hidden_states_mask=negative_prompt_embeds_mask,
896
  encoder_hidden_states=negative_prompt_embeds,
@@ -907,7 +895,7 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
907
 
908
  # compute the previous noisy sample x_t -> x_t-1
909
  latents_dtype = latents.dtype
910
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
911
 
912
  # output_dict[f"latents_{i}"] = latents.clone().cpu()
913
 
@@ -926,8 +914,7 @@ class QwenImageEditSaveIntermPipeline(DiffusionPipeline, QwenImageLoraLoaderMixi
926
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
927
 
928
  # call the callback, if provided
929
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
930
- progress_bar.update()
931
 
932
  if XLA_AVAILABLE:
933
  xm.mark_step()
 
36
 
37
  from qwenimage.debug import ctimed, ftimed, texam
38
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
39
+ from qwenimage.sampling import TimestepDistUtils
40
 
41
 
42
  if is_torch_xla_available():
 
794
  ] * batch_size
795
 
796
  # 5. Prepare timesteps
797
+ t_utils = TimestepDistUtils(
798
+ min_seq_len=self.scheduler.config.get("base_image_seq_len", 256),
799
+ max_seq_len=self.scheduler.config.get("max_image_seq_len", 4096),
800
+ min_mu=self.scheduler.config.get("base_shift", 0.5),
801
+ max_mu=self.scheduler.config.get("max_shift", 1.15),
 
 
 
 
 
802
  )
803
+ ts = t_utils.get_inference_t(num_inference_steps, seq_len=t_utils.get_seq_len(latents)).to(device)
804
+ print(f"ts={ts}")
 
 
 
 
 
 
 
 
 
805
 
806
  # handle guidance
807
  if self.transformer.config.guidance_embeds and guidance_scale is None:
 
842
  # 6. Denoising loop
843
  self.scheduler.set_begin_index(0)
844
  with self.progress_bar(total=num_inference_steps) as progress_bar:
845
+ for i in range(len(ts)-1):
846
+ t = ts[i]
847
  with ctimed(f"loop {i}"):
848
  if self.interrupt:
849
  continue
850
 
851
+ # self._current_timestep = t
852
 
853
  latent_model_input = latents
854
  if image_latents is not None:
855
  latent_model_input = torch.cat([latents, image_latents], dim=1)
856
 
857
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
858
+ in_t = t.expand(latents.shape[0]).to(latents.dtype)
859
+ output_dict[f"t_{i}"] = in_t.clone().cpu()
860
  output_dict[f"latents_{i}_start"] = latents.clone().cpu()
861
  with self.transformer.cache_context("cond"):
862
  noise_pred = self.transformer(
863
  hidden_states=latent_model_input,
864
+ timestep=in_t,
865
  guidance=guidance,
866
  encoder_hidden_states_mask=prompt_embeds_mask,
867
  encoder_hidden_states=prompt_embeds,
 
878
  with self.transformer.cache_context("uncond"):
879
  neg_noise_pred = self.transformer(
880
  hidden_states=latent_model_input,
881
+ timestep=in_t,
882
  guidance=guidance,
883
  encoder_hidden_states_mask=negative_prompt_embeds_mask,
884
  encoder_hidden_states=negative_prompt_embeds,
 
895
 
896
  # compute the previous noisy sample x_t -> x_t-1
897
  latents_dtype = latents.dtype
898
+ latents = t_utils.inference_ode_step(noise_pred, latents, i, ts)
899
 
900
  # output_dict[f"latents_{i}"] = latents.clone().cpu()
901
 
 
914
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
915
 
916
  # call the callback, if provided
917
+ progress_bar.update()
 
918
 
919
  if XLA_AVAILABLE:
920
  xm.mark_step()
scripts/inf.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
scripts/save_model_outputs.py CHANGED
@@ -14,10 +14,10 @@ def main():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("--start-index", type=int, default=0)
16
  parser.add_argument("--end-index", type=int, default=100)
17
- parser.add_argument("--imsize", type=int, default=512)
18
  parser.add_argument("--indir", type=str, default="/data/CrispEdit")
19
  parser.add_argument("--outdir", type=str, default="/data/model_output")
20
- parser.add_argument("--steps", type=int, default=50)
21
  parser.add_argument("--checkpoint", type=str)
22
  parser.add_argument("--lora-rank", type=int)
23
  args = parser.parse_args()
@@ -53,8 +53,11 @@ def main():
53
  finetuner = QwenLoraFinetuner(foundation, foundation.config)
54
  finetuner.load(args.checkpoint, lora_rank=args.lora_rank)
55
 
56
- dataset_to_process = join_ds.select(range(args.start_index, len(join_ds)))
57
-
 
 
 
58
  for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
59
 
60
  output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
 
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("--start-index", type=int, default=0)
16
  parser.add_argument("--end-index", type=int, default=100)
17
+ parser.add_argument("--imsize", type=int, default=1024)
18
  parser.add_argument("--indir", type=str, default="/data/CrispEdit")
19
  parser.add_argument("--outdir", type=str, default="/data/model_output")
20
+ parser.add_argument("--steps", type=int, default=2)
21
  parser.add_argument("--checkpoint", type=str)
22
  parser.add_argument("--lora-rank", type=int)
23
  args = parser.parse_args()
 
53
  finetuner = QwenLoraFinetuner(foundation, foundation.config)
54
  finetuner.load(args.checkpoint, lora_rank=args.lora_rank)
55
 
56
+ dataset_to_process = join_ds.select(range(args.start_index, args.end_index))
57
+
58
+ # foundation.scheduler.config["base_shift"] = 2.0
59
+ # foundation.scheduler.config["max_shift"] = 2.0
60
+
61
  for idx, input_data in enumerate(tqdm.tqdm(dataset_to_process), start=args.start_index):
62
 
63
  output_dict = foundation.base_pipe(foundation.INPUT_MODEL(
scripts/straightness.ipynb CHANGED
The diff for this file is too large to render. See raw diff