psi0_5 / notation.py
klemenk's picture
Upload notation.py with huggingface_hub
0ff8bab verified
"""
Shared notation parsing utilities for PSI2.
Parses Einstein-like notation strings (e.g. "rgb0,c01->rgb1") into
structured element dicts with modality class references via GlobalRegistry.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import List, Dict, Set, Tuple, Any, Type, Union
from .modalities import (
GlobalRegistry, SerializableModality, NodeModality, EdgeModality,
)
@dataclass
class ParsedNotation:
name: str
modalities: Set[Type[SerializableModality]]
nodes: Dict[Type[SerializableModality], List[int]]
edges: Dict[Type[SerializableModality], Union[List[int], List[Tuple[int, int]]]]
elements: List[Tuple[Type[SerializableModality], Union[int, Tuple[int, int]]]]
element_names: List[str]
def parse_element(element_str: str) -> Dict[str, Any]:
"""
Parse a single element string.
Format examples:
rgb0 -> RGB frame at index 0
f01 -> Flow from frame 0 to frame 1
d0 -> Disparity at frame 0
c01 -> Camera pose from frame 0 to frame 1
Returns:
Dict with keys: modality, modality_cls, indices, raw
"""
modality_name = re.match(r'^[^0-9]*', element_str)
if not modality_name:
raise ValueError(f"Invalid element string: {element_str}")
modality_name = modality_name.group()
modality_cls = GlobalRegistry.get(modality_name)
if modality_cls is None:
raise ValueError(f"Unknown modality: {element_str} (valid names: {GlobalRegistry.names_str()})")
times_str = element_str.removeprefix(modality_name)
if times_str and not times_str.isdigit():
match = re.match(r"^(\d+)(.+)$", times_str)
if match:
digits, trailing = match.groups()
raise ValueError(
f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only. "
f"Did you mean '{modality_name}{digits},{trailing}'?"
)
raise ValueError(
f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only."
)
times = [x for x in times_str]
if issubclass(modality_cls, NodeModality):
if len(times) != 1:
raise ValueError(f"{modality_name} is a Node and expects 1 time index: {element_str}")
elif issubclass(modality_cls, EdgeModality):
if len(times) != 2:
raise ValueError(f"{modality_name} is an Edge and expects 2 time indices: {element_str}")
else:
if times != ['']:
raise ValueError(
f"{modality_name} does not take time indices, but got {times}: {element_str}"
)
if times != ['']:
times = list(map(int, times))
else:
times = []
return {"modality": modality_name, "modality_cls": modality_cls, "indices": times, "raw": element_str}
def parse_notation(notation: str) -> List[Dict[str, Any]]:
"""Parse a notation string into a list of element dictionaries."""
elements = []
for part in notation.split(","):
part = part.strip()
if not part:
continue
elements.append(parse_element(part))
return elements
def analyze_notation(name: str, elements: List[Dict[str, Any]]) -> ParsedNotation:
"""Analyze parsed notation elements to determine requirements."""
nodes = {}
edges = {}
out_elements = []
for e in elements:
modality_cls, indices = e['modality_cls'], e['indices']
if len(indices) == 2:
edges.setdefault(modality_cls, set()).add(tuple(indices))
out_elements.append((modality_cls, tuple(indices)))
elif len(indices) == 1:
nodes.setdefault(modality_cls, set()).add(indices[0])
out_elements.append((modality_cls, indices[0]))
else:
out_elements.append((modality_cls, 0))
return ParsedNotation(
name=name,
modalities=set(e['modality_cls'] for e in elements),
nodes={k: list(v) for k, v in nodes.items()},
edges={k: list(v) for k, v in edges.items()},
elements=out_elements,
element_names=[e['raw'] for e in elements],
)
def parse_full_notation(notation: str) -> Tuple[List[Dict], List[Dict], ParsedNotation, ParsedNotation]:
"""
Parse full notation with input/output split on '->'.
Returns:
input_elements: List of parsed element dicts
output_elements: List of parsed element dicts
input_analysis: ParsedNotation for inputs
output_analysis: ParsedNotation for outputs
"""
if "->" not in notation:
raise ValueError(f"Notation must contain '->' to separate inputs from outputs: {notation}")
input_str, output_str = notation.split("->", 1)
input_elements = parse_notation(input_str)
output_elements = parse_notation(output_str)
if not output_elements:
raise ValueError(f"No output elements specified in notation: {notation}")
input_analysis = analyze_notation("input", input_elements)
output_analysis = analyze_notation("output", output_elements)
return input_elements, output_elements, input_analysis, output_analysis