File size: 5,190 Bytes
0ff8bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Shared notation parsing utilities for PSI2.

Parses Einstein-like notation strings (e.g. "rgb0,c01->rgb1") into
structured element dicts with modality class references via GlobalRegistry.
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import List, Dict, Set, Tuple, Any, Type, Union

from .modalities import (
    GlobalRegistry, SerializableModality, NodeModality, EdgeModality,
)


@dataclass
class ParsedNotation:
    name: str
    modalities: Set[Type[SerializableModality]]
    nodes: Dict[Type[SerializableModality], List[int]]
    edges: Dict[Type[SerializableModality], Union[List[int], List[Tuple[int, int]]]]
    elements: List[Tuple[Type[SerializableModality], Union[int, Tuple[int, int]]]]
    element_names: List[str]


def parse_element(element_str: str) -> Dict[str, Any]:
    """
    Parse a single element string.

    Format examples:
        rgb0 -> RGB frame at index 0
        f01  -> Flow from frame 0 to frame 1
        d0   -> Disparity at frame 0
        c01  -> Camera pose from frame 0 to frame 1

    Returns:
        Dict with keys: modality, modality_cls, indices, raw
    """
    modality_name = re.match(r'^[^0-9]*', element_str)
    if not modality_name:
        raise ValueError(f"Invalid element string: {element_str}")
    modality_name = modality_name.group()

    modality_cls = GlobalRegistry.get(modality_name)
    if modality_cls is None:
        raise ValueError(f"Unknown modality: {element_str} (valid names: {GlobalRegistry.names_str()})")

    times_str = element_str.removeprefix(modality_name)
    if times_str and not times_str.isdigit():
        match = re.match(r"^(\d+)(.+)$", times_str)
        if match:
            digits, trailing = match.groups()
            raise ValueError(
                f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only. "
                f"Did you mean '{modality_name}{digits},{trailing}'?"
            )
        raise ValueError(
            f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only."
        )

    times = [x for x in times_str]

    if issubclass(modality_cls, NodeModality):
        if len(times) != 1:
            raise ValueError(f"{modality_name} is a Node and expects 1 time index: {element_str}")
    elif issubclass(modality_cls, EdgeModality):
        if len(times) != 2:
            raise ValueError(f"{modality_name} is an Edge and expects 2 time indices: {element_str}")
    else:
        if times != ['']:
            raise ValueError(
                f"{modality_name} does not take time indices, but got {times}: {element_str}"
            )

    if times != ['']:
        times = list(map(int, times))
    else:
        times = []

    return {"modality": modality_name, "modality_cls": modality_cls, "indices": times, "raw": element_str}


def parse_notation(notation: str) -> List[Dict[str, Any]]:
    """Parse a notation string into a list of element dictionaries."""
    elements = []
    for part in notation.split(","):
        part = part.strip()
        if not part:
            continue
        elements.append(parse_element(part))
    return elements


def analyze_notation(name: str, elements: List[Dict[str, Any]]) -> ParsedNotation:
    """Analyze parsed notation elements to determine requirements."""
    nodes = {}
    edges = {}
    out_elements = []
    for e in elements:
        modality_cls, indices = e['modality_cls'], e['indices']
        if len(indices) == 2:
            edges.setdefault(modality_cls, set()).add(tuple(indices))
            out_elements.append((modality_cls, tuple(indices)))
        elif len(indices) == 1:
            nodes.setdefault(modality_cls, set()).add(indices[0])
            out_elements.append((modality_cls, indices[0]))
        else:
            out_elements.append((modality_cls, 0))
    return ParsedNotation(
        name=name,
        modalities=set(e['modality_cls'] for e in elements),
        nodes={k: list(v) for k, v in nodes.items()},
        edges={k: list(v) for k, v in edges.items()},
        elements=out_elements,
        element_names=[e['raw'] for e in elements],
    )


def parse_full_notation(notation: str) -> Tuple[List[Dict], List[Dict], ParsedNotation, ParsedNotation]:
    """
    Parse full notation with input/output split on '->'.

    Returns:
        input_elements: List of parsed element dicts
        output_elements: List of parsed element dicts
        input_analysis: ParsedNotation for inputs
        output_analysis: ParsedNotation for outputs
    """
    if "->" not in notation:
        raise ValueError(f"Notation must contain '->' to separate inputs from outputs: {notation}")

    input_str, output_str = notation.split("->", 1)

    input_elements = parse_notation(input_str)
    output_elements = parse_notation(output_str)

    if not output_elements:
        raise ValueError(f"No output elements specified in notation: {notation}")

    input_analysis = analyze_notation("input", input_elements)
    output_analysis = analyze_notation("output", output_elements)

    return input_elements, output_elements, input_analysis, output_analysis