Elea Zhong commited on
Commit
0365768
·
1 Parent(s): f9abc90

image context training

Browse files
configs/base.yaml CHANGED
@@ -15,6 +15,7 @@ num_workers: 4
15
  resume_from_checkpoint: null
16
  log_model_steps: 100
17
  preprocessing_epoch_len: 64
 
18
 
19
 
20
  # Logging
 
15
  resume_from_checkpoint: null
16
  log_model_steps: 100
17
  preprocessing_epoch_len: 64
18
+ preprocessing_epoch_repetitions: 1
19
 
20
 
21
  # Logging
configs/style/lora-im2im.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_run_name: "lora-im2im"
2
+ output_dir: "/data/checkpoints/lora-im2im"
3
+
4
+ learning_rate: 1e-4
5
+ num_train_epochs: 1
6
+ max_train_steps: 1000
7
+ preprocessing_epoch_len: 33
8
+ preprocessing_epoch_repetitions: 31
9
+ num_validation_images: 2
10
+ num_sample_images: 2
11
+
12
+ source_type: "im2im"
13
+ style_title: "Simpsons"
14
+ csv_path: "/data/chatgpt-style-transfer-data/output/results.csv"
15
+ base_dir: "/data/chatgpt-style-transfer-data"
16
+ train_range: [2, 35]
17
+ test_range: [0, 2]
18
+ val_with: test
configs/style/{lora-1.yaml → lora-naive.yaml} RENAMED
@@ -1,4 +1,9 @@
1
  wandb_run_name: "lora-naive"
2
  output_dir: "/data/checkpoints/lora-naive"
3
 
4
- learning_rate: 4e-4
 
 
 
 
 
 
1
  wandb_run_name: "lora-naive"
2
  output_dir: "/data/checkpoints/lora-naive"
3
 
4
+ learning_rate: 4e-4
5
+
6
+ source_type: "naive"
7
+ data_dir: "/data/styles-finetune-data-artistic/tarot"
8
+ prompt: "<0001>"
9
+ ref_dir: "/data/image"
qwenimage/datamodels.py CHANGED
@@ -43,7 +43,19 @@ class QwenConfig(ExperimentTrainerParameters):
43
  static_mu: float | None = None
44
  loss_weight_dist: str | None = None # "scaled_clipped_gaussian", "logit-normal"
45
 
46
- vae_image_size: int = 1024 * 1024
47
  offload_text_encoder: bool = True
48
  quantize_text_encoder: bool = False
49
  quantize_transformer: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  static_mu: float | None = None
44
  loss_weight_dist: str | None = None # "scaled_clipped_gaussian", "logit-normal"
45
 
46
+ vae_image_size: int = 512 * 512
47
  offload_text_encoder: bool = True
48
  quantize_text_encoder: bool = False
49
  quantize_transformer: bool = False
50
+
51
+ source_type: str = "im2im"
52
+ style_title: str|None = None
53
+ base_dir: str|None = None
54
+ csv_path: str|None = None
55
+ data_dir: str|None = None
56
+ ref_dir: str|None = None
57
+ prompt: str|None = None
58
+ train_range: tuple[int|float,int|float]|None=None
59
+ test_range: tuple[int|float,int|float]|None=None
60
+ val_with: str = "train"
61
+
qwenimage/foundation.py CHANGED
@@ -8,14 +8,14 @@ from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline
8
  import torch
9
  from safetensors.torch import load_file, save_model
10
  import torch.nn.functional as F
 
11
  from einops import rearrange
12
 
13
  from qwenimage.datamodels import QwenConfig, QwenInputs
14
  from qwenimage.debug import ctimed, ftimed, print_gpu_memory, texam
15
  from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
16
  from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
17
- from qwenimage.models.encode_prompt import encode_prompt
18
- from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
19
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
20
  from qwenimage.optimization import simple_quantize_model
21
  from qwenimage.sampling import TimestepDistUtils
@@ -90,7 +90,6 @@ class QwenImageFoundation(WandModel):
90
  static_mu=self.config.static_mu,
91
  loss_weight_dist=self.config.loss_weight_dist,
92
  )
93
- self.static_prompt_embeds = None
94
 
95
  if self.config.quantize_text_encoder:
96
  quantize_text_encoder_int4wo_linear(self.text_encoder)
@@ -131,8 +130,14 @@ class QwenImageFoundation(WandModel):
131
 
132
  def pil_to_latents(self, images):
133
  image = self.pipe.image_processor.preprocess(images)
134
- print("pil_to_latents, image")
 
 
 
 
 
135
  texam(image)
 
136
  image = image.unsqueeze(2) # N, C, F=1, H, W
137
  image = image.to(device=self.device, dtype=self.dtype)
138
  latents = self.pipe.vae.encode(image).latent_dist.mode() # argmax
@@ -149,7 +154,7 @@ class QwenImageFoundation(WandModel):
149
  )
150
  latents = (latents - latents_mean) / latents_std
151
  latents = latents.squeeze(2)
152
- print("pil_to_latents, latents")
153
  texam(latents)
154
  return latents.to(dtype=self.dtype)
155
 
@@ -185,29 +190,19 @@ class QwenImageFoundation(WandModel):
185
  latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
186
  return latents
187
 
188
- def set_static_prompt(self, prompt:str):
189
- self.text_encoder.to(device=self.device)
190
- if self.text_encoder_device != "cuda":
191
- self.text_encoder_device = "cuda"
192
- with torch.no_grad():
193
- prompt_embeds, prompt_embeds_mask = encode_prompt(
194
- self.text_encoder,
195
- self.pipe.tokenizer,
196
- prompt,
197
- device=self.device,
198
- dtype=self.dtype,
199
- max_sequence_length = self.config.train_max_sequence_length,
200
- )
201
- prompt_embeds = prompt_embeds.cpu().clone().detach()
202
- prompt_embeds_mask = prompt_embeds_mask.cpu().clone().detach()
203
- self.static_prompt_embeds = (prompt_embeds, prompt_embeds_mask)
204
 
205
  @ftimed
206
  def preprocess_batch(self, batch):
207
  prompts = batch["text"]
 
 
 
 
 
208
 
209
- if self.static_prompt_embeds is not None:
210
- prompt_embeds, prompt_embeds_mask = self.static_prompt_embeds
 
211
 
212
  with ctimed("text_encoder.cuda()"):
213
  self.text_encoder.to(device=self.device)
@@ -215,26 +210,19 @@ class QwenImageFoundation(WandModel):
215
  self.text_encoder_device = "cuda"
216
 
217
  with torch.no_grad():
218
- prompt_embeds, prompt_embeds_mask = encode_prompt(
219
- self.text_encoder,
220
- self.pipe.tokenizer,
221
  prompts,
222
- device=self.device,
223
- dtype=self.dtype,
224
  max_sequence_length = self.config.train_max_sequence_length,
225
  )
226
- # prompt_embeds, prompt_embeds_mask = foundation.pipe.encode_prompt(
227
- # inps[i]["prompt"],
228
- # _transforms(inps[i]["image"][0]).mul(255),
229
- # device="cuda",
230
- # # dtype=foundation.dtype,
231
- # max_sequence_length = foundation.config.train_max_sequence_length,
232
- # )
233
  prompt_embeds = prompt_embeds.cpu().clone().detach()
234
  prompt_embeds_mask = prompt_embeds_mask.cpu().clone().detach()
235
 
236
 
237
  batch["prompt_embeds"] = (prompt_embeds, prompt_embeds_mask)
 
 
238
 
239
  return batch
240
 
@@ -253,25 +241,42 @@ class QwenImageFoundation(WandModel):
253
  prompt_embeds = prompt_embeds.to(device=self.device)
254
  prompt_embeds_mask = prompt_embeds_mask.to(device=self.device)
255
 
256
- images = batch["image"]
257
  x_0 = self.pil_to_latents(images).to(device=self.device, dtype=self.dtype)
258
  x_1 = torch.randn_like(x_0).to(device=self.device, dtype=self.dtype)
259
  seq_len = self.timestep_dist_utils.get_seq_len(x_0)
260
  batch_size = x_0.shape[0]
261
  t = self.timestep_dist_utils.get_train_t([batch_size], seq_len=seq_len).to(device=self.device, dtype=self.dtype)
262
  x_t = (1.0 - t) * x_0 + t * x_1
263
-
264
  x_t_1d = self.pack_latents(x_t)
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  l_height, l_width = x_0.shape[-2:]
 
267
  img_shapes = [
268
- [(1, l_height // 2, l_width // 2), ]
 
 
 
269
  ] * batch_size
270
  txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
271
  image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=x_0.device)
272
 
273
  v_pred_1d = self.transformer(
274
- hidden_states=x_t_1d,
275
  encoder_hidden_states=prompt_embeds,
276
  encoder_hidden_states_mask=prompt_embeds_mask,
277
  timestep=t,
@@ -279,6 +284,8 @@ class QwenImageFoundation(WandModel):
279
  return_dict=False,
280
  )[0]
281
 
 
 
282
  v_pred_2d = self.unpack_latents(v_pred_1d, h=l_height//2, w=l_width//2)
283
  v_gt_2d = x_1 - x_0
284
 
@@ -298,6 +305,11 @@ class QwenImageFoundation(WandModel):
298
  self.text_encoder.to(device=self.device)
299
  if self.text_encoder_device != "cuda":
300
  self.text_encoder_device = "cuda"
 
 
 
 
 
301
  return self.pipe(**inputs.model_dump()).images
302
 
303
 
 
8
  import torch
9
  from safetensors.torch import load_file, save_model
10
  import torch.nn.functional as F
11
+ import torchvision.transforms.v2.functional as TF
12
  from einops import rearrange
13
 
14
  from qwenimage.datamodels import QwenConfig, QwenInputs
15
  from qwenimage.debug import ctimed, ftimed, print_gpu_memory, texam
16
  from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
17
  from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
18
+ from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
 
19
  from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
20
  from qwenimage.optimization import simple_quantize_model
21
  from qwenimage.sampling import TimestepDistUtils
 
90
  static_mu=self.config.static_mu,
91
  loss_weight_dist=self.config.loss_weight_dist,
92
  )
 
93
 
94
  if self.config.quantize_text_encoder:
95
  quantize_text_encoder_int4wo_linear(self.text_encoder)
 
130
 
131
  def pil_to_latents(self, images):
132
  image = self.pipe.image_processor.preprocess(images)
133
+
134
+ h,w = image.shape[-2:]
135
+ h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
136
+ image = TF.resize(image, (h_r, w_r))
137
+
138
+ print("pil_to_latents.image")
139
  texam(image)
140
+
141
  image = image.unsqueeze(2) # N, C, F=1, H, W
142
  image = image.to(device=self.device, dtype=self.dtype)
143
  latents = self.pipe.vae.encode(image).latent_dist.mode() # argmax
 
154
  )
155
  latents = (latents - latents_mean) / latents_std
156
  latents = latents.squeeze(2)
157
+ print("pil_to_latents.latents")
158
  texam(latents)
159
  return latents.to(dtype=self.dtype)
160
 
 
190
  latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
191
  return latents
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  @ftimed
195
  def preprocess_batch(self, batch):
196
  prompts = batch["text"]
197
+ references = batch["reference"]
198
+
199
+ h,w = references.shape[-2:]
200
+ h_r, w_r = calculate_dimensions(CONDITION_IMAGE_SIZE, h/w)
201
+ references = TF.resize(references, (h_r, w_r))
202
 
203
+ print("preprocess_batch.references")
204
+ texam(references)
205
+
206
 
207
  with ctimed("text_encoder.cuda()"):
208
  self.text_encoder.to(device=self.device)
 
210
  self.text_encoder_device = "cuda"
211
 
212
  with torch.no_grad():
213
+ prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt(
 
 
214
  prompts,
215
+ references.mul(255), # scaled to RGB
216
+ device="cuda",
217
  max_sequence_length = self.config.train_max_sequence_length,
218
  )
 
 
 
 
 
 
 
219
  prompt_embeds = prompt_embeds.cpu().clone().detach()
220
  prompt_embeds_mask = prompt_embeds_mask.cpu().clone().detach()
221
 
222
 
223
  batch["prompt_embeds"] = (prompt_embeds, prompt_embeds_mask)
224
+ batch["reference"] = batch["reference"].cpu()
225
+ batch["image"] = batch["image"].cpu()
226
 
227
  return batch
228
 
 
241
  prompt_embeds = prompt_embeds.to(device=self.device)
242
  prompt_embeds_mask = prompt_embeds_mask.to(device=self.device)
243
 
244
+ images = batch["image"].to(device=self.device, dtype=self.dtype)
245
  x_0 = self.pil_to_latents(images).to(device=self.device, dtype=self.dtype)
246
  x_1 = torch.randn_like(x_0).to(device=self.device, dtype=self.dtype)
247
  seq_len = self.timestep_dist_utils.get_seq_len(x_0)
248
  batch_size = x_0.shape[0]
249
  t = self.timestep_dist_utils.get_train_t([batch_size], seq_len=seq_len).to(device=self.device, dtype=self.dtype)
250
  x_t = (1.0 - t) * x_0 + t * x_1
 
251
  x_t_1d = self.pack_latents(x_t)
252
 
253
+ references = batch["reference"].to(device=self.device, dtype=self.dtype)
254
+ print("references")
255
+ texam(references)
256
+ assert references.shape[0] == 1
257
+ refs = self.pil_to_latents(references).to(device=self.device, dtype=self.dtype)
258
+ refs_1d = self.pack_latents(refs)
259
+ print("refs refs_1d")
260
+ texam(refs)
261
+ texam(refs_1d)
262
+
263
+ inp_1d = torch.cat([x_t_1d, refs_1d], dim=1)
264
+ print("inp_1d")
265
+ texam(inp_1d)
266
+
267
  l_height, l_width = x_0.shape[-2:]
268
+ ref_height, ref_width = refs.shape[-2:]
269
  img_shapes = [
270
+ [
271
+ (1, l_height // 2, l_width // 2),
272
+ (1, ref_height // 2, ref_width // 2),
273
+ ]
274
  ] * batch_size
275
  txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
276
  image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=x_0.device)
277
 
278
  v_pred_1d = self.transformer(
279
+ hidden_states=inp_1d,
280
  encoder_hidden_states=prompt_embeds,
281
  encoder_hidden_states_mask=prompt_embeds_mask,
282
  timestep=t,
 
284
  return_dict=False,
285
  )[0]
286
 
287
+ v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)]
288
+
289
  v_pred_2d = self.unpack_latents(v_pred_1d, h=l_height//2, w=l_width//2)
290
  v_gt_2d = x_1 - x_0
291
 
 
305
  self.text_encoder.to(device=self.device)
306
  if self.text_encoder_device != "cuda":
307
  self.text_encoder_device = "cuda"
308
+ image = inputs.image[0]
309
+ w,h = image.size
310
+ h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
311
+ image = TF.resize(image, (h_r, w_r))
312
+ inputs.image = [image]
313
  return self.pipe(**inputs.model_dump()).images
314
 
315
 
qwenimage/{datasets.py → sources.py} RENAMED
@@ -1,5 +1,6 @@
1
 
2
 
 
3
  from pathlib import Path
4
  import random
5
 
@@ -54,3 +55,55 @@ class StyleSourceWithRandomRef(Source):
54
  rand_ref = random.choice(self.ref_images)
55
  ref_pil = Image.open(rand_ref).convert("RGB")
56
  return im_pil, self.prompt, ref_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
 
3
+ import csv
4
  from pathlib import Path
5
  import random
6
 
 
55
  rand_ref = random.choice(self.ref_images)
56
  ref_pil = Image.open(rand_ref).convert("RGB")
57
  return im_pil, self.prompt, ref_pil
58
+
59
+
60
+ class StyleImagetoImageSource(Source):
61
+ _data_types = [
62
+ SourceDataType(name="text", type=str),
63
+ SourceDataType(name="image", type=Image.Image),
64
+ SourceDataType(name="reference", type=Image.Image),
65
+ ]
66
+ def __init__(self, csv_path, base_dir, style_title=None, data_range:tuple[int|float,int|float]|None=None):
67
+ self.csv_path = Path(csv_path)
68
+ self.base_dir = Path(base_dir)
69
+ self.style_title = style_title
70
+ self.data = []
71
+
72
+ with open(self.csv_path, 'r', encoding='utf-8') as f:
73
+ reader = csv.DictReader(f)
74
+ for row in reader:
75
+ if self.style_title is not None and row['style_title'] != self.style_title:
76
+ continue
77
+
78
+ input_image = self.base_dir / row['input_image']
79
+ output_image = self.base_dir / row['output_image_path']
80
+ self.data.append({
81
+ 'input_image': input_image,
82
+ 'output_image': output_image,
83
+ 'style_title': row['style_title'],
84
+ 'prompt': row['prompt']
85
+ })
86
+
87
+ if data_range is not None:
88
+ left, right = data_range
89
+ if (isinstance(left, float) or isinstance(right, float)) and (left<1 and right<1):
90
+ left = left * len(self.data)
91
+ right = right * len(self.data)
92
+ remain_data = []
93
+ for i, d in enumerate(self.data):
94
+ if left <= i and i < right:
95
+ remain_data.append(d)
96
+ self.data = remain_data
97
+
98
+ print(f"{self.__class__} of len{len(self)}")
99
+
100
+
101
+ def __len__(self):
102
+ return len(self.data)
103
+
104
+ def __getitem__(self, idx):
105
+ item = self.data[idx]
106
+ prompt = item["prompt"]
107
+ input_pil = Image.open(item['input_image']).convert("RGB")
108
+ output_pil = Image.open(item['output_image']).convert("RGB")
109
+ return prompt, output_pil, input_pil
qwenimage/task.py CHANGED
@@ -11,7 +11,7 @@ image_transforms = T.Compose([
11
  RemoveAlphaTransform(bg_color_rgb=(34, 34, 34)),
12
  T.ToImage(),
13
  T.RGB(),
14
- RandomDownsize(sizes=(384, 512, 768)),
15
  T.ToDtype(torch.float, scale=True),
16
  ])
17
 
 
11
  RemoveAlphaTransform(bg_color_rgb=(34, 34, 34)),
12
  T.ToImage(),
13
  T.RGB(),
14
+ # RandomDownsize(sizes=(384, 512, 768)),
15
  T.ToDtype(torch.float, scale=True),
16
  ])
17
 
qwenimage/training.py CHANGED
@@ -17,7 +17,7 @@ from wandml.trainers.experiment_trainer import ExperimentTrainer
17
 
18
 
19
  from qwenimage.finetuner import QwenLoraFinetuner
20
- from qwenimage.datasets import StyleSourceWithRandomRef
21
  from qwenimage.task import TextToImageWithRefTask
22
  from qwenimage.datamodels import QwenConfig
23
  from qwenimage.foundation import QwenImageFoundation
@@ -50,11 +50,32 @@ def run_training(config_path: Path | str, update_config_paths: list[Path] | None
50
  )
51
 
52
  # Data
53
- src = StyleSourceWithRandomRef("/data/styles-finetune-data-artistic/tarot", "<0001>", "/data/image", set_len=1000)
54
- task = TextToImageWithRefTask()
55
  dp = WandDataPipe()
56
- dp.add_source(src)
57
- dp.set_task(task)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  # Model
@@ -63,7 +84,21 @@ def run_training(config_path: Path | str, update_config_paths: list[Path] | None
63
  finetuner.load(None)
64
 
65
 
66
- trainer = ExperimentTrainer(foundation,dp,config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  trainer.train()
68
 
69
 
 
17
 
18
 
19
  from qwenimage.finetuner import QwenLoraFinetuner
20
+ from qwenimage.sources import StyleSourceWithRandomRef, StyleImagetoImageSource
21
  from qwenimage.task import TextToImageWithRefTask
22
  from qwenimage.datamodels import QwenConfig
23
  from qwenimage.foundation import QwenImageFoundation
 
50
  )
51
 
52
  # Data
 
 
53
  dp = WandDataPipe()
54
+ dp.set_task(TextToImageWithRefTask())
55
+ dp_test = WandDataPipe()
56
+ dp_test.set_task(TextToImageWithRefTask())
57
+ if config.source_type == "naive":
58
+ src = StyleSourceWithRandomRef(
59
+ config.data_dir, config.prompt, config.ref_dir, set_len=config.max_train_steps
60
+ )
61
+ dp.add_source(src)
62
+ elif config.source_type == "im2im":
63
+ src = StyleImagetoImageSource(
64
+ csv_path=config.csv_path,
65
+ base_dir=config.base_dir,
66
+ style_title=config.style_title,
67
+ data_range=config.train_range,
68
+ )
69
+ dp.add_source(src)
70
+ src_test = StyleImagetoImageSource(
71
+ csv_path=config.csv_path,
72
+ base_dir=config.base_dir,
73
+ style_title=config.style_title,
74
+ data_range=config.test_range,
75
+ )
76
+ dp_test.add_source(src_test)
77
+ else:
78
+ raise ValueError()
79
 
80
 
81
  # Model
 
84
  finetuner.load(None)
85
 
86
 
87
+ if len(dp_test) == 0:
88
+ dp_test = None
89
+ if config.val_with == "train":
90
+ dp_val = dp
91
+ elif config.val_with == "test":
92
+ dp_val = dp_test
93
+ else:
94
+ raise ValueError()
95
+ trainer = ExperimentTrainer(
96
+ model=foundation,
97
+ datapipe=dp,
98
+ args=config,
99
+ validation_datapipe=dp_val,
100
+ test_datapipe=dp_test,
101
+ )
102
  trainer.train()
103
 
104
 
scripts/train.ipynb CHANGED
@@ -30,16 +30,16 @@
30
  "text": [
31
  "/usr/lib/python3/dist-packages/sklearn/utils/fixes.py:25: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
32
  " from pkg_resources import parse_version # type: ignore\n",
33
- "2025-11-22 18:13:10.673389: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
34
- "2025-11-22 18:13:10.687858: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
35
  "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
36
- "E0000 00:00:1763835190.705243 2236633 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
37
- "E0000 00:00:1763835190.710795 2236633 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
38
- "W0000 00:00:1763835190.724588 2236633 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
39
- "W0000 00:00:1763835190.724603 2236633 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
40
- "W0000 00:00:1763835190.724605 2236633 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
41
- "W0000 00:00:1763835190.724607 2236633 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
42
- "2025-11-22 18:13:10.729261: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
43
  "To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
44
  ]
45
  },
@@ -129,6 +129,20 @@
129
  "and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n",
130
  "\n"
131
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  }
133
  ],
134
  "source": [
@@ -137,15 +151,70 @@
137
  "from pathlib import Path\n",
138
  "import argparse\n",
139
  "\n",
140
- "from ruamel.yaml import YAML\n",
141
  "import diffusers\n",
142
  "\n",
143
  "\n",
144
  "from wandml.trainers.experiment_trainer import ExperimentTrainer\n",
145
  "from wandml import WandDataPipe\n",
146
  "import wandml\n",
 
 
 
 
 
147
  "\n",
148
- "from qwenimage.finetuner import QwenLoraFinetuner\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ]
150
  },
151
  {
@@ -521,7 +590,7 @@
521
  "name": "python",
522
  "nbconvert_exporter": "python",
523
  "pygments_lexer": "ipython3",
524
- "version": "3.9.6"
525
  }
526
  },
527
  "nbformat": 4,
 
30
  "text": [
31
  "/usr/lib/python3/dist-packages/sklearn/utils/fixes.py:25: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
32
  " from pkg_resources import parse_version # type: ignore\n",
33
+ "2025-11-23 10:48:20.190181: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
34
+ "2025-11-23 10:48:20.204255: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
35
  "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
36
+ "E0000 00:00:1763894900.221429 2465541 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
37
+ "E0000 00:00:1763894900.227066 2465541 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
38
+ "W0000 00:00:1763894900.240375 2465541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
39
+ "W0000 00:00:1763894900.240390 2465541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
40
+ "W0000 00:00:1763894900.240392 2465541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
41
+ "W0000 00:00:1763894900.240394 2465541 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
42
+ "2025-11-23 10:48:20.244577: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
43
  "To enable the following instructions: AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
44
  ]
45
  },
 
129
  "and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n",
130
  "\n"
131
  ]
132
+ },
133
+ {
134
+ "data": {
135
+ "application/vnd.jupyter.widget-view+json": {
136
+ "model_id": "f70e31b9ba79496a921f0e7d0cddfed4",
137
+ "version_major": 2,
138
+ "version_minor": 0
139
+ },
140
+ "text/plain": [
141
+ "Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
142
+ ]
143
+ },
144
+ "metadata": {},
145
+ "output_type": "display_data"
146
  }
147
  ],
148
  "source": [
 
151
  "from pathlib import Path\n",
152
  "import argparse\n",
153
  "\n",
154
+ "import yaml\n",
155
  "import diffusers\n",
156
  "\n",
157
  "\n",
158
  "from wandml.trainers.experiment_trainer import ExperimentTrainer\n",
159
  "from wandml import WandDataPipe\n",
160
  "import wandml\n",
161
+ "from wandml import WandAuth\n",
162
+ "from wandml import utils as wandml_utils\n",
163
+ "from wandml.trainers.datamodels import ExperimentTrainerParameters\n",
164
+ "from wandml.trainers.experiment_trainer import ExperimentTrainer\n",
165
+ "\n",
166
  "\n",
167
+ "from qwenimage.finetuner import QwenLoraFinetuner\n",
168
+ "from qwenimage.sources import StyleSourceWithRandomRef, StyleImagetoImageSource\n",
169
+ "from qwenimage.task import TextToImageWithRefTask\n",
170
+ "from qwenimage.datamodels import QwenConfig\n",
171
+ "from qwenimage.foundation import QwenImageFoundation\n",
172
+ "\n"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 3,
178
+ "id": "18bf116a",
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "name": "stdout",
183
+ "output_type": "stream",
184
+ "text": [
185
+ "<class 'qwenimage.sources.StyleImagetoImageSource'> of len2\n"
186
+ ]
187
+ }
188
+ ],
189
+ "source": [
190
+ "src = StyleImagetoImageSource(\n",
191
+ " csv_path=\"/data/chatgpt-style-transfer-data/output/results.csv\",\n",
192
+ " base_dir=\"/data/chatgpt-style-transfer-data\",\n",
193
+ " style_title=\"Simpsons\",\n",
194
+ " data_range=[2, 35],\n",
195
+ ")"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": 4,
201
+ "metadata": {},
202
+ "outputs": [
203
+ {
204
+ "name": "stdout",
205
+ "output_type": "stream",
206
+ "text": [
207
+ "<class 'qwenimage.sources.StyleImagetoImageSource'> of len33\n"
208
+ ]
209
+ }
210
+ ],
211
+ "source": [
212
+ "src = StyleImagetoImageSource(\n",
213
+ " csv_path=\"/data/chatgpt-style-transfer-data/output/results.csv\",\n",
214
+ " base_dir=\"/data/chatgpt-style-transfer-data\",\n",
215
+ " style_title=\"Simpsons\",\n",
216
+ " data_range=[0, 2],\n",
217
+ ")"
218
  ]
219
  },
220
  {
 
590
  "name": "python",
591
  "nbconvert_exporter": "python",
592
  "pygments_lexer": "ipython3",
593
+ "version": "3.10.12"
594
  }
595
  },
596
  "nbformat": 4,