rahul7star commited on
Commit
8cca793
·
verified ·
1 Parent(s): 2013a08

Create optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +130 -0
optimization.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
+
15
+ from optimization_utils import capture_component_call
16
+ from optimization_utils import aoti_compile
17
+ from optimization_utils import ZeroGPUCompiledModel
18
+ from optimization_utils import drain_module_parameters
19
+
20
+
21
+ P = ParamSpec('P')
22
+
23
+
24
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
25
+
26
+ TRANSFORMER_DYNAMIC_SHAPES = {
27
+ 'hidden_states': {
28
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
29
+ },
30
+ }
31
+
32
+ INDUCTOR_CONFIGS = {
33
+ 'conv_1x1_as_mm': True,
34
+ 'epilogue_fusion': False,
35
+ 'coordinate_descent_tuning': True,
36
+ 'coordinate_descent_check_all_directions': True,
37
+ 'max_autotune': True,
38
+ 'triton.cudagraphs': True,
39
+ }
40
+
41
+
42
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
43
+
44
+ @spaces.GPU(duration=1500)
45
+ def compile_transformer():
46
+
47
+ pipeline.load_lora_weights(
48
+ "Kijai/WanVideo_comfy",
49
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
50
+ adapter_name="lightx2v"
51
+ )
52
+ kwargs_lora = {}
53
+ kwargs_lora["load_into_transformer_2"] = True
54
+ pipeline.load_lora_weights(
55
+ "Kijai/WanVideo_comfy",
56
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
57
+ adapter_name="lightx2v_2", **kwargs_lora
58
+ )
59
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
60
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
61
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
62
+ pipeline.unload_lora_weights()
63
+
64
+ with capture_component_call(pipeline, 'transformer') as call:
65
+ pipeline(*args, **kwargs)
66
+
67
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
68
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
69
+
70
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
71
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
72
+
73
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
74
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
75
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
76
+ hidden_states_landscape = hidden_states
77
+ hidden_states_portrait = hidden_states_transposed
78
+ else:
79
+ hidden_states_landscape = hidden_states_transposed
80
+ hidden_states_portrait = hidden_states
81
+
82
+ exported_landscape_1 = torch.export.export(
83
+ mod=pipeline.transformer,
84
+ args=call.args,
85
+ kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
86
+ dynamic_shapes=dynamic_shapes,
87
+ )
88
+
89
+ exported_portrait_2 = torch.export.export(
90
+ mod=pipeline.transformer_2,
91
+ args=call.args,
92
+ kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
93
+ dynamic_shapes=dynamic_shapes,
94
+ )
95
+
96
+ compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
97
+ compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
98
+
99
+ compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
100
+ compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
101
+
102
+ return (
103
+ compiled_landscape_1,
104
+ compiled_landscape_2,
105
+ compiled_portrait_1,
106
+ compiled_portrait_2,
107
+ )
108
+
109
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
110
+ cl1, cl2, cp1, cp2 = compile_transformer()
111
+
112
+ def combined_transformer_1(*args, **kwargs):
113
+ hidden_states: torch.Tensor = kwargs['hidden_states']
114
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
115
+ return cl1(*args, **kwargs)
116
+ else:
117
+ return cp1(*args, **kwargs)
118
+
119
+ def combined_transformer_2(*args, **kwargs):
120
+ hidden_states: torch.Tensor = kwargs['hidden_states']
121
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
122
+ return cl2(*args, **kwargs)
123
+ else:
124
+ return cp2(*args, **kwargs)
125
+
126
+ pipeline.transformer.forward = combined_transformer_1
127
+ drain_module_parameters(pipeline.transformer)
128
+
129
+ pipeline.transformer_2.forward = combined_transformer_2
130
+ drain_module_parameters(pipeline.transformer_2)