| import gc |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| from datasets import load_dataset |
| from tqdm import tqdm |
| from solution import predict_wireframe |
|
|
|
|
| def empty_solution(): |
| """Return a minimal valid solution in case of an error.""" |
| return np.zeros((2, 3)), [] |
|
|
|
|
| def main(): |
| """ |
| Main script for the S23DR 2025 Challenge. |
| This script loads the test dataset using the competition's specific |
| method, runs the prediction pipeline, and saves the results. |
| """ |
| print("------------ Setting up data paths ------------") |
| |
| data_path = Path('/tmp/data') |
|
|
| print("------------ Loading dataset ------------") |
| |
| |
| data_files = { |
| "validation": [str(p) for p in data_path.rglob('*public*/**/*.tar')], |
| "test": [str(p) for p in data_path.rglob('*private*/**/*.tar')], |
| } |
| print(f"Found data files: {data_files}") |
|
|
| dataset = load_dataset( |
| str(data_path / 'hoho25k_test_x.py'), |
| data_files=data_files, |
| trust_remote_code=True, |
| writer_batch_size=100, |
| ) |
| print(f"Dataset loaded successfully: {dataset}") |
|
|
| print('------------ Starting prediction loop ---------------') |
| solution = [] |
| for subset_name in dataset.keys(): |
| print(f"Predicting for subset: {subset_name}") |
| for i, entry in enumerate(tqdm(dataset[subset_name], desc=f"Processing {subset_name}")): |
| try: |
| |
| pred_vertices, pred_edges = predict_wireframe(entry) |
| except Exception as e: |
| |
| print(f"Error processing sample {entry.get('order_id', 'UNKNOWN')}: {e}") |
| pred_vertices, pred_edges = empty_solution() |
|
|
| |
| solution.append( |
| { |
| 'order_id': entry['order_id'], |
| 'wf_vertices': pred_vertices.tolist(), |
| 'wf_edges': pred_edges, |
| } |
| ) |
|
|
| |
| if (i + 1) % 50 == 0: |
| gc.collect() |
|
|
| print('------------ Saving results ---------------') |
| sub = pd.DataFrame(solution, columns=["order_id", "wf_vertices", "wf_edges"]) |
| sub.to_parquet("submission.parquet", index=False) |
| print("------------ Done ------------") |
|
|
|
|
| if __name__ == "__main__": |
| main() |