| | import torch |
| | import einops |
| | import torchvision |
| | import torch.nn as nn |
| | from typing import List, Tuple |
| |
|
| |
|
| | class MultiviewStack(nn.Module): |
| | def __init__( |
| | self, |
| | encoders: List[nn.Module], |
| | normalizations: List[Tuple[List, List]], |
| | output_dim: int, |
| | ): |
| | super().__init__() |
| | self.encoders = nn.ModuleList(encoders) |
| | self.normalizations = [] |
| | for mean, std in normalizations: |
| | self.normalizations.append( |
| | torchvision.transforms.Normalize(mean=mean, std=std) |
| | ) |
| |
|
| | def forward(self, x): |
| | orig_shape = x.shape |
| | x = einops.rearrange(x, "... V C H W -> (...) V C H W") |
| | outputs = [] |
| | for i, encoder in enumerate(self.encoders): |
| | this_view = x[:, i] |
| | this_view = self.normalizations[i](this_view) |
| | outputs.append(encoder(this_view)) |
| | out = torch.stack(outputs, dim=-1) |
| | out = out.reshape(*orig_shape[:-3], -1) |
| | return out |
| |
|