|
|
import torch |
|
|
import torchvision |
|
|
import tqdm |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def invert(network, unet, vae, text_encoder, tokenizer, prompt, noise_scheduler, epochs, image_path, mask_path, device, weight_decay = 1e-10, lr=1e-1): |
|
|
|
|
|
if mask_path: |
|
|
mask = Image.open(mask_path) |
|
|
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) |
|
|
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16() |
|
|
else: |
|
|
mask = torch.ones((1,1,64,64)).to(device).bfloat16() |
|
|
|
|
|
|
|
|
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), |
|
|
transforms.RandomCrop(512), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5])]) |
|
|
|
|
|
|
|
|
train_dataset = torchvision.datasets.ImageFolder(root=image_path, transform = image_transforms) |
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) |
|
|
|
|
|
|
|
|
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
|
|
|
unet.train() |
|
|
for epoch in tqdm.tqdm(range(epochs)): |
|
|
for batch,_ in train_dataloader: |
|
|
|
|
|
batch = batch.to(device).bfloat16() |
|
|
latents = vae.encode(batch).latent_dist.sample() |
|
|
latents = latents*0.18215 |
|
|
noise = torch.randn_like(latents) |
|
|
bsz = latents.shape[0] |
|
|
|
|
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) |
|
|
timesteps = timesteps.long() |
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
|
|
|
with network: |
|
|
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample |
|
|
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") |
|
|
optim.zero_grad() |
|
|
loss.backward() |
|
|
optim.step() |
|
|
|
|
|
|
|
|
return network |
|
|
|
|
|
|
|
|
|