|
|
import gradio as gr |
|
|
import sys |
|
|
import os |
|
|
import tqdm |
|
|
sys.path.append(os.path.abspath(os.path.join("", ".."))) |
|
|
import torch |
|
|
import gc |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
from PIL import Image |
|
|
from utils import load_models, save_model_w2w, save_model_for_diffusers |
|
|
from sampling import sample_weights |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
global device |
|
|
global generator |
|
|
global unet |
|
|
global vae |
|
|
global text_encoder |
|
|
global tokenizer |
|
|
global noise_scheduler |
|
|
device = "cuda:0" |
|
|
generator = torch.Generator(device=device) |
|
|
|
|
|
models_path = snapshot_download(repo_id="Snapchat/w2w") |
|
|
|
|
|
mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device) |
|
|
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device) |
|
|
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device) |
|
|
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device) |
|
|
df = torch.load(f"{models_path}/identity_df.pt") |
|
|
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt") |
|
|
|
|
|
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) |
|
|
|
|
|
global network |
|
|
|
|
|
def sample_model(): |
|
|
global unet |
|
|
del unet |
|
|
global network |
|
|
unet, _, _, _, _ = load_models(device) |
|
|
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) |
|
|
|
|
|
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): |
|
|
global device |
|
|
global generator |
|
|
global unet |
|
|
global vae |
|
|
global text_encoder |
|
|
global tokenizer |
|
|
global noise_scheduler |
|
|
generator = generator.manual_seed(seed) |
|
|
latents = torch.randn( |
|
|
(1, unet.in_channels, 512 // 8, 512 // 8), |
|
|
generator = generator, |
|
|
device = device |
|
|
).bfloat16() |
|
|
|
|
|
|
|
|
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
|
|
|
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
max_length = text_input.input_ids.shape[-1] |
|
|
uncond_input = tokenizer( |
|
|
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" |
|
|
) |
|
|
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
noise_scheduler.set_timesteps(ddim_steps) |
|
|
latents = latents * noise_scheduler.init_noise_sigma |
|
|
|
|
|
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) |
|
|
with network: |
|
|
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
|
image = vae.decode(latents).sample |
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] |
|
|
|
|
|
image = Image.fromarray((image * 255).round().astype("uint8")) |
|
|
|
|
|
return [image] |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown("# <em>weights2weights</em> Demo") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
files = gr.Files( |
|
|
label="Upload a photo of your face to invert, or sample a new model", |
|
|
file_types=["image"] |
|
|
) |
|
|
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125) |
|
|
|
|
|
sample = gr.Button("Sample New Model") |
|
|
|
|
|
with gr.Column(visible=False) as clear_button: |
|
|
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm") |
|
|
prompt = gr.Textbox(label="Prompt", |
|
|
info="Make sure to include 'sks person'" , |
|
|
placeholder="sks person", |
|
|
value="sks person") |
|
|
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon") |
|
|
seed = gr.Number(value=5, precision=0, label="Seed", interactive=True) |
|
|
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) |
|
|
steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True) |
|
|
|
|
|
|
|
|
submit = gr.Button("Submit") |
|
|
|
|
|
with gr.Column(): |
|
|
gallery = gr.Gallery(label="Generated Images") |
|
|
|
|
|
sample.click(fn=sample_model) |
|
|
|
|
|
submit.click(fn=inference, |
|
|
inputs=[prompt, negative_prompt, cfg, steps, seed], |
|
|
outputs=gallery) |
|
|
|
|
|
demo.launch(share=True) |