jamesaasher commited on
Commit
c997f38
Β·
verified Β·
1 Parent(s): 2ccfe3a

Upload generate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate.py +122 -0
generate.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generation script for text-conditional diffusion model."""
2
+ import torch
3
+ import argparse
4
+ import os
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+
8
+ import config
9
+ from model import TextConditionedUNet
10
+ from scheduler import SimpleDDPMScheduler
11
+ from text_encoder import CLIPTextEncoder
12
+
13
+
14
+ def tensor_to_image(tensor):
15
+ """Convert tensor to PIL Image."""
16
+ # tensor is in range [-1, 1], convert to [0, 1]
17
+ tensor = (tensor + 1.0) / 2.0
18
+ tensor = torch.clamp(tensor, 0, 1)
19
+
20
+ # Convert to PIL
21
+ transform = transforms.ToPILImage()
22
+ return transform(tensor.squeeze(0))
23
+
24
+
25
+ def generate_samples(checkpoint_path, prompt="a drawing of a cat", num_samples=4, guidance_scale=3.0, device='cuda'):
26
+ """Generate samples using text prompts with classifier-free guidance.
27
+
28
+ Args:
29
+ checkpoint_path: Path to model checkpoint
30
+ prompt: Text prompt for generation
31
+ num_samples: Number of samples to generate
32
+ guidance_scale: CFG scale (1.0 = no guidance, 3.0-7.0 typical, higher = stronger)
33
+ device: Device to use
34
+ """
35
+ print(f"🎨 Generating {num_samples} samples with prompt: '{prompt}'")
36
+ print(f"πŸ“Š Guidance scale: {guidance_scale}")
37
+
38
+ # Load checkpoint
39
+ if not os.path.exists(checkpoint_path):
40
+ print(f"❌ Checkpoint not found: {checkpoint_path}")
41
+ return
42
+
43
+ print(f"πŸ“‚ Loading checkpoint: {checkpoint_path}")
44
+ checkpoint = torch.load(checkpoint_path, map_location=device)
45
+
46
+ # Get config from checkpoint
47
+ ckpt_config = checkpoint.get('config', {})
48
+ text_dim = ckpt_config.get('text_dim', config.TEXT_DIM)
49
+ clip_model = ckpt_config.get('clip_model', config.CLIP_MODEL)
50
+
51
+ # Create model
52
+ model = TextConditionedUNet(text_dim=text_dim).to(device)
53
+ model.load_state_dict(checkpoint['model_state_dict'])
54
+ model.eval()
55
+
56
+ # Create text encoder
57
+ text_encoder = CLIPTextEncoder(model_name=clip_model, freeze=True).to(device)
58
+ text_encoder.eval()
59
+
60
+ # Create scheduler
61
+ scheduler = SimpleDDPMScheduler(config.TIMESTEPS)
62
+
63
+ print(f"πŸ“Š Model loaded (text_dim={text_dim})")
64
+ print(f"πŸ“Š CLIP model: {clip_model}")
65
+
66
+ # Encode the text prompt once
67
+ with torch.no_grad():
68
+ text_embedding = text_encoder(prompt)
69
+ # Repeat for batch generation
70
+ text_embeddings = text_embedding.repeat(num_samples, 1)
71
+
72
+ # Create outputs directory
73
+ os.makedirs("outputs", exist_ok=True)
74
+
75
+ # Generate samples
76
+ print(f"🎨 Generating {num_samples} samples...")
77
+ with torch.no_grad():
78
+ # Generate all samples in a batch
79
+ shape = (num_samples, 1, config.IMAGE_SIZE, config.IMAGE_SIZE)
80
+ samples = scheduler.sample_text(model, shape, text_embeddings, device, guidance_scale)
81
+
82
+ # Save each sample
83
+ for i in range(num_samples):
84
+ # Create safe filename from prompt
85
+ safe_prompt = "".join(c if c.isalnum() or c in " _-" else "" for c in prompt)
86
+ safe_prompt = safe_prompt.replace(" ", "_")[:50] # Limit length
87
+ sample_name = f"text_sample_{i+1}_{safe_prompt}"
88
+
89
+ # Convert to image and save
90
+ img = tensor_to_image(samples[i])
91
+ img_path = f"outputs/{sample_name}.png"
92
+ img.save(img_path)
93
+ print(f"πŸ’Ύ Saved: {img_path}")
94
+
95
+ print("βœ… Generation complete!")
96
+
97
+
98
+ def main():
99
+ parser = argparse.ArgumentParser(description='Generate samples from text-conditional diffusion model')
100
+ parser.add_argument('--checkpoint', type=str, required=True,
101
+ help='Path to checkpoint file')
102
+ parser.add_argument('--prompt', type=str, default="a drawing of a cat and dog",
103
+ help='Text prompt for generation')
104
+ parser.add_argument('--num-samples', type=int, default=4,
105
+ help='Number of samples to generate (default: 4)')
106
+ parser.add_argument('--guidance-scale', type=float, default=config.CFG_GUIDANCE_SCALE,
107
+ help=f'Classifier-free guidance scale (1.0 = no guidance, 3.0-7.0 typical, default: {config.CFG_GUIDANCE_SCALE})')
108
+ parser.add_argument('--device', type=str, default='cuda',
109
+ help='Device to use (default: cuda)')
110
+
111
+ args = parser.parse_args()
112
+
113
+ # Check device availability
114
+ if args.device == 'cuda' and not torch.cuda.is_available():
115
+ print("⚠️ CUDA not available, using CPU")
116
+ args.device = 'cpu'
117
+
118
+ generate_samples(args.checkpoint, args.prompt, args.num_samples, args.guidance_scale, args.device)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ main()