| import io |
| import json |
| import numpy as np |
| import sys |
| import os |
| from datasets import load_dataset |
| from hoho2025.metric_helper import hss |
| import sklearn_submission |
|
|
| print("Loading dataset...") |
| dataset = load_dataset('usm3d/hoho22k_2026_trainval', split='train', streaming=True, trust_remote_code=True) |
|
|
| samples = [] |
| for idx, s in enumerate(dataset): |
| if idx >= 10: |
| break |
| samples.append(s) |
|
|
| scores = [] |
| for idx, sample in enumerate(samples): |
| print(f"Testing sample {idx}") |
| try: |
| pred_v, pred_e = sklearn_submission.predict_wireframe_sklearn(sample) |
| except Exception as e: |
| print(f"Error on sample {idx}: {e}") |
| pred_v, pred_e = np.zeros((2, 3)), [(0, 1)] |
| |
| gt_v = sample.get('wf_vertices') |
| gt_e = sample.get('wf_edges') |
| |
| if gt_v is None or gt_e is None: |
| print(f"Skipping sample {idx} due to missing ground truth.") |
| continue |
|
|
| res = hss(pred_v, pred_e, gt_v, gt_e) |
| scores.append(res.hss) |
| print(f"Sample {idx} HSS: {res.hss:.4f}") |
|
|
| if scores: |
| print(f"Average HSS: {sum(scores) / len(scores):.4f}") |
| else: |
| print("No valid scores.") |
|
|