Image-to-Video
Transformers
psi
feature-extraction
world-model
video-generation
multimodal
physical-world-model
controllable-generation
custom_code
Instructions to use StanfordNeuroAILab/psi0_5 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use StanfordNeuroAILab/psi0_5 with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("StanfordNeuroAILab/psi0_5", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Shared notation parsing utilities for PSI2. | |
| Parses Einstein-like notation strings (e.g. "rgb0,c01->rgb1") into | |
| structured element dicts with modality class references via GlobalRegistry. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass | |
| from typing import List, Dict, Set, Tuple, Any, Type, Union | |
| from .modalities import ( | |
| GlobalRegistry, SerializableModality, NodeModality, EdgeModality, | |
| ) | |
| class ParsedNotation: | |
| name: str | |
| modalities: Set[Type[SerializableModality]] | |
| nodes: Dict[Type[SerializableModality], List[int]] | |
| edges: Dict[Type[SerializableModality], Union[List[int], List[Tuple[int, int]]]] | |
| elements: List[Tuple[Type[SerializableModality], Union[int, Tuple[int, int]]]] | |
| element_names: List[str] | |
| def parse_element(element_str: str) -> Dict[str, Any]: | |
| """ | |
| Parse a single element string. | |
| Format examples: | |
| rgb0 -> RGB frame at index 0 | |
| f01 -> Flow from frame 0 to frame 1 | |
| d0 -> Disparity at frame 0 | |
| c01 -> Camera pose from frame 0 to frame 1 | |
| Returns: | |
| Dict with keys: modality, modality_cls, indices, raw | |
| """ | |
| modality_name = re.match(r'^[^0-9]*', element_str) | |
| if not modality_name: | |
| raise ValueError(f"Invalid element string: {element_str}") | |
| modality_name = modality_name.group() | |
| modality_cls = GlobalRegistry.get(modality_name) | |
| if modality_cls is None: | |
| raise ValueError(f"Unknown modality: {element_str} (valid names: {GlobalRegistry.names_str()})") | |
| times_str = element_str.removeprefix(modality_name) | |
| if times_str and not times_str.isdigit(): | |
| match = re.match(r"^(\d+)(.+)$", times_str) | |
| if match: | |
| digits, trailing = match.groups() | |
| raise ValueError( | |
| f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only. " | |
| f"Did you mean '{modality_name}{digits},{trailing}'?" | |
| ) | |
| raise ValueError( | |
| f"Invalid element '{element_str}': time indices for '{modality_name}' must be digits only." | |
| ) | |
| times = [x for x in times_str] | |
| if issubclass(modality_cls, NodeModality): | |
| if len(times) != 1: | |
| raise ValueError(f"{modality_name} is a Node and expects 1 time index: {element_str}") | |
| elif issubclass(modality_cls, EdgeModality): | |
| if len(times) != 2: | |
| raise ValueError(f"{modality_name} is an Edge and expects 2 time indices: {element_str}") | |
| else: | |
| if times != ['']: | |
| raise ValueError( | |
| f"{modality_name} does not take time indices, but got {times}: {element_str}" | |
| ) | |
| if times != ['']: | |
| times = list(map(int, times)) | |
| else: | |
| times = [] | |
| return {"modality": modality_name, "modality_cls": modality_cls, "indices": times, "raw": element_str} | |
| def parse_notation(notation: str) -> List[Dict[str, Any]]: | |
| """Parse a notation string into a list of element dictionaries.""" | |
| elements = [] | |
| for part in notation.split(","): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| elements.append(parse_element(part)) | |
| return elements | |
| def analyze_notation(name: str, elements: List[Dict[str, Any]]) -> ParsedNotation: | |
| """Analyze parsed notation elements to determine requirements.""" | |
| nodes = {} | |
| edges = {} | |
| out_elements = [] | |
| for e in elements: | |
| modality_cls, indices = e['modality_cls'], e['indices'] | |
| if len(indices) == 2: | |
| edges.setdefault(modality_cls, set()).add(tuple(indices)) | |
| out_elements.append((modality_cls, tuple(indices))) | |
| elif len(indices) == 1: | |
| nodes.setdefault(modality_cls, set()).add(indices[0]) | |
| out_elements.append((modality_cls, indices[0])) | |
| else: | |
| out_elements.append((modality_cls, 0)) | |
| return ParsedNotation( | |
| name=name, | |
| modalities=set(e['modality_cls'] for e in elements), | |
| nodes={k: list(v) for k, v in nodes.items()}, | |
| edges={k: list(v) for k, v in edges.items()}, | |
| elements=out_elements, | |
| element_names=[e['raw'] for e in elements], | |
| ) | |
| def parse_full_notation(notation: str) -> Tuple[List[Dict], List[Dict], ParsedNotation, ParsedNotation]: | |
| """ | |
| Parse full notation with input/output split on '->'. | |
| Returns: | |
| input_elements: List of parsed element dicts | |
| output_elements: List of parsed element dicts | |
| input_analysis: ParsedNotation for inputs | |
| output_analysis: ParsedNotation for outputs | |
| """ | |
| if "->" not in notation: | |
| raise ValueError(f"Notation must contain '->' to separate inputs from outputs: {notation}") | |
| input_str, output_str = notation.split("->", 1) | |
| input_elements = parse_notation(input_str) | |
| output_elements = parse_notation(output_str) | |
| if not output_elements: | |
| raise ValueError(f"No output elements specified in notation: {notation}") | |
| input_analysis = analyze_notation("input", input_elements) | |
| output_analysis = analyze_notation("output", output_elements) | |
| return input_elements, output_elements, input_analysis, output_analysis | |