| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| | import torch |
| | import torch.nn as nn |
| | from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional |
| |
|
| |
|
| | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): |
| | zero_index = -1 * len(inds) |
| | first_inds = list(range(len(tensor.shape[:zero_index]))) |
| | return tensor.permute(first_inds + [zero_index + i for i in inds]) |
| |
|
| |
|
| | def flatten_final_dims(t: torch.Tensor, no_dims: int): |
| | return t.reshape(t.shape[:-no_dims] + (-1,)) |
| |
|
| |
|
| | def masked_mean(mask, value, dim, eps=1e-4): |
| | mask = mask.expand(*value.shape) |
| | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) |
| |
|
| |
|
| | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): |
| | boundaries = torch.linspace( |
| | min_bin, max_bin, no_bins - 1, device=pts.device |
| | ) |
| | dists = torch.sqrt( |
| | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) |
| | ) |
| | return torch.bucketize(dists, boundaries) |
| |
|
| |
|
| | def dict_multimap(fn, dicts): |
| | first = dicts[0] |
| | new_dict = {} |
| | for k, v in first.items(): |
| | all_v = [d[k] for d in dicts] |
| | if type(v) is dict: |
| | new_dict[k] = dict_multimap(fn, all_v) |
| | else: |
| | new_dict[k] = fn(all_v) |
| |
|
| | return new_dict |
| |
|
| |
|
| | def one_hot(x, v_bins): |
| | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) |
| | diffs = x[..., None] - reshaped_bins |
| | am = torch.argmin(torch.abs(diffs), dim=-1) |
| | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() |
| |
|
| |
|
| | def batched_gather(data, inds, dim=0, no_batch_dims=0): |
| | ranges = [] |
| | for i, s in enumerate(data.shape[:no_batch_dims]): |
| | r = torch.arange(s) |
| | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) |
| | ranges.append(r) |
| |
|
| | remaining_dims = [ |
| | slice(None) for _ in range(len(data.shape) - no_batch_dims) |
| | ] |
| | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds |
| | ranges.extend(remaining_dims) |
| | return data[ranges] |
| |
|
| |
|
| | |
| | def dict_map(fn, dic, leaf_type): |
| | new_dict = {} |
| | for k, v in dic.items(): |
| | if type(v) is dict: |
| | new_dict[k] = dict_map(fn, v, leaf_type) |
| | else: |
| | new_dict[k] = tree_map(fn, v, leaf_type) |
| |
|
| | return new_dict |
| |
|
| |
|
| | def tree_map(fn, tree, leaf_type): |
| | if isinstance(tree, dict): |
| | return dict_map(fn, tree, leaf_type) |
| | elif isinstance(tree, list): |
| | return [tree_map(fn, x, leaf_type) for x in tree] |
| | elif isinstance(tree, tuple): |
| | return tuple([tree_map(fn, x, leaf_type) for x in tree]) |
| | elif isinstance(tree, leaf_type): |
| | return fn(tree) |
| | else: |
| | print(type(tree)) |
| | raise ValueError("Not supported") |
| |
|
| |
|
| | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) |
| |
|
| | def _fetch_dims(tree): |
| | shapes = [] |
| | tree_type = type(tree) |
| | if tree_type is dict: |
| | for v in tree.values(): |
| | shapes.extend(_fetch_dims(v)) |
| | elif tree_type is list or tree_type is tuple: |
| | for t in tree: |
| | shapes.extend(_fetch_dims(t)) |
| | elif tree_type is torch.Tensor: |
| | shapes.append(tree.shape) |
| | else: |
| | raise ValueError("Not supported") |
| |
|
| | return shapes |
| |
|
| |
|
| | @torch.jit.ignore |
| | def _flat_idx_to_idx( |
| | flat_idx: int, |
| | dims: Tuple[int], |
| | ) -> Tuple[int]: |
| | idx = [] |
| | for d in reversed(dims): |
| | idx.append(flat_idx % d) |
| | flat_idx = flat_idx // d |
| |
|
| | return tuple(reversed(idx)) |
| |
|
| |
|
| | @torch.jit.ignore |
| | def _get_minimal_slice_set( |
| | start: Sequence[int], |
| | end: Sequence[int], |
| | dims: int, |
| | start_edges: Optional[Sequence[bool]] = None, |
| | end_edges: Optional[Sequence[bool]] = None, |
| | ) -> Sequence[Tuple[int]]: |
| | """ |
| | Produces an ordered sequence of tensor slices that, when used in |
| | sequence on a tensor with shape dims, yields tensors that contain every |
| | leaf in the contiguous range [start, end]. Care is taken to yield a |
| | short sequence of slices, and perhaps even the shortest possible (I'm |
| | pretty sure it's the latter). |
| | |
| | end is INCLUSIVE. |
| | """ |
| | |
| | |
| | |
| | def reduce_edge_list(l): |
| | tally = 1 |
| | for i in range(len(l)): |
| | reversed_idx = -1 * (i + 1) |
| | l[reversed_idx] *= tally |
| | tally = l[reversed_idx] |
| |
|
| | if(start_edges is None): |
| | start_edges = [s == 0 for s in start] |
| | reduce_edge_list(start_edges) |
| | if(end_edges is None): |
| | end_edges = [e == (d - 1) for e,d in zip(end, dims)] |
| | reduce_edge_list(end_edges) |
| |
|
| | |
| | |
| | if(len(start) == 0): |
| | return [tuple()] |
| | elif(len(start) == 1): |
| | return [(slice(start[0], end[0] + 1),)] |
| |
|
| | slices = [] |
| | path = [] |
| | |
| | |
| | for s,e in zip(start, end): |
| | if(s == e): |
| | path.append(slice(s, s + 1)) |
| | else: |
| | break |
| |
|
| | path = tuple(path) |
| | divergence_idx = len(path) |
| |
|
| | |
| | if(divergence_idx == len(dims)): |
| | return [tuple(path)] |
| |
|
| | def upper(): |
| | sdi = start[divergence_idx] |
| | return [ |
| | path + (slice(sdi, sdi + 1),) + s for s in |
| | _get_minimal_slice_set( |
| | start[divergence_idx + 1:], |
| | [d - 1 for d in dims[divergence_idx + 1:]], |
| | dims[divergence_idx + 1:], |
| | start_edges=start_edges[divergence_idx + 1:], |
| | end_edges=[1 for _ in end_edges[divergence_idx + 1:]] |
| | ) |
| | ] |
| |
|
| | def lower(): |
| | edi = end[divergence_idx] |
| | return [ |
| | path + (slice(edi, edi + 1),) + s for s in |
| | _get_minimal_slice_set( |
| | [0 for _ in start[divergence_idx + 1:]], |
| | end[divergence_idx + 1:], |
| | dims[divergence_idx + 1:], |
| | start_edges=[1 for _ in start_edges[divergence_idx + 1:]], |
| | end_edges=end_edges[divergence_idx + 1:], |
| | ) |
| | ] |
| |
|
| | |
| | |
| | if(start_edges[divergence_idx] and end_edges[divergence_idx]): |
| | slices.append( |
| | path + (slice(start[divergence_idx], end[divergence_idx] + 1),) |
| | ) |
| | |
| | |
| | elif(start_edges[divergence_idx]): |
| | slices.append( |
| | path + (slice(start[divergence_idx], end[divergence_idx]),) |
| | ) |
| | slices.extend(lower()) |
| | |
| | elif(end_edges[divergence_idx]): |
| | slices.extend(upper()) |
| | slices.append( |
| | path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),) |
| | ) |
| | |
| | |
| | |
| | else: |
| | slices.extend(upper()) |
| | middle_ground = end[divergence_idx] - start[divergence_idx] |
| | if(middle_ground > 1): |
| | slices.append( |
| | path + (slice(start[divergence_idx] + 1, end[divergence_idx]),) |
| | ) |
| | slices.extend(lower()) |
| |
|
| | return [tuple(s) for s in slices] |
| |
|
| |
|
| | @torch.jit.ignore |
| | def _chunk_slice( |
| | t: torch.Tensor, |
| | flat_start: int, |
| | flat_end: int, |
| | no_batch_dims: int, |
| | ) -> torch.Tensor: |
| | """ |
| | Equivalent to |
| | |
| | t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] |
| | |
| | but without the need for the initial reshape call, which can be |
| | memory-intensive in certain situations. The only reshape operations |
| | in this function are performed on sub-tensors that scale with |
| | (flat_end - flat_start), the chunk size. |
| | """ |
| |
|
| | batch_dims = t.shape[:no_batch_dims] |
| | start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) |
| | |
| | end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) |
| |
|
| | |
| | slices = _get_minimal_slice_set( |
| | start_idx, |
| | end_idx, |
| | batch_dims, |
| | ) |
| |
|
| | sliced_tensors = [t[s] for s in slices] |
| |
|
| | return torch.cat( |
| | [s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors] |
| | ) |
| |
|
| |
|
| | def chunk_layer( |
| | layer: Callable, |
| | inputs: Dict[str, Any], |
| | chunk_size: int, |
| | no_batch_dims: int, |
| | low_mem: bool = False, |
| | ) -> Any: |
| | """ |
| | Implements the "chunking" procedure described in section 1.11.8. |
| | |
| | Layer outputs and inputs are assumed to be simple "pytrees," |
| | consisting only of (arbitrarily nested) lists, tuples, and dicts with |
| | torch.Tensor leaves. |
| | |
| | Args: |
| | layer: |
| | The layer to be applied chunk-wise |
| | inputs: |
| | A (non-nested) dictionary of keyworded inputs. All leaves must |
| | be tensors and must share the same batch dimensions. |
| | chunk_size: |
| | The number of sub-batches per chunk. If multiple batch |
| | dimensions are specified, a "sub-batch" is defined as a single |
| | indexing of all batch dimensions simultaneously (s.t. the |
| | number of sub-batches is the product of the batch dimensions). |
| | no_batch_dims: |
| | How many of the initial dimensions of each input tensor can |
| | be considered batch dimensions. |
| | low_mem: |
| | Avoids flattening potentially large input tensors. Unnecessary |
| | in most cases, and is ever so slightly slower than the default |
| | setting. |
| | Returns: |
| | The reassembled output of the layer on the inputs. |
| | """ |
| | if not (len(inputs) > 0): |
| | raise ValueError("Must provide at least one input") |
| |
|
| | initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] |
| | orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) |
| |
|
| | def _prep_inputs(t): |
| | |
| | if(not low_mem): |
| | if not sum(t.shape[:no_batch_dims]) == no_batch_dims: |
| | t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) |
| | t = t.reshape(-1, *t.shape[no_batch_dims:]) |
| | else: |
| | t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) |
| | return t |
| |
|
| | prepped_inputs = tensor_tree_map(_prep_inputs, inputs) |
| |
|
| | flat_batch_dim = 1 |
| | for d in orig_batch_dims: |
| | flat_batch_dim *= d |
| |
|
| | no_chunks = flat_batch_dim // chunk_size + ( |
| | flat_batch_dim % chunk_size != 0 |
| | ) |
| |
|
| | i = 0 |
| | out = None |
| | for _ in range(no_chunks): |
| | |
| | if(not low_mem): |
| | select_chunk = ( |
| | lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t |
| | ) |
| | else: |
| | select_chunk = ( |
| | partial( |
| | _chunk_slice, |
| | flat_start=i, |
| | flat_end=min(flat_batch_dim, i + chunk_size), |
| | no_batch_dims=len(orig_batch_dims) |
| | ) |
| | ) |
| |
|
| | chunks = tensor_tree_map(select_chunk, prepped_inputs) |
| |
|
| | |
| | output_chunk = layer(**chunks) |
| |
|
| | |
| | if out is None: |
| | allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]) |
| | out = tensor_tree_map(allocate, output_chunk) |
| |
|
| | |
| | out_type = type(output_chunk) |
| | if out_type is dict: |
| | def assign(d1, d2): |
| | for k, v in d1.items(): |
| | if type(v) is dict: |
| | assign(v, d2[k]) |
| | else: |
| | v[i : i + chunk_size] = d2[k] |
| |
|
| | assign(out, output_chunk) |
| | elif out_type is tuple: |
| | for x1, x2 in zip(out, output_chunk): |
| | x1[i : i + chunk_size] = x2 |
| | elif out_type is torch.Tensor: |
| | out[i : i + chunk_size] = output_chunk |
| | else: |
| | raise ValueError("Not supported") |
| |
|
| | i += chunk_size |
| |
|
| | reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) |
| | out = tensor_tree_map(reshape, out) |
| |
|
| | return out |
| |
|