fix: old SD3Transformer2DModel has no attributes
#2
by
KevinZonda - opened
- deepgen_pipeline.py +25 -6
deepgen_pipeline.py
CHANGED
|
@@ -961,12 +961,31 @@ class DeepGenPipeline(DiffusionPipeline):
|
|
| 961 |
with cond_latents support for image editing. No weight copying needed."""
|
| 962 |
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3
|
| 963 |
if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel):
|
| 964 |
-
self.transformer.
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
|
| 971 |
def _resolve_pretrained_path(self):
|
| 972 |
path = self.config._name_or_path
|
|
|
|
| 961 |
with cond_latents support for image editing. No weight copying needed."""
|
| 962 |
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3
|
| 963 |
if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel):
|
| 964 |
+
state_dict = self.transformer.state_dict()
|
| 965 |
+
config = self.transformer.config
|
| 966 |
+
custom_config = {
|
| 967 |
+
'sample_size': config.sample_size,
|
| 968 |
+
'patch_size': config.patch_size,
|
| 969 |
+
'in_channels': config.in_channels,
|
| 970 |
+
'num_layers': config.num_layers,
|
| 971 |
+
'attention_head_dim': config.attention_head_dim,
|
| 972 |
+
'num_attention_heads': config.num_attention_heads,
|
| 973 |
+
'joint_attention_dim': config.joint_attention_dim,
|
| 974 |
+
'caption_projection_dim': config.caption_projection_dim,
|
| 975 |
+
'pooled_projection_dim': config.pooled_projection_dim,
|
| 976 |
+
'out_channels': config.out_channels,
|
| 977 |
+
'pos_embed_max_size': getattr(config, 'pos_embed_max_size', 96),
|
| 978 |
+
'dual_attention_layers': getattr(config, 'dual_attention_layers', ()),
|
| 979 |
+
'qk_norm': getattr(config, 'qk_norm', None),
|
| 980 |
+
}
|
| 981 |
+
device = self.transformer.device
|
| 982 |
+
dtype = self.transformer.dtype
|
| 983 |
+
self.transformer = SD3Transformer2DModel(**custom_config).to(device=device, dtype=dtype)
|
| 984 |
+
self.transformer.load_state_dict(state_dict, strict=False)
|
| 985 |
+
|
| 986 |
+
# Set gradient checkpointing if it was enabled
|
| 987 |
+
if hasattr(self.transformer, 'gradient_checkpointing'):
|
| 988 |
+
self.transformer.gradient_checkpointing = self.transformer.gradient_checkpointing
|
| 989 |
|
| 990 |
def _resolve_pretrained_path(self):
|
| 991 |
path = self.config._name_or_path
|