|
|
import torch |
|
|
import torchvision |
|
|
import os |
|
|
import gc |
|
|
import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import torchvision.transforms as transforms |
|
|
from transformers import CLIPTextModel |
|
|
from peft import PeftModel, LoraConfig |
|
|
from lora_w2w import LoRAw2w |
|
|
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler |
|
|
from peft.utils.save_and_load import load_peft_weights, set_peft_model_state_dict |
|
|
from transformers import AutoTokenizer, PretrainedConfig |
|
|
from PIL import Image |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
DDPMScheduler, |
|
|
DiffusionPipeline, |
|
|
DPMSolverMultistepScheduler, |
|
|
UNet2DConditionModel, |
|
|
PNDMScheduler, |
|
|
StableDiffusionPipeline |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_weights(unet, proj, mean, std, v, device, factor = 1.0): |
|
|
|
|
|
m = torch.mean(proj, 0) |
|
|
standev = torch.std(proj, 0) |
|
|
del proj |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
sample = torch.zeros([1, 1000]).to(device) |
|
|
for i in range(1000): |
|
|
sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1)) |
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|