Text-to-Image
Diffusers
Safetensors

fix: old SD3Transformer2DModel has no attributes

#2
by KevinZonda - opened
Files changed (1) hide show
  1. deepgen_pipeline.py +25 -6
deepgen_pipeline.py CHANGED
@@ -961,12 +961,31 @@ class DeepGenPipeline(DiffusionPipeline):
961
  with cond_latents support for image editing. No weight copying needed."""
962
  from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3
963
  if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel):
964
- self.transformer.__class__ = SD3Transformer2DModel
965
- for block in self.transformer.transformer_blocks:
966
- block.__class__ = CustomJointTransformerBlock
967
- block.attn.set_processor(CustomJointAttnProcessor2_0())
968
- if block.attn2 is not None:
969
- block.attn2.set_processor(CustomJointAttnProcessor2_0())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
 
971
  def _resolve_pretrained_path(self):
972
  path = self.config._name_or_path
 
961
  with cond_latents support for image editing. No weight copying needed."""
962
  from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3
963
  if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel):
964
+ state_dict = self.transformer.state_dict()
965
+ config = self.transformer.config
966
+ custom_config = {
967
+ 'sample_size': config.sample_size,
968
+ 'patch_size': config.patch_size,
969
+ 'in_channels': config.in_channels,
970
+ 'num_layers': config.num_layers,
971
+ 'attention_head_dim': config.attention_head_dim,
972
+ 'num_attention_heads': config.num_attention_heads,
973
+ 'joint_attention_dim': config.joint_attention_dim,
974
+ 'caption_projection_dim': config.caption_projection_dim,
975
+ 'pooled_projection_dim': config.pooled_projection_dim,
976
+ 'out_channels': config.out_channels,
977
+ 'pos_embed_max_size': getattr(config, 'pos_embed_max_size', 96),
978
+ 'dual_attention_layers': getattr(config, 'dual_attention_layers', ()),
979
+ 'qk_norm': getattr(config, 'qk_norm', None),
980
+ }
981
+ device = self.transformer.device
982
+ dtype = self.transformer.dtype
983
+ self.transformer = SD3Transformer2DModel(**custom_config).to(device=device, dtype=dtype)
984
+ self.transformer.load_state_dict(state_dict, strict=False)
985
+
986
+ # Set gradient checkpointing if it was enabled
987
+ if hasattr(self.transformer, 'gradient_checkpointing'):
988
+ self.transformer.gradient_checkpointing = self.transformer.gradient_checkpointing
989
 
990
  def _resolve_pretrained_path(self):
991
  path = self.config._name_or_path