Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- hf_demo.py +4 -3
hf_demo.py
CHANGED
|
@@ -11,8 +11,9 @@ from PIL import Image
|
|
| 11 |
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
|
|
|
| 14 |
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
| 15 |
-
|
| 16 |
|
| 17 |
from inference import get_lora_network, inference, get_validation_dataloader
|
| 18 |
lora_map = {
|
|
@@ -42,7 +43,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
|
|
| 42 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
| 43 |
|
| 44 |
prompts = [prompt]*samples
|
| 45 |
-
infer_loader = get_validation_dataloader(prompts)
|
| 46 |
network = get_lora_network(pipe.unet, adapter_path)["network"]
|
| 47 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 48 |
height=512, width=512, scales=[1.0],
|
|
@@ -52,7 +53,7 @@ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0,
|
|
| 52 |
return pred_images
|
| 53 |
@spaces.GPU
|
| 54 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
| 55 |
-
infer_loader = get_validation_dataloader(prompts, image)
|
| 56 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
| 57 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 58 |
height=512, width=512, scales=[0.,1.],
|
|
|
|
| 11 |
|
| 12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 13 |
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
| 14 |
+
print(f"Using {device} device, dtype={dtype}")
|
| 15 |
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
| 16 |
+
torch_dtype=dtype).to(device)
|
| 17 |
|
| 18 |
from inference import get_lora_network, inference, get_validation_dataloader
|
| 19 |
lora_map = {
|
|
|
|
| 43 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
| 44 |
|
| 45 |
prompts = [prompt]*samples
|
| 46 |
+
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
| 47 |
network = get_lora_network(pipe.unet, adapter_path)["network"]
|
| 48 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 49 |
height=512, width=512, scales=[1.0],
|
|
|
|
| 53 |
return pred_images
|
| 54 |
@spaces.GPU
|
| 55 |
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
| 56 |
+
infer_loader = get_validation_dataloader(prompts, image,num_workers=0)
|
| 57 |
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
| 58 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 59 |
height=512, width=512, scales=[0.,1.],
|