"""S23DR 2026 submission: learned wireframe prediction from fused point clouds. Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe """ import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' import subprocess import sys def install_if_missing(package): try: __import__(package.split("==")[0]) except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", package]) install_if_missing("scipy") install_if_missing("pandas") from pathlib import Path from tqdm import tqdm import json import sys import time import numpy as np import torch def empty_solution(): return np.zeros((2, 3)), [(0, 1)] # --------------------------------------------------------------------------- # Point fusion + sampling (from cache_scenes.py / make_sampled_cache.py) # --------------------------------------------------------------------------- # Add our package to path SCRIPT_DIR = Path(__file__).resolve().parent sys.path.insert(0, str(SCRIPT_DIR)) from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig from s23dr_2026_example.cache_scenes import ( _compute_group_and_class, _compute_smart_center_scale, ) from s23dr_2026_example.make_sampled_cache import _priority_sample # Tokenizer / model imports from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig from s23dr_2026_example.model import EdgeDepthSegmentsModel from s23dr_2026_example.segment_postprocess import merge_vertices_iterative from s23dr_2026_example.varifold import segments_to_vertices_edges from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal SEQ_LEN = 4096 COLMAP_QUOTA = 3072 DEPTH_QUOTA = 1024 CONF_THRESH = 0.5 MERGE_THRESH = 0.4 SNAP_RADIUS = 0.5 def fuse_and_sample(sample, cfg, rng): """Run point fusion + priority sampling on a raw dataset sample. Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc. ready for model inference. Returns None if fusion fails. """ try: scene = build_compact_scene(sample, cfg, rng) except Exception as e: print(f" Fusion failed: {e}") return None xyz = scene["xyz"] source = scene["source"] if len(xyz) < 10: return None # Compute group_id and class_id (same as cache_scenes.py) behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16)) group_id, class_id = _compute_group_and_class( scene["visible_src"], scene["visible_id"], behind_id, source) # Normalize center, scale = _compute_smart_center_scale(xyz, source) # Priority sample indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA) xyz_norm = (xyz[indices] - center) / scale result = { "xyz_norm": xyz_norm.astype(np.float32), "class_id": class_id[indices].astype(np.int64), "source": source[indices].astype(np.int64), "mask": mask, "center": center.astype(np.float32), "scale": np.float32(scale), } # Optional fields if "behind_gest_id" in scene: behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None) result["behind"] = behind.astype(np.int64) if "n_views_voted" in scene: result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32) if "vote_frac" in scene: result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32) # Visible src/id for snap post-processing result["visible_src"] = scene["visible_src"][indices].astype(np.int64) result["visible_id"] = scene["visible_id"][indices].astype(np.int64) return result def load_model(checkpoint_path, device): """Load model from checkpoint.""" ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) args = ckpt.get("args", {}) norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None seq_cfg = EdgeDepthSequenceConfig( seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA) model = EdgeDepthSegmentsModel( seq_cfg=seq_cfg, segments=args.get("segments", 64), hidden=args.get("hidden", 256), num_heads=args.get("num_heads", 4), kv_heads_cross=args.get("kv_heads_cross", 2), kv_heads_self=args.get("kv_heads_self", 2), dim_feedforward=args.get("ff", 1024), dropout=args.get("dropout", 0.1), latent_tokens=args.get("latent_tokens", 256), latent_layers=args.get("latent_layers", 7), decoder_layers=args.get("decoder_layers", 3), cross_attn_interval=args.get("cross_attn_interval", 4), norm_class=norm_class, activation=args.get("activation", "gelu"), segment_conf=args.get("segment_conf", True), behind_emb_dim=args.get("behind_emb_dim", 8), use_vote_features=args.get("vote_features", True), arch=args.get("arch", "perceiver"), encoder_layers=args.get("encoder_layers", 4), pre_encoder_layers=args.get("pre_encoder_layers", 0), segment_param=args.get("segment_param", "midpoint_dir_len"), qk_norm=args.get("qk_norm", True), ).to(device) # Handle torch.compile _orig_mod prefix state = ckpt["model"] fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v for k, v in state.items()} model.load_state_dict(fixed, strict=True) model.eval() return model def build_tokens_single(sample_dict, model, device): """Build token tensor for a single sample (no DataLoader).""" xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device) cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device) src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device) masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device) B, T, _ = xyz.shape tok = model.tokenizer fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \ if tok.pos_enc is not None else xyz.new_zeros(B, T, 0) parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))] if tok.behind_emb_dim > 0: if "behind" in sample_dict: beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device) else: beh = xyz.new_zeros(B, T, dtype=torch.long) parts.append(tok.behind_emb(beh)) if tok.use_vote_features: if "n_views_voted" in sample_dict and "vote_frac" in sample_dict: nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1) vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1) parts.extend([nv, vf]) else: parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)]) tokens = torch.cat(parts, dim=-1) return tokens, masks def predict_sample(sample_dict, model, device): """Run model inference + post-processing on a fused sample. Returns (vertices, edges) in world space. """ tokens, masks = build_tokens_single(sample_dict, model, device) scale = float(sample_dict["scale"]) center = sample_dict["center"] with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, enabled=(device.type == 'cuda')): out = model.forward_tokens(tokens, masks) segs = out["segments"][0].float().cpu() conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None # Confidence filter if conf is not None: keep = conf > CONF_THRESH segs = segs[keep] if len(segs) < 1: return empty_solution() # To world space segs_world = segs.numpy() * scale + center # Vertices + edges from segments pv, pe = segments_to_vertices_edges(torch.tensor(segs_world)) pv, pe = pv.numpy(), np.array(pe, dtype=np.int32) # Merge pv, pe = merge_vertices_iterative(pv, pe) # Snap to point cloud xyz_norm = sample_dict["xyz_norm"] mask = sample_dict["mask"] cid = sample_dict["class_id"] xyz_world = xyz_norm[mask] * scale + center cid_valid = cid[mask] pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS) # Horizontal snap pv = snap_horizontal(pv, pe) if len(pv) < 2 or len(pe) < 1: return empty_solution() edges = [(int(a), int(b)) for a, b in pe] return pv, edges def hybrid_merge(pred_v, pred_e, track_v, track_e, merge_radius=0.8): if len(track_v) == 0: return pred_v, pred_e pred_v = np.array(pred_v) if isinstance(pred_v, list) else pred_v track_v = np.array(track_v) # Filter out NaNs and Infs from track_v valid_mask = np.isfinite(track_v).all(axis=1) if not valid_mask.all(): valid_indices = np.where(valid_mask)[0] idx_map = {old_idx: new_idx for new_idx, old_idx in enumerate(valid_indices)} track_v = track_v[valid_mask] new_track_e = [] for u, v in track_e: if u in idx_map and v in idx_map: new_track_e.append((idx_map[u], idx_map[v])) track_e = new_track_e if len(track_v) == 0: return pred_v, pred_e # We will append track vertices that are NOT close to any pred_v if len(pred_v) > 0: from scipy.spatial import cKDTree tree = cKDTree(pred_v) dists, indices = tree.query(track_v, k=1) else: dists = np.full(len(track_v), np.inf) indices = np.zeros(len(track_v), dtype=int) # Map track vertex indices to final vertex indices track_to_final = {} new_vertices = [] for i, (d, idx) in enumerate(zip(dists, indices)): if d <= merge_radius and len(pred_v) > 0: # Map to existing pred_v track_to_final[i] = int(idx) else: # Add as new vertex track_to_final[i] = len(pred_v) + len(new_vertices) new_vertices.append(track_v[i]) final_v = list(pred_v) + new_vertices final_e = list(pred_e) # Add track edges, mapping their indices existing_edges = set() for u, v in final_e: existing_edges.add((min(u, v), max(u, v))) for u_t, v_t in track_e: u_f = track_to_final.get(u_t) v_f = track_to_final.get(v_t) if u_f is not None and v_f is not None and u_f != v_f: e = (min(u_f, v_f), max(u_f, v_f)) if e not in existing_edges: final_e.append(e) existing_edges.add(e) return np.array(final_v), final_e # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- if __name__ == "__main__": t_start = time.time() # Load params param_path = Path("params.json") with param_path.open() as f: params = json.load(f) print(f"Competition: {params.get('competition_id', '?')}") print(f"Dataset: {params.get('dataset', '?')}") # Load test data data_path = Path("/tmp/data") if not data_path.exists(): from huggingface_hub import snapshot_download snapshot_download( repo_id=params["dataset"], local_dir="/tmp/data", repo_type="dataset", ) from datasets import load_dataset data_files = {} public_tars = sorted([str(p) for p in data_path.rglob('*public*/**/*.tar')]) private_tars = sorted([str(p) for p in data_path.rglob('*private*/**/*.tar')]) if public_tars: data_files["validation"] = public_tars if private_tars: data_files["test"] = private_tars print(f"Data files: {data_files}") loading_scripts = sorted(data_path.rglob('*.py')) loading_script = str(loading_scripts[0]) if loading_scripts else str(data_path) dataset = load_dataset( loading_script, data_files=data_files, trust_remote_code=True, writer_batch_size=100, ) print(f"Loaded: {dataset}") # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") checkpoint_path = SCRIPT_DIR / "checkpoint.pt" # Auto-download checkpoint if missing or just an LFS pointer if not checkpoint_path.exists() or checkpoint_path.stat().st_size < 1000: print("Downloading checkpoint.pt from upstream learned baseline...") import urllib.request ckpt_url = "https://huggingface.co/jacklangerman/s23dr-2026-submission/resolve/main/checkpoint.pt" urllib.request.urlretrieve(ckpt_url, str(checkpoint_path)) print("Downloaded checkpoint.pt") model = load_model(checkpoint_path, device) print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params") # Point fusion config cfg = FuserConfig() rng = np.random.RandomState(2718) # Process all samples solution = [] total_samples = sum(len(dataset[s]) for s in dataset) processed = 0 for subset_name in dataset: print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...") for sample in tqdm(dataset[subset_name], desc=subset_name): order_id = sample["order_id"] # Fuse + sample fused = fuse_and_sample(sample, cfg, rng) if fused is None: pred_v, pred_e = empty_solution() else: try: pred_v, pred_e = predict_sample(fused, model, device) if torch.cuda.is_available(): torch.cuda.empty_cache() # Apply handcrafted triangulation tracking to catch missing corners/edges # try: # from triangulation import predict_wireframe_tracks # # Force TRACK_MIN_VIEWS = 2 for aggressive recall # track_v, track_e = predict_wireframe_tracks(sample, min_views=2) # # pred_v, pred_e = hybrid_merge(pred_v, pred_e, track_v, track_e, merge_radius=0.8) # except Exception as track_e_err: # print(f" Track ensemble failed for {order_id}: {track_e_err}") except Exception as e: import traceback print(f" Predict failed for {order_id}:\n{traceback.format_exc()}") pred_v, pred_e = empty_solution() if torch.cuda.is_available(): torch.cuda.empty_cache() solution.append({ "order_id": order_id, "wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v, "wf_edges": [(int(a), int(b)) for a, b in pred_e], }) processed += 1 if processed % 50 == 0: elapsed = time.time() - t_start rate = elapsed / processed remaining = (total_samples - processed) * rate print(f" [{processed}/{total_samples}] " f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining") # Save output_path = Path(params.get('output_path', '.')) with open(output_path / "submission.json", "w") as f: json.dump(solution, f) try: import pandas as pd sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) sub.to_parquet(output_path / "submission.parquet") except Exception as e: print(f"Failed to write parquet: {e}") elapsed = time.time() - t_start print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)") print(f"Saved submission.json ({len(solution)} entries)")