Spaces:
Runtime error
Runtime error
Update ace_inference.py
Browse files- ace_inference.py +11 -10
ace_inference.py
CHANGED
|
@@ -330,6 +330,7 @@ class ACEInference(DiffusionInference):
|
|
| 330 |
history_io=None,
|
| 331 |
tar_index=0,
|
| 332 |
**kwargs):
|
|
|
|
| 333 |
input_image, input_mask = image, mask
|
| 334 |
g = torch.Generator(device=we.device_id)
|
| 335 |
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
|
@@ -396,9 +397,9 @@ class ACEInference(DiffusionInference):
|
|
| 396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
| 397 |
ctx, null_ctx = {}, {}
|
| 398 |
# Get Noise Shape
|
| 399 |
-
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 400 |
x = self.encode_first_stage(image)
|
| 401 |
-
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
| 402 |
'first_stage_model',
|
| 403 |
skip_loaded=True)
|
| 404 |
noise = [
|
|
@@ -414,7 +415,7 @@ class ACEInference(DiffusionInference):
|
|
| 414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
| 415 |
|
| 416 |
# Encode Prompt
|
| 417 |
-
if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
| 418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
| 419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
| 420 |
function_name)(prompt)
|
|
@@ -424,14 +425,14 @@ class ACEInference(DiffusionInference):
|
|
| 424 |
function_name)(n_prompt)
|
| 425 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
| 426 |
prompt, edit_image, null_cont, null_cont_mask)
|
| 427 |
-
if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
|
| 428 |
'cond_stage_model',
|
| 429 |
skip_loaded=False)
|
| 430 |
ctx['crossattn'] = cont
|
| 431 |
null_ctx['crossattn'] = null_cont
|
| 432 |
|
| 433 |
# Encode Edit Images
|
| 434 |
-
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
| 436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
| 437 |
e_img, e_mask = [], []
|
|
@@ -442,14 +443,14 @@ class ACEInference(DiffusionInference):
|
|
| 442 |
m = [None] * len(u)
|
| 443 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
| 444 |
e_mask.append([self.interpolate_func(i) for i in m])
|
| 445 |
-
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
| 446 |
'first_stage_model',
|
| 447 |
skip_loaded=True)
|
| 448 |
null_ctx['edit'] = ctx['edit'] = e_img
|
| 449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
| 450 |
|
| 451 |
# Diffusion Process
|
| 452 |
-
if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
| 454 |
with torch.autocast('cuda',
|
| 455 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
@@ -490,15 +491,15 @@ class ACEInference(DiffusionInference):
|
|
| 490 |
guide_rescale=guide_rescale,
|
| 491 |
return_intermediate=None,
|
| 492 |
**kwargs)
|
| 493 |
-
if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
| 494 |
'diffusion_model',
|
| 495 |
skip_loaded=False)
|
| 496 |
|
| 497 |
# Decode to Pixel Space
|
| 498 |
-
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
| 500 |
x_samples = self.decode_first_stage(samples)
|
| 501 |
-
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
| 502 |
'first_stage_model',
|
| 503 |
skip_loaded=False)
|
| 504 |
x_samples = [x.squeeze(0) for x in x_samples]
|
|
|
|
| 330 |
history_io=None,
|
| 331 |
tar_index=0,
|
| 332 |
**kwargs):
|
| 333 |
+
print(kwargs)
|
| 334 |
input_image, input_mask = image, mask
|
| 335 |
g = torch.Generator(device=we.device_id)
|
| 336 |
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
|
|
|
| 397 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
| 398 |
ctx, null_ctx = {}, {}
|
| 399 |
# Get Noise Shape
|
| 400 |
+
if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 401 |
x = self.encode_first_stage(image)
|
| 402 |
+
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
| 403 |
'first_stage_model',
|
| 404 |
skip_loaded=True)
|
| 405 |
noise = [
|
|
|
|
| 415 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
| 416 |
|
| 417 |
# Encode Prompt
|
| 418 |
+
if self.use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
| 419 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
| 420 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
| 421 |
function_name)(prompt)
|
|
|
|
| 425 |
function_name)(n_prompt)
|
| 426 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
| 427 |
prompt, edit_image, null_cont, null_cont_mask)
|
| 428 |
+
if self.use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
|
| 429 |
'cond_stage_model',
|
| 430 |
skip_loaded=False)
|
| 431 |
ctx['crossattn'] = cont
|
| 432 |
null_ctx['crossattn'] = null_cont
|
| 433 |
|
| 434 |
# Encode Edit Images
|
| 435 |
+
if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 436 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
| 437 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
| 438 |
e_img, e_mask = [], []
|
|
|
|
| 443 |
m = [None] * len(u)
|
| 444 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
| 445 |
e_mask.append([self.interpolate_func(i) for i in m])
|
| 446 |
+
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
| 447 |
'first_stage_model',
|
| 448 |
skip_loaded=True)
|
| 449 |
null_ctx['edit'] = ctx['edit'] = e_img
|
| 450 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
| 451 |
|
| 452 |
# Diffusion Process
|
| 453 |
+
if self.use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 454 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
| 455 |
with torch.autocast('cuda',
|
| 456 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
|
|
| 491 |
guide_rescale=guide_rescale,
|
| 492 |
return_intermediate=None,
|
| 493 |
**kwargs)
|
| 494 |
+
if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
| 495 |
'diffusion_model',
|
| 496 |
skip_loaded=False)
|
| 497 |
|
| 498 |
# Decode to Pixel Space
|
| 499 |
+
if self.use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 500 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
| 501 |
x_samples = self.decode_first_stage(samples)
|
| 502 |
+
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
| 503 |
'first_stage_model',
|
| 504 |
skip_loaded=False)
|
| 505 |
x_samples = [x.squeeze(0) for x in x_samples]
|