Spaces:
Sleeping
Sleeping
File size: 2,779 Bytes
2b7aae2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """
TorchScript model predictor base class.
Loads and runs inference on TorchScript (.pt) models.
"""
import os
import torch
import yaml
import logging
def resolve_model_path(model_path):
"""Resolve model path from directory using .state.yaml if needed.
If model_path is a file, return as-is.
If model_path is a directory, read .state.yaml 'best' field to find the model file.
"""
if os.path.isfile(model_path):
return model_path
if not os.path.isdir(model_path):
raise ValueError(f'Model path not found: {model_path}')
state_file = os.path.join(model_path, '.state.yaml')
if os.path.exists(state_file):
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
best = state.get('best')
if best:
# Strip extension from best to get base name
base = best
for ext in ['.chkpt', '.pt']:
if best.endswith(ext):
base = best[:-len(ext)]
break
# Prefer .pt (TorchScript) over .chkpt (checkpoint)
for ext in ['.pt', '.chkpt', '']:
candidate = os.path.join(model_path, base + ext)
if os.path.isfile(candidate):
return candidate
# Fallback: find .pt files first, then any model file
pt_files = [f for f in os.listdir(model_path)
if f.endswith('.pt') and not f.startswith('.')]
if len(pt_files) == 1:
return os.path.join(model_path, pt_files[0])
model_files = [f for f in os.listdir(model_path)
if f.endswith(('.pt', '.chkpt')) and not f.startswith('.')]
if len(model_files) == 1:
return os.path.join(model_path, model_files[0])
raise ValueError(f'Cannot resolve model file in directory: {model_path}')
class TorchScriptPredictor:
"""Base class for TorchScript model predictors."""
def __init__(self, model_path, device='cuda'):
self.device = device
resolved_path = resolve_model_path(model_path)
self.model = self._load_model(resolved_path)
logging.info('TorchScript model loaded: %s (device: %s)', resolved_path, device)
def _load_model(self, model_path):
"""Load TorchScript model from file."""
model = torch.jit.load(model_path, map_location=self.device)
model.eval()
return model
def preprocess(self, images):
"""
Preprocess images before inference.
Override in subclass.
images: list of numpy arrays
returns: torch.Tensor
"""
raise NotImplementedError
def postprocess(self, outputs):
"""
Postprocess model outputs.
Override in subclass.
outputs: torch.Tensor
returns: processed results
"""
raise NotImplementedError
def predict(self, streams, **kwargs):
"""
Run prediction on input streams.
streams: list of image byte buffers
yields: prediction results
"""
raise NotImplementedError
def run_inference(self, batch):
"""Run model inference with no_grad context."""
with torch.no_grad():
return self.model(batch)
|