| |
| import os.path as osp |
| import warnings |
| from typing import Any, Iterable |
|
|
| import numpy as np |
| import torch |
| from mmdet.models.builder import DETECTORS |
|
|
| from mmocr.models.textdet.detectors.single_stage_text_detector import \ |
| SingleStageTextDetector |
| from mmocr.models.textdet.detectors.text_detector_mixin import \ |
| TextDetectorMixin |
| from mmocr.models.textrecog.recognizer.encode_decode_recognizer import \ |
| EncodeDecodeRecognizer |
|
|
|
|
| def inference_with_session(sess, io_binding, input_name, output_names, |
| input_tensor): |
| device_type = input_tensor.device.type |
| device_id = input_tensor.device.index |
| device_id = 0 if device_id is None else device_id |
| io_binding.bind_input( |
| name=input_name, |
| device_type=device_type, |
| device_id=device_id, |
| element_type=np.float32, |
| shape=input_tensor.shape, |
| buffer_ptr=input_tensor.data_ptr()) |
| for name in output_names: |
| io_binding.bind_output(name) |
| sess.run_with_iobinding(io_binding) |
| pred = io_binding.copy_outputs_to_cpu() |
| return pred |
|
|
|
|
| @DETECTORS.register_module() |
| class ONNXRuntimeDetector(TextDetectorMixin, SingleStageTextDetector): |
| """The class for evaluating onnx file of detection.""" |
|
|
| def __init__(self, |
| onnx_file: str, |
| cfg: Any, |
| device_id: int, |
| show_score: bool = False): |
| if 'type' in cfg.model: |
| cfg.model.pop('type') |
| SingleStageTextDetector.__init__(self, **(cfg.model)) |
| TextDetectorMixin.__init__(self, show_score) |
| import onnxruntime as ort |
|
|
| |
| ort_custom_op_path = '' |
| try: |
| from mmcv.ops import get_onnxruntime_op_path |
| ort_custom_op_path = get_onnxruntime_op_path() |
| except (ImportError, ModuleNotFoundError): |
| warnings.warn('If input model has custom op from mmcv, \ |
| you may have to build mmcv with ONNXRuntime from source.') |
| session_options = ort.SessionOptions() |
| |
| if osp.exists(ort_custom_op_path): |
| session_options.register_custom_ops_library(ort_custom_op_path) |
| sess = ort.InferenceSession(onnx_file, session_options) |
| providers = ['CPUExecutionProvider'] |
| options = [{}] |
| is_cuda_available = ort.get_device() == 'GPU' |
| if is_cuda_available: |
| providers.insert(0, 'CUDAExecutionProvider') |
| options.insert(0, {'device_id': device_id}) |
|
|
| sess.set_providers(providers, options) |
|
|
| self.sess = sess |
| self.device_id = device_id |
| self.io_binding = sess.io_binding() |
| self.output_names = [_.name for _ in sess.get_outputs()] |
| for name in self.output_names: |
| self.io_binding.bind_output(name) |
| self.cfg = cfg |
|
|
| def forward_train(self, img, img_metas, **kwargs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def aug_test(self, imgs, img_metas, **kwargs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def extract_feat(self, imgs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def simple_test(self, |
| img: torch.Tensor, |
| img_metas: Iterable, |
| rescale: bool = False): |
| onnx_pred = inference_with_session(self.sess, self.io_binding, 'input', |
| self.output_names, img) |
| onnx_pred = torch.from_numpy(onnx_pred[0]) |
| if len(img_metas) > 1: |
| boundaries = [ |
| self.bbox_head.get_boundary(*(onnx_pred[i].unsqueeze(0)), |
| [img_metas[i]], rescale) |
| for i in range(len(img_metas)) |
| ] |
|
|
| else: |
| boundaries = [ |
| self.bbox_head.get_boundary(*onnx_pred, img_metas, rescale) |
| ] |
|
|
| return boundaries |
|
|
|
|
| @DETECTORS.register_module() |
| class ONNXRuntimeRecognizer(EncodeDecodeRecognizer): |
| """The class for evaluating onnx file of recognition.""" |
|
|
| def __init__(self, |
| onnx_file: str, |
| cfg: Any, |
| device_id: int, |
| show_score: bool = False): |
| if 'type' in cfg.model: |
| cfg.model.pop('type') |
| EncodeDecodeRecognizer.__init__(self, **(cfg.model)) |
| import onnxruntime as ort |
|
|
| |
| ort_custom_op_path = '' |
| try: |
| from mmcv.ops import get_onnxruntime_op_path |
| ort_custom_op_path = get_onnxruntime_op_path() |
| except (ImportError, ModuleNotFoundError): |
| warnings.warn('If input model has custom op from mmcv, \ |
| you may have to build mmcv with ONNXRuntime from source.') |
| session_options = ort.SessionOptions() |
| |
| if osp.exists(ort_custom_op_path): |
| session_options.register_custom_ops_library(ort_custom_op_path) |
| sess = ort.InferenceSession(onnx_file, session_options) |
| providers = ['CPUExecutionProvider'] |
| options = [{}] |
| is_cuda_available = ort.get_device() == 'GPU' |
| if is_cuda_available: |
| providers.insert(0, 'CUDAExecutionProvider') |
| options.insert(0, {'device_id': device_id}) |
|
|
| sess.set_providers(providers, options) |
|
|
| self.sess = sess |
| self.device_id = device_id |
| self.io_binding = sess.io_binding() |
| self.output_names = [_.name for _ in sess.get_outputs()] |
| for name in self.output_names: |
| self.io_binding.bind_output(name) |
| self.cfg = cfg |
|
|
| def forward_train(self, img, img_metas, **kwargs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def aug_test(self, imgs, img_metas, **kwargs): |
| if isinstance(imgs, list): |
| for idx, each_img in enumerate(imgs): |
| if each_img.dim() == 3: |
| imgs[idx] = each_img.unsqueeze(0) |
| imgs = imgs[0] |
| img_metas = img_metas[0] |
| else: |
| if len(img_metas) == 1 and isinstance(img_metas[0], list): |
| img_metas = img_metas[0] |
| return self.simple_test(imgs, img_metas=img_metas) |
|
|
| def extract_feat(self, imgs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def simple_test(self, |
| img: torch.Tensor, |
| img_metas: Iterable, |
| rescale: bool = False): |
| """Test function. |
| |
| Args: |
| imgs (torch.Tensor): Image input tensor. |
| img_metas (list[dict]): List of image information. |
| |
| Returns: |
| list[str]: Text label result of each image. |
| """ |
| onnx_pred = inference_with_session(self.sess, self.io_binding, 'input', |
| self.output_names, img) |
| onnx_pred = torch.from_numpy(onnx_pred[0]) |
|
|
| label_indexes, label_scores = self.label_convertor.tensor2idx( |
| onnx_pred, img_metas) |
| label_strings = self.label_convertor.idx2str(label_indexes) |
|
|
| |
| results = [] |
| for string, score in zip(label_strings, label_scores): |
| results.append(dict(text=string, score=score)) |
|
|
| return results |
|
|
|
|
| @DETECTORS.register_module() |
| class TensorRTDetector(TextDetectorMixin, SingleStageTextDetector): |
| """The class for evaluating TensorRT file of detection.""" |
|
|
| def __init__(self, |
| trt_file: str, |
| cfg: Any, |
| device_id: int, |
| show_score: bool = False): |
| if 'type' in cfg.model: |
| cfg.model.pop('type') |
| SingleStageTextDetector.__init__(self, **(cfg.model)) |
| TextDetectorMixin.__init__(self, show_score) |
| from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin |
| try: |
| load_tensorrt_plugin() |
| except (ImportError, ModuleNotFoundError): |
| warnings.warn('If input model has custom op from mmcv, \ |
| you may have to build mmcv with TensorRT from source.') |
| model = TRTWrapper( |
| trt_file, input_names=['input'], output_names=['output']) |
|
|
| self.model = model |
| self.device_id = device_id |
| self.cfg = cfg |
|
|
| def forward_train(self, img, img_metas, **kwargs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def aug_test(self, imgs, img_metas, **kwargs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def extract_feat(self, imgs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def simple_test(self, |
| img: torch.Tensor, |
| img_metas: Iterable, |
| rescale: bool = False): |
| with torch.cuda.device(self.device_id), torch.no_grad(): |
| trt_pred = self.model({'input': img})['output'] |
| if len(img_metas) > 1: |
| boundaries = [ |
| self.bbox_head.get_boundary(*(trt_pred[i].unsqueeze(0)), |
| [img_metas[i]], rescale) |
| for i in range(len(img_metas)) |
| ] |
|
|
| else: |
| boundaries = [ |
| self.bbox_head.get_boundary(*trt_pred, img_metas, rescale) |
| ] |
|
|
| return boundaries |
|
|
|
|
| @DETECTORS.register_module() |
| class TensorRTRecognizer(EncodeDecodeRecognizer): |
| """The class for evaluating TensorRT file of recognition.""" |
|
|
| def __init__(self, |
| trt_file: str, |
| cfg: Any, |
| device_id: int, |
| show_score: bool = False): |
| if 'type' in cfg.model: |
| cfg.model.pop('type') |
| EncodeDecodeRecognizer.__init__(self, **(cfg.model)) |
| from mmcv.tensorrt import TRTWrapper, load_tensorrt_plugin |
| try: |
| load_tensorrt_plugin() |
| except (ImportError, ModuleNotFoundError): |
| warnings.warn('If input model has custom op from mmcv, \ |
| you may have to build mmcv with TensorRT from source.') |
| model = TRTWrapper( |
| trt_file, input_names=['input'], output_names=['output']) |
|
|
| self.model = model |
| self.device_id = device_id |
| self.cfg = cfg |
|
|
| def forward_train(self, img, img_metas, **kwargs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def aug_test(self, imgs, img_metas, **kwargs): |
| if isinstance(imgs, list): |
| for idx, each_img in enumerate(imgs): |
| if each_img.dim() == 3: |
| imgs[idx] = each_img.unsqueeze(0) |
| imgs = imgs[0] |
| img_metas = img_metas[0] |
| else: |
| if len(img_metas) == 1 and isinstance(img_metas[0], list): |
| img_metas = img_metas[0] |
| return self.simple_test(imgs, img_metas=img_metas) |
|
|
| def extract_feat(self, imgs): |
| raise NotImplementedError('This method is not implemented.') |
|
|
| def simple_test(self, |
| img: torch.Tensor, |
| img_metas: Iterable, |
| rescale: bool = False): |
| """Test function. |
| |
| Args: |
| imgs (torch.Tensor): Image input tensor. |
| img_metas (list[dict]): List of image information. |
| |
| Returns: |
| list[str]: Text label result of each image. |
| """ |
| with torch.cuda.device(self.device_id), torch.no_grad(): |
| trt_pred = self.model({'input': img})['output'] |
|
|
| label_indexes, label_scores = self.label_convertor.tensor2idx( |
| trt_pred, img_metas) |
| label_strings = self.label_convertor.idx2str(label_indexes) |
|
|
| |
| results = [] |
| for string, score in zip(label_strings, label_scores): |
| results.append(dict(text=string, score=score)) |
|
|
| return results |
|
|