Image-to-Video
Transformers
psi
feature-extraction
world-model
video-generation
multimodal
physical-world-model
controllable-generation
custom_code
Instructions to use StanfordNeuroAILab/psi0_5 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use StanfordNeuroAILab/psi0_5 with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("StanfordNeuroAILab/psi0_5", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,190 Bytes
0ff8bab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
Shared notation parsing utilities for PSI2.
Parses Einstein-like notation strings (e.g. "rgb0,c01->rgb1") into
structured element dicts with modality class references via GlobalRegistry.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import List, Dict, Set, Tuple, Any, Type, Union
from .modalities import (
GlobalRegistry, SerializableModality, NodeModality, EdgeModality,
)
@dataclass
class ParsedNotation:
name: str
modalities: Set[Type[SerializableModality]]
nodes: Dict[Type[SerializableModality], List[int]]
edges: Dict[Type[SerializableModality], Union[List[int], List[Tuple[int, int]]]]
elements: List[Tuple[Type[SerializableModality], Union[int, Tuple[int, int]]]]
element_names: List[str]
def parse_element(element_str: str) -> Dict[str, Any]:
"""
Parse a single element string.
Format examples:
rgb0 -> RGB frame at index 0
f01 -> Flow from frame 0 to frame 1
d0 -> Disparity at frame 0
c01 -> Camera pose from frame 0 to frame 1
Returns:
Dict with keys: modality, modality_cls, indices, raw
"""
modality_name = re.match(r'^[^0-9]*', element_str)
if not modality_name:
raise ValueError(f"Invalid element string: {element_str}")
modality_name = modality_name.group()
modality_cls = GlobalRegistry.get(modality_name)
if modality_cls is None:
raise ValueError(f"Unknown modality: {element_str} (valid names: {GlobalRegistry.names_str()})")
times_str = element_str.removeprefix(modality_name)
if times_str and not times_str.isdigit():
match = re.match(r"^(\d+)(.+)$", times_str)
if match:
digits, trailing = match.groups()
raise ValueError(
f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only. "
f"Did you mean '{modality_name}{digits},{trailing}'?"
)
raise ValueError(
f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only."
)
times = [x for x in times_str]
if issubclass(modality_cls, NodeModality):
if len(times) != 1:
raise ValueError(f"{modality_name} is a Node and expects 1 time index: {element_str}")
elif issubclass(modality_cls, EdgeModality):
if len(times) != 2:
raise ValueError(f"{modality_name} is an Edge and expects 2 time indices: {element_str}")
else:
if times != ['']:
raise ValueError(
f"{modality_name} does not take time indices, but got {times}: {element_str}"
)
if times != ['']:
times = list(map(int, times))
else:
times = []
return {"modality": modality_name, "modality_cls": modality_cls, "indices": times, "raw": element_str}
def parse_notation(notation: str) -> List[Dict[str, Any]]:
"""Parse a notation string into a list of element dictionaries."""
elements = []
for part in notation.split(","):
part = part.strip()
if not part:
continue
elements.append(parse_element(part))
return elements
def analyze_notation(name: str, elements: List[Dict[str, Any]]) -> ParsedNotation:
"""Analyze parsed notation elements to determine requirements."""
nodes = {}
edges = {}
out_elements = []
for e in elements:
modality_cls, indices = e['modality_cls'], e['indices']
if len(indices) == 2:
edges.setdefault(modality_cls, set()).add(tuple(indices))
out_elements.append((modality_cls, tuple(indices)))
elif len(indices) == 1:
nodes.setdefault(modality_cls, set()).add(indices[0])
out_elements.append((modality_cls, indices[0]))
else:
out_elements.append((modality_cls, 0))
return ParsedNotation(
name=name,
modalities=set(e['modality_cls'] for e in elements),
nodes={k: list(v) for k, v in nodes.items()},
edges={k: list(v) for k, v in edges.items()},
elements=out_elements,
element_names=[e['raw'] for e in elements],
)
def parse_full_notation(notation: str) -> Tuple[List[Dict], List[Dict], ParsedNotation, ParsedNotation]:
"""
Parse full notation with input/output split on '->'.
Returns:
input_elements: List of parsed element dicts
output_elements: List of parsed element dicts
input_analysis: ParsedNotation for inputs
output_analysis: ParsedNotation for outputs
"""
if "->" not in notation:
raise ValueError(f"Notation must contain '->' to separate inputs from outputs: {notation}")
input_str, output_str = notation.split("->", 1)
input_elements = parse_notation(input_str)
output_elements = parse_notation(output_str)
if not output_elements:
raise ValueError(f"No output elements specified in notation: {notation}")
input_analysis = analyze_notation("input", input_elements)
output_analysis = analyze_notation("output", output_elements)
return input_elements, output_elements, input_analysis, output_analysis
|