File size: 2,779 Bytes
2b7aae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
TorchScript model predictor base class.
Loads and runs inference on TorchScript (.pt) models.
"""

import os
import torch
import yaml
import logging


def resolve_model_path(model_path):
	"""Resolve model path from directory using .state.yaml if needed.

	If model_path is a file, return as-is.
	If model_path is a directory, read .state.yaml 'best' field to find the model file.
	"""
	if os.path.isfile(model_path):
		return model_path

	if not os.path.isdir(model_path):
		raise ValueError(f'Model path not found: {model_path}')

	state_file = os.path.join(model_path, '.state.yaml')
	if os.path.exists(state_file):
		with open(state_file, 'r') as f:
			state = yaml.safe_load(f)

		best = state.get('best')
		if best:
			# Strip extension from best to get base name
			base = best
			for ext in ['.chkpt', '.pt']:
				if best.endswith(ext):
					base = best[:-len(ext)]
					break

			# Prefer .pt (TorchScript) over .chkpt (checkpoint)
			for ext in ['.pt', '.chkpt', '']:
				candidate = os.path.join(model_path, base + ext)
				if os.path.isfile(candidate):
					return candidate

	# Fallback: find .pt files first, then any model file
	pt_files = [f for f in os.listdir(model_path)
				if f.endswith('.pt') and not f.startswith('.')]
	if len(pt_files) == 1:
		return os.path.join(model_path, pt_files[0])

	model_files = [f for f in os.listdir(model_path)
				   if f.endswith(('.pt', '.chkpt')) and not f.startswith('.')]
	if len(model_files) == 1:
		return os.path.join(model_path, model_files[0])

	raise ValueError(f'Cannot resolve model file in directory: {model_path}')


class TorchScriptPredictor:
	"""Base class for TorchScript model predictors."""

	def __init__(self, model_path, device='cuda'):
		self.device = device
		resolved_path = resolve_model_path(model_path)
		self.model = self._load_model(resolved_path)
		logging.info('TorchScript model loaded: %s (device: %s)', resolved_path, device)

	def _load_model(self, model_path):
		"""Load TorchScript model from file."""
		model = torch.jit.load(model_path, map_location=self.device)
		model.eval()
		return model

	def preprocess(self, images):
		"""
		Preprocess images before inference.
		Override in subclass.
		images: list of numpy arrays
		returns: torch.Tensor
		"""
		raise NotImplementedError

	def postprocess(self, outputs):
		"""
		Postprocess model outputs.
		Override in subclass.
		outputs: torch.Tensor
		returns: processed results
		"""
		raise NotImplementedError

	def predict(self, streams, **kwargs):
		"""
		Run prediction on input streams.
		streams: list of image byte buffers
		yields: prediction results
		"""
		raise NotImplementedError

	def run_inference(self, batch):
		"""Run model inference with no_grad context."""
		with torch.no_grad():
			return self.model(batch)