jamesaasher commited on
Commit
2ccfe3a
·
verified ·
1 Parent(s): 369d21b

Upload text_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. text_encoder.py +59 -0
text_encoder.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLIP Text Encoder for text-conditional diffusion."""
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+
6
+
7
+ class CLIPTextEncoder(nn.Module):
8
+ """Wrapper around CLIP text encoder for diffusion conditioning.
9
+
10
+ Clip effectively maps images and text to the same latent space.
11
+
12
+ """
13
+
14
+ def __init__(self, model_name="openai/clip-vit-base-patch32", freeze=True):
15
+ super().__init__()
16
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
17
+ self.text_model = CLIPTextModel.from_pretrained(model_name)
18
+
19
+ if freeze:
20
+ for param in self.text_model.parameters():
21
+ param.requires_grad = False
22
+
23
+ self.embedding_dim = self.text_model.config.hidden_size # 512 for base model
24
+
25
+ def forward(self, text_prompts):
26
+ """
27
+ Encode text prompts to embeddings.
28
+
29
+ Args:
30
+ text_prompts: List of strings or single string
31
+
32
+ Returns:
33
+ Text embeddings of shape [batch_size, embedding_dim]
34
+ """
35
+ if isinstance(text_prompts, str):
36
+ text_prompts = [text_prompts]
37
+
38
+ tokens = self.tokenizer(
39
+ text_prompts,
40
+ padding=True,
41
+ truncation=True,
42
+ max_length=77,
43
+ return_tensors="pt"
44
+ ).to(self.text_model.device)
45
+
46
+ with torch.set_grad_enabled(self.text_model.training):
47
+ outputs = self.text_model(**tokens)
48
+ embeddings = outputs.pooler_output # [batch_size, 512]
49
+
50
+ return embeddings
51
+
52
+ def encode_batch(self, text_prompts):
53
+ """Convenience method for batch encoding."""
54
+ return self.forward(text_prompts)
55
+
56
+ @property
57
+ def device(self):
58
+ return self.text_model.device
59
+