| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from argparse import Namespace |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput |
| | from .adaptor_mlp import create_mlp_from_state |
| |
|
| |
|
| | class GenericAdaptor(AdaptorBase): |
| | def __init__(self, main_config: Namespace, adaptor_config, state): |
| | super().__init__() |
| |
|
| | self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.') |
| | self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.') |
| |
|
| | def forward(self, input: AdaptorInput) -> RadioOutput: |
| | summary = self.head_mlp(input.summary) |
| | feat = self.feat_mlp(input.features) |
| |
|
| | return RadioOutput(summary, feat) |
| |
|