| """ |
| ExecutionEncoder: Graph-Based Transformer for Execution Plan Encoding |
| |
| This module implements the transformer-based encoder that maps |
| (plan_graph, provenance_metadata) → z_e ∈ R^1024. |
| |
| The ExecutionEncoder is the complementary half of the JEPA dual-encoder architecture, |
| encoding proposed execution plans into the same latent space as governance policies |
| to enable energy-based security validation. |
| |
| Architecture: |
| - Graph Neural Network for tool-call dependency encoding |
| - Provenance-aware attention mechanism |
| - Scope metadata integration |
| - Differentiable for end-to-end training with energy functions |
| |
| References: |
| - Graph Attention Networks: https://arxiv.org/abs/1710.10903 |
| - Relational Graph Convolutional Networks: https://arxiv.org/abs/1703.06103 |
| """ |
|
|
| from enum import IntEnum |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from pydantic import BaseModel, Field, field_validator |
|
|
|
|
| class TrustTier(IntEnum): |
| """Trust levels for data provenance (as per Dawn Song's workstream).""" |
| INTERNAL = 1 |
| SIGNED_PARTNER = 2 |
| PUBLIC_WEB = 3 |
|
|
|
|
| class ToolCallNode(BaseModel): |
| """A single tool invocation in the execution plan graph.""" |
| tool_name: str = Field(..., min_length=1, description="Name of the tool being invoked") |
| arguments: dict[str, Any] = Field(default_factory=dict, description="Tool arguments") |
|
|
| |
| provenance_tier: TrustTier = Field(default=TrustTier.INTERNAL, description="Trust tier of instruction source") |
| provenance_hash: str | None = Field(default=None, description="Cryptographic hash of source") |
|
|
| |
| scope_volume: int = Field(default=1, ge=1, description="Data volume (rows, records, files)") |
| scope_sensitivity: int = Field(default=1, ge=1, le=5, description="Sensitivity level (1=public, 5=critical)") |
|
|
| |
| node_id: str = Field(..., description="Unique node identifier") |
|
|
| @field_validator('provenance_tier', mode='before') |
| @classmethod |
| def parse_trust_tier(cls, v): |
| """Parse trust tier from int or TrustTier.""" |
| if isinstance(v, int): |
| return TrustTier(v) |
| return v |
|
|
|
|
| class ExecutionPlan(BaseModel): |
| """Complete execution plan represented as a typed tool-call graph.""" |
| nodes: list[ToolCallNode] = Field(..., min_length=1, description="Tool invocation nodes") |
| edges: list[tuple[str, str]] = Field(default_factory=list, description="Data flow edges (src_id, dst_id)") |
|
|
| @field_validator('edges') |
| @classmethod |
| def validate_edges(cls, v, info): |
| """Ensure edge endpoints reference valid nodes.""" |
| if 'nodes' not in info.data: |
| return v |
|
|
| node_ids = {node.node_id for node in info.data['nodes']} |
| for src, dst in v: |
| if src not in node_ids or dst not in node_ids: |
| raise ValueError(f"Edge ({src}, {dst}) references non-existent node") |
| return v |
|
|
|
|
| class ProvenanceEmbedding(nn.Module): |
| """Embeds provenance metadata (trust tier + cryptographic hash).""" |
|
|
| def __init__(self, hidden_dim: int, num_tiers: int = 3): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.tier_embedding = nn.Embedding(num_tiers + 1, hidden_dim) |
| self.scope_projection = nn.Linear(2, hidden_dim) |
| self.fusion = nn.Linear(hidden_dim * 2, hidden_dim) |
|
|
| def forward( |
| self, |
| tier_indices: torch.Tensor, |
| scope_volume: torch.Tensor, |
| scope_sensitivity: torch.Tensor |
| ) -> torch.Tensor: |
| """Combine provenance tier and scope metadata.""" |
| tier_emb = self.tier_embedding(tier_indices) |
|
|
| |
| log_volume = torch.log1p(scope_volume.float()).unsqueeze(-1) |
| sensitivity = scope_sensitivity.float().unsqueeze(-1) |
| scope_features = torch.cat([log_volume, sensitivity], dim=-1) |
| scope_emb = self.scope_projection(scope_features) |
|
|
| combined = torch.cat([tier_emb, scope_emb], dim=-1) |
| return self.fusion(combined) |
|
|
|
|
| class GraphAttention(nn.Module): |
| """ |
| Graph Attention layer for encoding tool-call dependencies. |
| Implements message passing with edge-aware attention. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| num_heads: int = 8, |
| dropout: float = 0.1 |
| ): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.num_heads = num_heads |
| self.head_dim = hidden_dim // num_heads |
|
|
| assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" |
|
|
| self.q_proj = nn.Linear(hidden_dim, hidden_dim) |
| self.k_proj = nn.Linear(hidden_dim, hidden_dim) |
| self.v_proj = nn.Linear(hidden_dim, hidden_dim) |
| self.out_proj = nn.Linear(hidden_dim, hidden_dim) |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.scale = self.head_dim ** -0.5 |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| adjacency: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Apply graph attention. |
| |
| Args: |
| x: Node features [batch_size, num_nodes, hidden_dim] |
| adjacency: Adjacency matrix [batch_size, num_nodes, num_nodes] |
| 1 = edge exists, 0 = no edge |
| """ |
| batch_size, num_nodes, _ = x.shape |
|
|
| q = self.q_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
|
|
| |
| |
| mask = adjacency.unsqueeze(1) |
| eye = torch.eye(num_nodes, device=x.device).unsqueeze(0).unsqueeze(0) |
| mask = torch.maximum(mask, eye) |
|
|
| scores = scores.masked_fill(mask == 0, float('-inf')) |
| attn_weights = F.softmax(scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
|
|
| out = torch.matmul(attn_weights, v) |
| out = out.transpose(1, 2).contiguous().view(batch_size, num_nodes, self.hidden_dim) |
|
|
| return self.out_proj(out) |
|
|
|
|
| class GraphTransformerBlock(nn.Module): |
| """Transformer block with graph-aware attention.""" |
|
|
| def __init__( |
| self, |
| hidden_dim: int, |
| num_heads: int, |
| dropout: float = 0.1 |
| ): |
| super().__init__() |
| self.attention = GraphAttention(hidden_dim, num_heads, dropout) |
| self.norm1 = nn.LayerNorm(hidden_dim) |
|
|
| self.ffn = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 4, hidden_dim), |
| nn.Dropout(dropout) |
| ) |
| self.norm2 = nn.LayerNorm(hidden_dim) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| adjacency: torch.Tensor |
| ) -> torch.Tensor: |
| """Apply graph transformer block.""" |
| x = x + self.attention(self.norm1(x), adjacency) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class ExecutionEncoder(nn.Module): |
| """ |
| Graph-based transformer encoder mapping execution plans to z_e ∈ R^1024. |
| |
| Encodes: |
| - Tool invocation sequences |
| - Data flow dependencies (graph edges) |
| - Provenance metadata (trust tiers) |
| - Scope metadata (volume + sensitivity) |
| |
| Performance targets: |
| - Latency: <100ms on CPU (pairs with GovernanceEncoder's 98ms) |
| - Memory: <500MB |
| - Differentiable: Yes |
| """ |
|
|
| def __init__( |
| self, |
| latent_dim: int = 1024, |
| hidden_dim: int = 512, |
| num_layers: int = 4, |
| num_heads: int = 8, |
| max_nodes: int = 64, |
| dropout: float = 0.1, |
| vocab_size: int = 10000 |
| ): |
| super().__init__() |
|
|
| self.latent_dim = latent_dim |
| self.hidden_dim = hidden_dim |
| self.max_nodes = max_nodes |
|
|
| |
| self.token_embedding = nn.Embedding(vocab_size, hidden_dim) |
| self.position_embedding = nn.Embedding(max_nodes, hidden_dim) |
|
|
| |
| self.provenance_embedding = ProvenanceEmbedding(hidden_dim) |
|
|
| |
| self.layers = nn.ModuleList([ |
| GraphTransformerBlock(hidden_dim, num_heads, dropout) |
| for _ in range(num_layers) |
| ]) |
|
|
| |
| self.attention_pool = nn.Linear(hidden_dim, 1) |
| self.projection = nn.Sequential( |
| nn.Linear(hidden_dim, latent_dim * 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(latent_dim * 2, latent_dim), |
| nn.LayerNorm(latent_dim) |
| ) |
|
|
| self.input_norm = nn.LayerNorm(hidden_dim) |
|
|
| def _tokenize(self, text: str) -> int: |
| """ |
| Hash-based tokenization (v0.1.0). |
| |
| Future: Replace with BPE tokenizer (v0.2.0) to reduce collisions. |
| """ |
| return hash(text) % 10000 |
|
|
| def _create_adjacency_matrix( |
| self, |
| num_nodes: int, |
| edges: list[tuple[int, int]], |
| device: torch.device |
| ) -> torch.Tensor: |
| """Build adjacency matrix from edge list.""" |
| adjacency = torch.zeros(num_nodes, num_nodes, device=device) |
| for src, dst in edges: |
| if src < num_nodes and dst < num_nodes: |
| adjacency[src, dst] = 1 |
| return adjacency |
|
|
| def forward( |
| self, |
| plan: ExecutionPlan | dict[str, Any] |
| ) -> torch.Tensor: |
| """ |
| Encode execution plan into latent vector. |
| |
| Args: |
| plan: ExecutionPlan or dict conforming to ExecutionPlan schema |
| |
| Returns: |
| z_e: Latent vector [1, latent_dim] |
| """ |
| |
| if not isinstance(plan, ExecutionPlan): |
| plan = ExecutionPlan(**plan) |
|
|
| nodes = plan.nodes |
| edges = plan.edges |
|
|
| |
| node_id_to_idx = {node.node_id: i for i, node in enumerate(nodes)} |
| edge_indices = [(node_id_to_idx[src], node_id_to_idx[dst]) for src, dst in edges] |
|
|
| |
| num_nodes = min(len(nodes), self.max_nodes) |
| nodes = nodes[:num_nodes] |
|
|
| |
| tool_tokens = [] |
| for node in nodes: |
| |
| arg_str = ",".join(f"{k}={v}" for k, v in sorted(node.arguments.items())) |
| combined = f"{node.tool_name}({arg_str})" |
| tool_tokens.append(self._tokenize(combined)) |
|
|
| |
| if len(tool_tokens) < self.max_nodes: |
| tool_tokens.extend([0] * (self.max_nodes - len(tool_tokens))) |
|
|
| |
| device = next(self.parameters()).device |
|
|
| |
| token_ids = torch.tensor(tool_tokens[:self.max_nodes], device=device).unsqueeze(0) |
| position_ids = torch.arange(self.max_nodes, device=device).unsqueeze(0) |
|
|
| |
| tier_indices = torch.tensor([node.provenance_tier for node in nodes] + [0] * (self.max_nodes - num_nodes), device=device).unsqueeze(0) |
| scope_volume = torch.tensor([node.scope_volume for node in nodes] + [1] * (self.max_nodes - num_nodes), device=device).unsqueeze(0) |
| scope_sensitivity = torch.tensor([node.scope_sensitivity for node in nodes] + [1] * (self.max_nodes - num_nodes), device=device).unsqueeze(0) |
|
|
| |
| adjacency = self._create_adjacency_matrix( |
| self.max_nodes, |
| edge_indices, |
| device |
| ).unsqueeze(0) |
|
|
| |
| token_emb = self.token_embedding(token_ids) |
| pos_emb = self.position_embedding(position_ids) |
| prov_emb = self.provenance_embedding(tier_indices, scope_volume, scope_sensitivity) |
|
|
| |
| x = token_emb + pos_emb + prov_emb |
| x = self.input_norm(x) |
|
|
| |
| for layer in self.layers: |
| x = layer(x, adjacency) |
|
|
| |
| attn_scores = self.attention_pool(x).squeeze(-1) |
| attn_weights = F.softmax(attn_scores, dim=-1).unsqueeze(1) |
| pooled = torch.matmul(attn_weights, x).squeeze(1) |
|
|
| |
| z_e = self.projection(pooled) |
|
|
| return z_e |
|
|
| def encode_batch(self, plans: list[ExecutionPlan]) -> torch.Tensor: |
| """ |
| Batch encoding of multiple execution plans. |
| |
| Args: |
| plans: List of ExecutionPlan objects |
| |
| Returns: |
| z_e: Latent vectors [batch_size, latent_dim] |
| """ |
| latents = [self.forward(plan) for plan in plans] |
| return torch.cat(latents, dim=0) |
|
|
|
|
| def create_execution_encoder( |
| latent_dim: int = 1024, |
| checkpoint_path: str | None = None, |
| device: str = "cpu" |
| ) -> ExecutionEncoder: |
| """ |
| Factory function to create ExecutionEncoder. |
| |
| Args: |
| latent_dim: Dimension of output latent vector (must match GovernanceEncoder) |
| checkpoint_path: Optional path to pretrained weights |
| device: Device to load model on |
| |
| Returns: |
| Initialized ExecutionEncoder in inference mode |
| """ |
| model = ExecutionEncoder(latent_dim=latent_dim) |
|
|
| if checkpoint_path is not None: |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True)) |
|
|
| model = model.to(device) |
| model.training = False |
|
|
| return model |
|
|