| | import os |
| | import os.path as osp |
| | import PIL |
| | from PIL import Image |
| | from pathlib import Path |
| | import numpy as np |
| | import numpy.random as npr |
| |
|
| | import torch |
| | import torchvision.transforms as tvtrans |
| | from lib.cfg_helper import model_cfg_bank |
| | from lib.model_zoo import get_model |
| | from lib.model_zoo.ddim_dualcontext import DDIMSampler_DualContext |
| | from lib.experiments.sd_default import color_adjust, auto_merge_imlist |
| |
|
| | import argparse |
| |
|
| | n_sample_image_default = 2 |
| | n_sample_text_default = 4 |
| |
|
| | def highlight_print(info): |
| | print('') |
| | print(''.join(['#']*(len(info)+4))) |
| | print('# '+info+' #') |
| | print(''.join(['#']*(len(info)+4))) |
| | print('') |
| |
|
| | class vd_inference(object): |
| | def __init__(self, pth='pretrained/vd1.0-four-flow.pth', fp16=False, device=0): |
| | cfgm_name = 'vd_noema' |
| | cfgm = model_cfg_bank()('vd_noema') |
| | device_str = device if isinstance(device, str) else 'cuda:{}'.format(device) |
| | cfgm.args.autokl_cfg.map_location = device_str |
| | cfgm.args.optimus_cfg.map_location = device_str |
| | net = get_model()(cfgm) |
| | if fp16: |
| | highlight_print('Running in FP16') |
| | net.clip.fp16 = True |
| | net = net.half() |
| | sd = torch.load(pth, map_location=device_str) |
| | net.load_state_dict(sd, strict=False) |
| | print('Load pretrained weight from {}'.format(pth)) |
| | net.to(device) |
| |
|
| | self.device = device |
| | self.model_name = cfgm_name |
| | self.net = net |
| | self.fp16 = fp16 |
| | from lib.model_zoo.ddim_vd import DDIMSampler_VD |
| | self.sampler = DDIMSampler_VD(net) |
| |
|
| | def regularize_image(self, x): |
| | BICUBIC = PIL.Image.Resampling.BICUBIC |
| | if isinstance(x, str): |
| | x = Image.open(x).resize([512, 512], resample=BICUBIC) |
| | x = tvtrans.ToTensor()(x) |
| | elif isinstance(x, PIL.Image.Image): |
| | x = x.resize([512, 512], resample=BICUBIC) |
| | x = tvtrans.ToTensor()(x) |
| | elif isinstance(x, np.ndarray): |
| | x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) |
| | x = tvtrans.ToTensor()(x) |
| | elif isinstance(x, torch.Tensor): |
| | pass |
| | else: |
| | assert False, 'Unknown image type' |
| |
|
| | assert (x.shape[1]==512) & (x.shape[2]==512), \ |
| | 'Wrong image size' |
| | x = x.to(self.device) |
| | if self.fp16: |
| | x = x.half() |
| | return x |
| |
|
| | def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None): |
| | net = self.net |
| | if xtype == 'image': |
| | x = net.autokl_decode(z) |
| |
|
| | color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None) |
| | color_adj_simple = (color_adj=='Simple') or color_adj=='simple' |
| | color_adj_keep_ratio = 0.5 |
| |
|
| | if color_adj_flag and (ctype=='vision'): |
| | x_adj = [] |
| | for xi in x: |
| | color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) |
| | xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) |
| | x_adj.append(xi_adj) |
| | x = x_adj |
| | else: |
| | x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) |
| | x = [tvtrans.ToPILImage()(xi) for xi in x] |
| | return x |
| |
|
| | elif xtype == 'text': |
| | prompt_temperature = 1.0 |
| | prompt_merge_same_adj_word = True |
| | x = net.optimus_decode(z, temperature=prompt_temperature) |
| | if prompt_merge_same_adj_word: |
| | xnew = [] |
| | for xi in x: |
| | xi_split = xi.split() |
| | xinew = [] |
| | for idxi, wi in enumerate(xi_split): |
| | if idxi!=0 and wi==xi_split[idxi-1]: |
| | continue |
| | xinew.append(wi) |
| | xnew.append(' '.join(xinew)) |
| | x = xnew |
| | return x |
| |
|
| | def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,): |
| | net = self.net |
| | sampler = self.sampler |
| | ddim_steps = 50 |
| | ddim_eta = 0.0 |
| |
|
| | if xtype == 'image': |
| | n_samples = n_sample_image_default if n_samples is None else n_samples |
| | elif xtype == 'text': |
| | n_samples = n_sample_text_default if n_samples is None else n_samples |
| |
|
| | if ctype in ['prompt', 'text']: |
| | c = net.clip_encode_text(n_samples * [cin]) |
| | u = None |
| | if scale != 1.0: |
| | u = net.clip_encode_text(n_samples * [""]) |
| |
|
| | elif ctype in ['vision', 'image']: |
| | cin = self.regularize_image(cin) |
| | ctemp = cin*2 - 1 |
| | ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) |
| | c = net.clip_encode_vision(ctemp) |
| | u = None |
| | if scale != 1.0: |
| | dummy = torch.zeros_like(ctemp) |
| | u = net.clip_encode_vision(dummy) |
| |
|
| | u, c = [u.half(), c.half()] if self.fp16 else [u, c] |
| |
|
| | if xtype == 'image': |
| | h, w = [512, 512] |
| | shape = [n_samples, 4, h//8, w//8] |
| | z, _ = sampler.sample( |
| | steps=ddim_steps, |
| | shape=shape, |
| | conditioning=c, |
| | unconditional_guidance_scale=scale, |
| | unconditional_conditioning=u, |
| | xtype=xtype, ctype=ctype, |
| | eta=ddim_eta, |
| | verbose=False,) |
| | x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin) |
| | return x |
| |
|
| | elif xtype == 'text': |
| | n = 768 |
| | shape = [n_samples, n] |
| | z, _ = sampler.sample( |
| | steps=ddim_steps, |
| | shape=shape, |
| | conditioning=c, |
| | unconditional_guidance_scale=scale, |
| | unconditional_conditioning=u, |
| | xtype=xtype, ctype=ctype, |
| | eta=ddim_eta, |
| | verbose=False,) |
| | x = self.decode(z, xtype, ctype) |
| | return x |
| |
|
| | def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,): |
| | net = self.net |
| | scale = 7.5 |
| | sampler = self.sampler |
| | ddim_steps = 50 |
| | ddim_eta = 0.0 |
| | n_samples = n_sample_image_default if n_samples is None else n_samples |
| |
|
| | cin = self.regularize_image(cin) |
| | ctemp = cin*2 - 1 |
| | ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) |
| | c = net.clip_encode_vision(ctemp) |
| | u = None |
| | if scale != 1.0: |
| | dummy = torch.zeros_like(ctemp) |
| | u = net.clip_encode_vision(dummy) |
| | u, c = [u.half(), c.half()] if self.fp16 else [u, c] |
| |
|
| | if level == 0: |
| | pass |
| | else: |
| | c_glb = c[:, 0:1] |
| | c_loc = c[:, 1: ] |
| | u_glb = u[:, 0:1] |
| | u_loc = u[:, 1: ] |
| |
|
| | if level == -1: |
| | c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1) |
| | u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1) |
| | if level == -2: |
| | c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2) |
| | u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2) |
| | if level == 1: |
| | c_loc = self.find_low_rank(c_loc, demean=True, q=10) |
| | u_loc = self.find_low_rank(u_loc, demean=True, q=10) |
| | if level == 2: |
| | c_loc = self.find_low_rank(c_loc, demean=True, q=2) |
| | u_loc = self.find_low_rank(u_loc, demean=True, q=2) |
| |
|
| | c = torch.cat([c_glb, c_loc], dim=1) |
| | u = torch.cat([u_glb, u_loc], dim=1) |
| |
|
| | h, w = [512, 512] |
| | shape = [n_samples, 4, h//8, w//8] |
| | z, _ = sampler.sample( |
| | steps=ddim_steps, |
| | shape=shape, |
| | conditioning=c, |
| | unconditional_guidance_scale=scale, |
| | unconditional_conditioning=u, |
| | xtype='image', ctype='vision', |
| | eta=ddim_eta, |
| | verbose=False,) |
| | x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin) |
| | return x |
| |
|
| | def find_low_rank(self, x, demean=True, q=20, niter=10): |
| | if demean: |
| | x_mean = x.mean(-1, keepdim=True) |
| | x_input = x - x_mean |
| | else: |
| | x_input = x |
| |
|
| | if x_input.dtype == torch.float16: |
| | fp16 = True |
| | x_input = x_input.float() |
| | else: |
| | fp16 = False |
| |
|
| | u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) |
| | ss = torch.stack([torch.diag(si) for si in s]) |
| | x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) |
| |
|
| | if fp16: |
| | x_lowrank = x_lowrank.half() |
| |
|
| | if demean: |
| | x_lowrank += x_mean |
| | return x_lowrank |
| |
|
| | def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10): |
| | if demean: |
| | x_mean = x.mean(-1, keepdim=True) |
| | x_input = x - x_mean |
| | else: |
| | x_input = x |
| |
|
| | if x_input.dtype == torch.float16: |
| | fp16 = True |
| | x_input = x_input.float() |
| | else: |
| | fp16 = False |
| |
|
| | u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) |
| | s[:, 0:q_remove] = 0 |
| | ss = torch.stack([torch.diag(si) for si in s]) |
| | x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) |
| |
|
| | if fp16: |
| | x_lowrank = x_lowrank.half() |
| |
|
| | if demean: |
| | x_lowrank += x_mean |
| | return x_lowrank |
| |
|
| | def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ): |
| | net = self.net |
| | scale = 7.5 |
| | sampler = self.sampler |
| | ddim_steps = 50 |
| | ddim_eta = 0.0 |
| | n_samples = n_sample_image_default if n_samples is None else n_samples |
| |
|
| | ctemp0 = self.regularize_image(cim) |
| | ctemp1 = ctemp0*2 - 1 |
| | ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
| | cim = net.clip_encode_vision(ctemp1) |
| | uim = None |
| | if scale != 1.0: |
| | dummy = torch.zeros_like(ctemp1) |
| | uim = net.clip_encode_vision(dummy) |
| |
|
| | ctx = net.clip_encode_text(n_samples * [ctx]) |
| | utx = None |
| | if scale != 1.0: |
| | utx = net.clip_encode_text(n_samples * [""]) |
| |
|
| | uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] |
| | utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx] |
| |
|
| | h, w = [512, 512] |
| | shape = [n_samples, 4, h//8, w//8] |
| |
|
| | z, _ = sampler.sample_dc( |
| | steps=ddim_steps, |
| | shape=shape, |
| | first_conditioning=[uim, cim], |
| | second_conditioning=[utx, ctx], |
| | unconditional_guidance_scale=scale, |
| | xtype='image', |
| | first_ctype='vision', |
| | second_ctype='prompt', |
| | eta=ddim_eta, |
| | verbose=False, |
| | mixed_ratio=(1-mixing), ) |
| | x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) |
| | return x |
| |
|
| | def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,): |
| | net = self.net |
| | scale = 7.5 |
| | sampler = self.sampler |
| | ddim_steps = 50 |
| | ddim_eta = 0.0 |
| | prompt_temperature = 1.0 |
| | n_samples = n_sample_image_default if n_samples is None else n_samples |
| |
|
| | ctemp0 = self.regularize_image(cim) |
| | ctemp1 = ctemp0*2 - 1 |
| | ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
| | cim = net.clip_encode_vision(ctemp1) |
| | uim = None |
| | if scale != 1.0: |
| | dummy = torch.zeros_like(ctemp1) |
| | uim = net.clip_encode_vision(dummy) |
| |
|
| | uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] |
| |
|
| | n = 768 |
| | shape = [n_samples, n] |
| | zt, _ = sampler.sample( |
| | steps=ddim_steps, |
| | shape=shape, |
| | conditioning=cim, |
| | unconditional_guidance_scale=scale, |
| | unconditional_conditioning=uim, |
| | xtype='text', ctype='vision', |
| | eta=ddim_eta, |
| | verbose=False,) |
| | ztn = net.optimus_encode([ctx_n]) |
| | ztp = net.optimus_encode([ctx_p]) |
| |
|
| | ztn_norm = ztn / ztn.norm(dim=1) |
| | zt_proj_mag = torch.matmul(zt, ztn_norm[0]) |
| | zt_perp = zt - zt_proj_mag[:, None] * ztn_norm |
| | zt_newd = zt_perp + ztp |
| | ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature) |
| |
|
| | ctx_new = net.clip_encode_text(ctx_new) |
| | ctx_p = net.clip_encode_text([ctx_p]) |
| | ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1) |
| | utx_new = net.clip_encode_text(n_samples * [""]) |
| | utx_new = torch.cat([utx_new, utx_new], dim=1) |
| |
|
| | cim_loc = cim[:, 1: ] |
| | cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10) |
| | cim_new = cim_loc_new |
| | uim_new = uim[:, 1:] |
| | |
| | h, w = [512, 512] |
| | shape = [n_samples, 4, h//8, w//8] |
| | z, _ = sampler.sample_dc( |
| | steps=ddim_steps, |
| | shape=shape, |
| | first_conditioning=[uim_new, cim_new], |
| | second_conditioning=[utx_new, ctx_new], |
| | unconditional_guidance_scale=scale, |
| | xtype='image', |
| | first_ctype='vision', |
| | second_ctype='prompt', |
| | eta=ddim_eta, |
| | verbose=False, |
| | mixed_ratio=0.33, ) |
| |
|
| | x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) |
| | return x |
| |
|
| | def main(netwrapper, |
| | app, |
| | image=None, |
| | prompt=None, |
| | nprompt=None, |
| | pprompt=None, |
| | color_adj=None, |
| | disentanglement_level=None, |
| | dual_guided_mixing=None, |
| | n_samples=4, |
| | seed=0,): |
| |
|
| | if seed is not None: |
| | seed = 0 if seed<0 else seed |
| | np.random.seed(seed) |
| | torch.manual_seed(seed+100) |
| |
|
| | if app == 'text-to-image': |
| | print('Running [{}] with prompt [{}], n_samples [{}], seed [{}].'.format( |
| | app, prompt, n_samples, seed)) |
| | if (prompt is None) or (prompt == ""): |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.inference( |
| | xtype = 'image', |
| | cin = prompt, |
| | ctype = 'prompt', |
| | n_samples = n_samples, ) |
| | return rv, None |
| |
|
| | elif app == 'image-variation': |
| | print('Running [{}] with image [{}], color_adj [{}], n_samples [{}], seed [{}].'.format( |
| | app, image, color_adj, n_samples, seed)) |
| | if image is None: |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.inference( |
| | xtype = 'image', |
| | cin = image, |
| | ctype = 'vision', |
| | color_adj = color_adj, |
| | n_samples = n_samples, ) |
| | return rv, None |
| |
|
| | elif app == 'image-to-text': |
| | print('Running [{}] with iamge [{}], n_samples [{}], seed [{}].'.format( |
| | app, image, n_samples, seed)) |
| | if image is None: |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.inference( |
| | xtype = 'text', |
| | cin = image, |
| | ctype = 'vision', |
| | n_samples = n_samples, ) |
| | return None, '\n'.join(rv) |
| |
|
| | elif app == 'text-variation': |
| | print('Running [{}] with prompt [{}], n_samples [{}], seed [{}].'.format( |
| | app, prompt, n_samples, seed)) |
| | if prompt is None: |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.inference( |
| | xtype = 'text', |
| | cin = prompt, |
| | ctype = 'prompt', |
| | n_samples = n_samples, ) |
| | return None, '\n'.join(rv) |
| |
|
| | elif app == 'disentanglement': |
| | print('Running [{}] with image [{}], color_adj [{}], disentanglement_level [{}], n_samples [{}], seed [{}].'.format( |
| | app, image, color_adj, disentanglement_level, n_samples, seed)) |
| | if image is None: |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.application_disensemble( |
| | cin = image, |
| | level = disentanglement_level, |
| | color_adj = color_adj, |
| | n_samples = n_samples, ) |
| | return rv, None |
| |
|
| | elif app == 'dual-guided': |
| | print('Running [{}] with image [{}], prompt [{}], color_adj [{}], dual_guided_mixing [{}], n_samples [{}], seed [{}].'.format( |
| | app, image, prompt, color_adj, dual_guided_mixing, n_samples, seed)) |
| | if (image is None) or (prompt is None) or (prompt==""): |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.application_dualguided( |
| | cim = image, |
| | ctx = prompt, |
| | mixing = dual_guided_mixing, |
| | color_adj = color_adj, |
| | n_samples = n_samples, ) |
| | return rv, None |
| |
|
| | elif app == 'i2t2i': |
| | print('Running [{}] with image [{}], nprompt [{}], pprompt [{}], color_adj [{}], n_samples [{}], seed [{}].'.format( |
| | app, image, nprompt, pprompt, color_adj, n_samples, seed)) |
| | if (image is None) or (nprompt is None) or (nprompt=="") \ |
| | or (pprompt is None) or (pprompt==""): |
| | return None, None |
| | with torch.no_grad(): |
| | rv = netwrapper.application_i2t2i( |
| | cim = image, |
| | ctx_n = nprompt, |
| | ctx_p = pprompt, |
| | color_adj = color_adj, |
| | n_samples = n_samples, ) |
| | return rv, None |
| | |
| | else: |
| | assert False, "No such mode!" |
| |
|
| | if __name__ == '__main__': |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--app", type=str, default="text-to-image", |
| | help="Choose the application from ["\ |
| | "text-to-image, image-variation, "\ |
| | "image-to-text, text-variation, "\ |
| | "disentanglement, dual-guided, i2t2i]") |
| |
|
| | parser.add_argument( |
| | "--model", type=str, default="official", |
| | help="Choose the model type from ["\ |
| | "dc, official]") |
| |
|
| | parser.add_argument( |
| | "--prompt", type=str, |
| | default="a dream of a village in china, by Caspar "\ |
| | "David Friedrich, matte painting trending on artstation HQ") |
| |
|
| | parser.add_argument("--image", type=str) |
| |
|
| | parser.add_argument("--nprompt", type=str) |
| |
|
| | parser.add_argument("--pprompt", type=str) |
| |
|
| | parser.add_argument("--coloradj", type=str, default='simple') |
| |
|
| | parser.add_argument("--dislevel", type=int, default=0) |
| |
|
| | parser.add_argument("--dgmixing", type=float, default=0.7) |
| |
|
| | parser.add_argument("--nsample", type=int, default=4) |
| |
|
| | parser.add_argument("--seed", type=int) |
| |
|
| | parser.add_argument("--save", type=str, default='log', |
| | help="The path or file the result will save into") |
| |
|
| | parser.add_argument("--gpu", type=int, default=0) |
| |
|
| | parser.add_argument("--fp16", action="store_true") |
| |
|
| | |
| |
|
| | args = parser.parse_args() |
| |
|
| | assert args.app in [ |
| | "text-to-image", "image-variation", |
| | "image-to-text", "text-variation", |
| | "disentanglement", "dual-guided", "i2t2i"], \ |
| | "Unknown app! Select from [text-to-image, image-variation, "\ |
| | "image-to-text, text-variation, "\ |
| | "disentanglement, dual-guided, i2t2i]" |
| |
|
| | device=args.gpu if torch.cuda.is_available() else 'cpu' |
| |
|
| | if args.model in ['4-flow', 'official']: |
| | if args.fp16: |
| | pth='pretrained/vd-four-flow-v1-0-fp16.pth' |
| | else: |
| | pth='pretrained/vd-four-flow-v1-0.pth' |
| | vd_wrapper = vd_inference(pth=pth, fp16=args.fp16, device=device) |
| | elif args.model in ['2-flow', 'dc']: |
| | raise NotImplementedError |
| | |
| | elif args.model in ['1-flow', 'basic']: |
| | raise NotImplementedError |
| | |
| | else: |
| | assert False, "No such model! Select model from [4-flow(official), 2-flow(dc), 1-flow(basic)]" |
| |
|
| | imout, txtout = main( |
| | netwrapper=vd_wrapper, |
| | app=args.app, |
| | image=args.image, |
| | prompt=args.prompt, |
| | nprompt=args.nprompt, |
| | pprompt=args.pprompt, |
| | color_adj=args.coloradj, |
| | disentanglement_level=args.dislevel, |
| | dual_guided_mixing=args.dgmixing, |
| | n_samples=args.nsample, |
| | seed=args.seed,) |
| |
|
| | if imout is not None: |
| | imout = auto_merge_imlist([np.array(i) for i in imout]) |
| | imout = PIL.Image.fromarray(imout) |
| | if osp.isdir(args.save): |
| | imout.save(osp.join(args.save, 'imout.png')) |
| | print('Output image saved to {}.'.format(osp.join(args.save, 'imout.png'))) |
| | else: |
| | imout.save(osp.join(args.save)) |
| | print('Output image saved to {}.'.format(args.save)) |
| | |
| | if txtout is not None: |
| | print(txtout) |
| |
|