#!/usr/bin/env python3 """ Convert SimpleTuner LoRA weights to diffusers-compatible format for AuraFlow. This script converts LoRA weights saved by SimpleTuner into a format that can be directly loaded by diffusers' load_lora_weights() method. Usage: python convert_simpletuner_lora.py Example: python convert_simpletuner_lora.py input_lora.safetensors diffusers_compatible_lora.safetensors """ import argparse import sys from pathlib import Path from typing import Dict import safetensors.torch import torch def detect_lora_format(state_dict: Dict[str, torch.Tensor]) -> str: """ Detect the format of the LoRA state dict. Returns: "peft" if already in PEFT/diffusers format "mixed" if mixed format (some lora_A/B, some lora.down/up) "simpletuner_transformer" if in SimpleTuner format with transformer prefix "simpletuner_auraflow" if in SimpleTuner AuraFlow format "kohya" if in Kohya format "unknown" otherwise """ keys = list(state_dict.keys()) # Check the actual weight naming convention (lora_A/lora_B vs lora_down/lora_up) has_lora_a_b = any((".lora_A." in k or ".lora_B." in k) for k in keys) has_lora_down_up = any((".lora_down." in k or ".lora_up." in k) for k in keys) has_lora_dot_down_up = any((".lora.down." in k or ".lora.up." in k) for k in keys) # Check prefixes has_transformer_prefix = any(k.startswith("transformer.") for k in keys) has_lora_transformer_prefix = any(k.startswith("lora_transformer_") for k in keys) has_lora_unet_prefix = any(k.startswith("lora_unet_") for k in keys) # Mixed format: has both lora_A/B AND lora.down/up (SimpleTuner hybrid) if has_transformer_prefix and has_lora_a_b and (has_lora_down_up or has_lora_dot_down_up): return "mixed" # Pure PEFT format: transformer.* with ONLY lora_A/lora_B if has_transformer_prefix and has_lora_a_b and not has_lora_down_up and not has_lora_dot_down_up: return "peft" # SimpleTuner with transformer prefix but old naming: transformer.* with lora_down/lora_up if has_transformer_prefix and (has_lora_down_up or has_lora_dot_down_up): return "simpletuner_transformer" # SimpleTuner AuraFlow format: lora_transformer_* with lora_down/lora_up if has_lora_transformer_prefix and has_lora_down_up: return "simpletuner_auraflow" # Traditional Kohya format: lora_unet_* with lora_down/lora_up if has_lora_unet_prefix and has_lora_down_up: return "kohya" return "unknown" def convert_mixed_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert mixed LoRA format to pure PEFT format. SimpleTuner sometimes saves a hybrid format where some layers use lora_A/lora_B and others use .lora.down./.lora.up. This converts all to lora_A/lora_B. """ new_state_dict = {} converted_count = 0 kept_count = 0 skipped_count = 0 renames = [] # Get all keys all_keys = sorted(state_dict.keys()) print("\nProcessing keys:") print("-" * 80) for key in all_keys: # Already in correct format (lora_A or lora_B) if ".lora_A." in key or ".lora_B." in key: new_state_dict[key] = state_dict[key] kept_count += 1 # Needs conversion: .lora.down. -> .lora_A. elif ".lora.down.weight" in key: new_key = key.replace(".lora.down.weight", ".lora_A.weight") new_state_dict[new_key] = state_dict[key] renames.append((key, new_key)) converted_count += 1 # Needs conversion: .lora.up. -> .lora_B. elif ".lora.up.weight" in key: new_key = key.replace(".lora.up.weight", ".lora_B.weight") new_state_dict[new_key] = state_dict[key] renames.append((key, new_key)) converted_count += 1 # Skip alpha keys (not used in PEFT format) elif ".alpha" in key: skipped_count += 1 continue # Other keys (shouldn't happen, but keep them just in case) else: new_state_dict[key] = state_dict[key] print(f"⚠ Warning: Unexpected key format: {key}") print(f"\nSummary:") print(f" ✓ Kept {kept_count} keys already in correct format (lora_A/lora_B)") print(f" ✓ Converted {converted_count} keys from .lora.down/.lora.up to lora_A/lora_B") print(f" ✓ Skipped {skipped_count} alpha keys") if renames: print(f"\nRenames applied ({len(renames)} conversions):") print("-" * 80) for old_key, new_key in renames: # Show the difference more clearly if ".lora.down.weight" in old_key: layer = old_key.replace(".lora.down.weight", "") print(f" {layer}") print(f" .lora.down.weight → .lora_A.weight") elif ".lora.up.weight" in old_key: layer = old_key.replace(".lora.up.weight", "") print(f" {layer}") print(f" .lora.up.weight → .lora_B.weight") return new_state_dict def convert_simpletuner_transformer_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert SimpleTuner transformer format (already has transformer. prefix but uses lora_down/lora_up) to diffusers PEFT format (transformer. prefix with lora_A/lora_B). This is a simpler conversion since the key structure is already correct. """ new_state_dict = {} renames = [] # Get all unique LoRA layer base names (without .lora_down/.lora_up/.alpha suffix) all_keys = list(state_dict.keys()) base_keys = set() for key in all_keys: if ".lora_down.weight" in key: base_key = key.replace(".lora_down.weight", "") base_keys.add(base_key) print(f"\nFound {len(base_keys)} LoRA layers to convert") print("-" * 80) # Convert each layer for base_key in sorted(base_keys): down_key = f"{base_key}.lora_down.weight" up_key = f"{base_key}.lora_up.weight" alpha_key = f"{base_key}.alpha" if down_key not in state_dict or up_key not in state_dict: print(f"⚠ Warning: Missing weights for {base_key}") continue down_weight = state_dict.pop(down_key) up_weight = state_dict.pop(up_key) # Handle alpha scaling has_alpha = False if alpha_key in state_dict: alpha = state_dict.pop(alpha_key) lora_rank = down_weight.shape[0] scale = alpha / lora_rank # Calculate scale_down and scale_up to preserve the scale value scale_down = scale scale_up = 1.0 while scale_down * 2 < scale_up: scale_down *= 2 scale_up /= 2 down_weight = down_weight * scale_down up_weight = up_weight * scale_up has_alpha = True # Store in PEFT format (lora_A = down, lora_B = up) new_down_key = f"{base_key}.lora_A.weight" new_up_key = f"{base_key}.lora_B.weight" new_state_dict[new_down_key] = down_weight new_state_dict[new_up_key] = up_weight renames.append((down_key, new_down_key, has_alpha)) renames.append((up_key, new_up_key, has_alpha)) # Check for any remaining keys remaining = [k for k in state_dict.keys() if not k.startswith("text_encoder")] if remaining: print(f"⚠ Warning: {len(remaining)} keys were not converted: {remaining[:5]}") print(f"\nRenames applied ({len(renames)} conversions):") print("-" * 80) # Group by layer current_layer = None for old_key, new_key, has_alpha in renames: layer = old_key.replace(".lora_down.weight", "").replace(".lora_up.weight", "") if layer != current_layer: alpha_str = " (alpha scaled)" if has_alpha else "" print(f"\n {layer}{alpha_str}") current_layer = layer if ".lora_down.weight" in old_key: print(f" .lora_down.weight → .lora_A.weight") elif ".lora_up.weight" in old_key: print(f" .lora_up.weight → .lora_B.weight") return new_state_dict def convert_simpletuner_auraflow_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert SimpleTuner AuraFlow LoRA format to diffusers PEFT format. SimpleTuner typically saves LoRAs in a format similar to Kohya's sd-scripts, but for transformer-based models like AuraFlow, the keys may differ. """ new_state_dict = {} def _convert(original_key, diffusers_key, state_dict, new_state_dict): """Helper to convert a single LoRA layer.""" down_key = f"{original_key}.lora_down.weight" if down_key not in state_dict: return False down_weight = state_dict.pop(down_key) lora_rank = down_weight.shape[0] up_weight_key = f"{original_key}.lora_up.weight" up_weight = state_dict.pop(up_weight_key) # Handle alpha scaling alpha_key = f"{original_key}.alpha" if alpha_key in state_dict: alpha = state_dict.pop(alpha_key) scale = alpha / lora_rank # Calculate scale_down and scale_up to preserve the scale value scale_down = scale scale_up = 1.0 while scale_down * 2 < scale_up: scale_down *= 2 scale_up /= 2 down_weight = down_weight * scale_down up_weight = up_weight * scale_up # Store in PEFT format (lora_A = down, lora_B = up) diffusers_down_key = f"{diffusers_key}.lora_A.weight" new_state_dict[diffusers_down_key] = down_weight new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight return True # Get all unique LoRA layer names all_unique_keys = { k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict if ".lora_down.weight" in k or ".lora_up.weight" in k or ".alpha" in k } # Process transformer blocks for original_key in sorted(all_unique_keys): if original_key.startswith("lora_transformer_single_transformer_blocks_"): # Single transformer blocks parts = original_key.split("lora_transformer_single_transformer_blocks_")[-1].split("_") block_idx = int(parts[0]) diffusers_key = f"single_transformer_blocks.{block_idx}" # Map the rest of the key remaining = "_".join(parts[1:]) if "attn_to_q" in remaining: diffusers_key += ".attn.to_q" elif "attn_to_k" in remaining: diffusers_key += ".attn.to_k" elif "attn_to_v" in remaining: diffusers_key += ".attn.to_v" elif "proj_out" in remaining: diffusers_key += ".proj_out" elif "proj_mlp" in remaining: diffusers_key += ".proj_mlp" elif "norm_linear" in remaining: diffusers_key += ".norm.linear" else: print(f"Warning: Unhandled single block key pattern: {original_key}") continue elif original_key.startswith("lora_transformer_transformer_blocks_"): # Double transformer blocks parts = original_key.split("lora_transformer_transformer_blocks_")[-1].split("_") block_idx = int(parts[0]) diffusers_key = f"transformer_blocks.{block_idx}" # Map the rest of the key remaining = "_".join(parts[1:]) if "attn_to_out_0" in remaining: diffusers_key += ".attn.to_out.0" elif "attn_to_add_out" in remaining: diffusers_key += ".attn.to_add_out" elif "attn_to_q" in remaining: diffusers_key += ".attn.to_q" elif "attn_to_k" in remaining: diffusers_key += ".attn.to_k" elif "attn_to_v" in remaining: diffusers_key += ".attn.to_v" elif "attn_add_q_proj" in remaining: diffusers_key += ".attn.add_q_proj" elif "attn_add_k_proj" in remaining: diffusers_key += ".attn.add_k_proj" elif "attn_add_v_proj" in remaining: diffusers_key += ".attn.add_v_proj" elif "ff_net_0_proj" in remaining: diffusers_key += ".ff.net.0.proj" elif "ff_net_2" in remaining: diffusers_key += ".ff.net.2" elif "ff_context_net_0_proj" in remaining: diffusers_key += ".ff_context.net.0.proj" elif "ff_context_net_2" in remaining: diffusers_key += ".ff_context.net.2" elif "norm1_linear" in remaining: diffusers_key += ".norm1.linear" elif "norm1_context_linear" in remaining: diffusers_key += ".norm1_context.linear" else: print(f"Warning: Unhandled double block key pattern: {original_key}") continue elif original_key.startswith("lora_te1_") or original_key.startswith("lora_te_"): # Text encoder keys - handle separately print(f"Found text encoder key: {original_key}") continue else: print(f"Warning: Unknown key pattern: {original_key}") continue # Perform the conversion _convert(original_key, diffusers_key, state_dict, new_state_dict) # Add "transformer." prefix to all keys transformer_state_dict = { f"transformer.{k}": v for k, v in new_state_dict.items() if not k.startswith("text_model.") } # Check for remaining unconverted keys if len(state_dict) > 0: remaining_keys = [k for k in state_dict.keys() if not k.startswith("lora_te")] if remaining_keys: print(f"Warning: Some keys were not converted: {remaining_keys[:10]}") return transformer_state_dict def convert_lora(input_path: str, output_path: str) -> None: """ Main conversion function. Args: input_path: Path to input LoRA safetensors file output_path: Path to output diffusers-compatible safetensors file """ print(f"Loading LoRA from: {input_path}") state_dict = safetensors.torch.load_file(input_path) print(f"Detecting LoRA format...") format_type = detect_lora_format(state_dict) print(f"Detected format: {format_type}") if format_type == "peft": print("LoRA is already in diffusers-compatible PEFT format!") print("No conversion needed. Copying file...") import shutil shutil.copy(input_path, output_path) return elif format_type == "mixed": print("Converting MIXED format LoRA to pure PEFT format...") print("(Some layers use lora_A/B, others use .lora.down/.lora.up)") converted_state_dict = convert_mixed_lora_to_diffusers(state_dict.copy()) elif format_type == "simpletuner_transformer": print("Converting SimpleTuner transformer format to diffusers...") print("(has transformer. prefix but uses lora_down/lora_up naming)") converted_state_dict = convert_simpletuner_transformer_to_diffusers(state_dict.copy()) elif format_type == "simpletuner_auraflow": print("Converting SimpleTuner AuraFlow format to diffusers...") converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) elif format_type == "kohya": print("Note: Detected Kohya format. This converter is optimized for AuraFlow.") print("For other models, diffusers has built-in conversion.") converted_state_dict = convert_simpletuner_auraflow_to_diffusers(state_dict.copy()) else: print("Error: Unknown LoRA format!") print("Sample keys from the state dict:") for i, key in enumerate(list(state_dict.keys())[:20]): print(f" {key}") sys.exit(1) print(f"Saving converted LoRA to: {output_path}") safetensors.torch.save_file(converted_state_dict, output_path) print("\nConversion complete!") print(f"Original keys: {len(state_dict)}") print(f"Converted keys: {len(converted_state_dict)}") def main(): parser = argparse.ArgumentParser( description="Convert SimpleTuner LoRA to diffusers-compatible format", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Convert a SimpleTuner LoRA for AuraFlow python convert_simpletuner_lora.py my_lora.safetensors diffusers_lora.safetensors # Check format without converting python convert_simpletuner_lora.py my_lora.safetensors /tmp/test.safetensors """ ) parser.add_argument( "input", type=str, help="Input LoRA file (SimpleTuner format)" ) parser.add_argument( "output", type=str, help="Output LoRA file (diffusers-compatible format)" ) parser.add_argument( "--dry-run", action="store_true", help="Only detect format, don't convert" ) args = parser.parse_args() # Validate input file exists if not Path(args.input).exists(): print(f"Error: Input file not found: {args.input}") sys.exit(1) if args.dry_run: print(f"Loading LoRA from: {args.input}") state_dict = safetensors.torch.load_file(args.input) format_type = detect_lora_format(state_dict) print(f"Detected format: {format_type}") print(f"\nSample keys ({min(10, len(state_dict))} of {len(state_dict)}):") for key in list(state_dict.keys())[:10]: print(f" {key}") return convert_lora(args.input, args.output) if __name__ == "__main__": main()