|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
from typing import List |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def inflate_array_like(array, target): |
|
|
""" (tested) |
|
|
Inflates the array (array) with only a single axis (i.e. shape = (batch_size,), or possibly more empty |
|
|
axes (i.e. shape (batch_size, 1, ..., 1)) to match the target shape. |
|
|
Args: |
|
|
array: (B, ) |
|
|
target: (B, ...) |
|
|
|
|
|
Returns: |
|
|
array: (B, ...) |
|
|
""" |
|
|
if isinstance(array, float): |
|
|
return array |
|
|
|
|
|
diff_dims = target.ndim - array.ndim |
|
|
assert diff_dims >= 0, f'Error: target.ndim {target.ndim} < array.ndim {array.ndim}' |
|
|
if diff_dims == 0: |
|
|
return array |
|
|
assert target.shape[:array.ndim] == array.shape[:array.ndim], f'Error: target.shape[:array.ndim] {target.shape[:array.ndim]} != array.shape[:array.ndim] {array.shape[:array.ndim]}' |
|
|
return array[(...,) + (None,) * diff_dims] |
|
|
|
|
|
|
|
|
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): |
|
|
zero_index = -1 * len(inds) |
|
|
first_inds = list(range(len(tensor.shape[:zero_index]))) |
|
|
return tensor.permute(first_inds + [zero_index + i for i in inds]) |
|
|
|
|
|
def flatten_final_dims(t: torch.Tensor, no_dims: int): |
|
|
return t.reshape(t.shape[:-no_dims] + (-1,)) |
|
|
|
|
|
def sum_except_batch(t: torch.Tensor, batch_dims: int=1): |
|
|
return t.reshape(t.shape[:batch_dims] + (-1,)).sum(dim=-1) |
|
|
|
|
|
def masked_mean(mask, value, dim, eps=1e-4): |
|
|
mask = mask.expand(*value.shape) |
|
|
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) |
|
|
|
|
|
|
|
|
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): |
|
|
boundaries = torch.linspace( |
|
|
min_bin, max_bin, no_bins - 1, device=pts.device |
|
|
) |
|
|
dists = torch.sqrt( |
|
|
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) |
|
|
) |
|
|
return torch.bucketize(dists, boundaries) |
|
|
|
|
|
|
|
|
def dict_multimap(fn, dicts): |
|
|
first = dicts[0] |
|
|
new_dict = {} |
|
|
for k, v in first.items(): |
|
|
all_v = [d[k] for d in dicts] |
|
|
if type(v) is dict: |
|
|
new_dict[k] = dict_multimap(fn, all_v) |
|
|
else: |
|
|
new_dict[k] = fn(all_v) |
|
|
|
|
|
return new_dict |
|
|
|
|
|
|
|
|
def one_hot(x, v_bins): |
|
|
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) |
|
|
diffs = x[..., None] - reshaped_bins |
|
|
am = torch.argmin(torch.abs(diffs), dim=-1) |
|
|
return nn.functional.one_hot(am, num_classes=len(v_bins)).float() |
|
|
|
|
|
|
|
|
def batched_gather(data, inds, dim=0, no_batch_dims=0): |
|
|
ranges = [] |
|
|
for i, s in enumerate(data.shape[:no_batch_dims]): |
|
|
r = torch.arange(s) |
|
|
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) |
|
|
ranges.append(r) |
|
|
|
|
|
remaining_dims = [ |
|
|
slice(None) for _ in range(len(data.shape) - no_batch_dims) |
|
|
] |
|
|
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds |
|
|
ranges.extend(remaining_dims) |
|
|
return data[ranges] |
|
|
|
|
|
|
|
|
|
|
|
def dict_map(fn, dic, leaf_type): |
|
|
new_dict = {} |
|
|
for k, v in dic.items(): |
|
|
if type(v) is dict: |
|
|
new_dict[k] = dict_map(fn, v, leaf_type) |
|
|
else: |
|
|
new_dict[k] = tree_map(fn, v, leaf_type) |
|
|
|
|
|
return new_dict |
|
|
|
|
|
|
|
|
def tree_map(fn, tree, leaf_type): |
|
|
if isinstance(tree, dict): |
|
|
return dict_map(fn, tree, leaf_type) |
|
|
elif isinstance(tree, list): |
|
|
return [tree_map(fn, x, leaf_type) for x in tree] |
|
|
elif isinstance(tree, tuple): |
|
|
return tuple([tree_map(fn, x, leaf_type) for x in tree]) |
|
|
elif isinstance(tree, leaf_type): |
|
|
return fn(tree) |
|
|
else: |
|
|
print(type(tree)) |
|
|
raise ValueError("Not supported") |
|
|
|
|
|
|
|
|
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) |
|
|
|
|
|
def _fetch_dims(tree): |
|
|
shapes = [] |
|
|
tree_type = type(tree) |
|
|
if tree_type is dict: |
|
|
for v in tree.values(): |
|
|
shapes.extend(_fetch_dims(v)) |
|
|
elif tree_type is list or tree_type is tuple: |
|
|
for t in tree: |
|
|
shapes.extend(_fetch_dims(t)) |
|
|
elif tree_type is torch.Tensor: |
|
|
shapes.append(tree.shape) |
|
|
else: |
|
|
raise ValueError("Not supported") |
|
|
|
|
|
return shapes |
|
|
|