| | """Contains utility functionality to modify torch modules. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Any |
| |
|
| | from torch import nn |
| |
|
| | NORM_LAYER_TYPES = tuple(module_type for name, module_type in nn.__dict__.items() if "Norm" in name) |
| | BATCH_NORM_LAYER_TYPES = tuple( |
| | module_type for name, module_type in nn.__dict__.items() if "BatchNorm" in name |
| | ) |
| |
|
| |
|
| | def freeze_norm_layer(module: nn.Module) -> nn.Module: |
| | """Freeze all normalization layers.""" |
| |
|
| | def set_module_eval_mode(module: nn.Module, _: Any) -> None: |
| | module.eval() |
| |
|
| | for submodule in module.modules(): |
| | if isinstance(submodule, NORM_LAYER_TYPES): |
| | submodule.requires_grad_(False) |
| | |
| | |
| | submodule.register_forward_pre_hook(set_module_eval_mode) |
| |
|
| | return module |
| |
|