| |
| import io |
| import json |
| import os |
| import platform |
| import random |
| import sys |
| import tempfile |
| from pathlib import Path |
| from unittest import mock |
|
|
| import mmcv |
| import numpy as np |
| import pytest |
| import torch |
|
|
| from mmocr.apis import init_detector |
| from mmocr.datasets.kie_dataset import KIEDataset |
| from mmocr.utils.ocr import MMOCR |
|
|
|
|
| def test_ocr_init_errors(): |
| |
| with pytest.raises(ValueError): |
| _ = MMOCR(det='test') |
| with pytest.raises(ValueError): |
| _ = MMOCR(recog='test') |
| with pytest.raises(ValueError): |
| _ = MMOCR(kie='test') |
| with pytest.raises(NotImplementedError): |
| _ = MMOCR(det=None, recog=None, kie='SDMGR') |
| with pytest.raises(NotImplementedError): |
| _ = MMOCR(det='DB_r18', recog=None, kie='SDMGR') |
|
|
|
|
| cfg_default_prefix = os.path.join(str(Path.cwd()), 'configs/') |
|
|
|
|
| @pytest.mark.parametrize( |
| 'det, recog, kie, config_dir, gt_cfg, gt_ckpt', |
| [('DB_r18', None, '', '', |
| cfg_default_prefix + 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| 'https://download.openmmlab.com/mmocr/textdet/' |
| 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'), |
| (None, 'CRNN', '', '', |
| cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py', |
| 'https://download.openmmlab.com/mmocr/textrecog/' |
| 'crnn/crnn_academic-a723a1c5.pth'), |
| ('DB_r18', 'CRNN', 'SDMGR', '', [ |
| cfg_default_prefix + |
| 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| cfg_default_prefix + 'textrecog/crnn/crnn_academic_dataset.py', |
| cfg_default_prefix + 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' |
| ], [ |
| 'https://download.openmmlab.com/mmocr/textdet/' |
| 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth', |
| 'https://download.openmmlab.com/mmocr/textrecog/' |
| 'crnn/crnn_academic-a723a1c5.pth', |
| 'https://download.openmmlab.com/mmocr/kie/' |
| 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
| ]), |
| ('DB_r18', 'CRNN', 'SDMGR', 'test/', [ |
| 'test/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| 'test/textrecog/crnn/crnn_academic_dataset.py', |
| 'test/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' |
| ], [ |
| 'https://download.openmmlab.com/mmocr/textdet/' |
| 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth', |
| 'https://download.openmmlab.com/mmocr/textrecog/' |
| 'crnn/crnn_academic-a723a1c5.pth', |
| 'https://download.openmmlab.com/mmocr/kie/' |
| 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
| ])], |
| ) |
| @mock.patch('mmocr.utils.ocr.init_detector') |
| @mock.patch('mmocr.utils.ocr.build_detector') |
| @mock.patch('mmocr.utils.ocr.Config.fromfile') |
| @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| def test_ocr_init(mock_loading, mock_config, mock_build_detector, |
| mock_init_detector, det, recog, kie, config_dir, gt_cfg, |
| gt_ckpt): |
|
|
| def loadcheckpoint_assert(*args, **kwargs): |
| assert args[1] == gt_ckpt[-1] |
| assert kwargs['map_location'] == torch.device( |
| 'cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| mock_loading.side_effect = loadcheckpoint_assert |
| with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): |
| if kie == '': |
| if config_dir == '': |
| _ = MMOCR(det=det, recog=recog) |
| else: |
| _ = MMOCR(det=det, recog=recog, config_dir=config_dir) |
| else: |
| if config_dir == '': |
| _ = MMOCR(det=det, recog=recog, kie=kie) |
| else: |
| _ = MMOCR(det=det, recog=recog, kie=kie, config_dir=config_dir) |
| if isinstance(gt_cfg, str): |
| gt_cfg = [gt_cfg] |
| if isinstance(gt_ckpt, str): |
| gt_ckpt = [gt_ckpt] |
|
|
| i_range = range(len(gt_cfg)) |
| if kie: |
| i_range = i_range[:-1] |
| mock_config.assert_called_with(gt_cfg[-1]) |
| mock_build_detector.assert_called_once() |
| mock_loading.assert_called_once() |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| calls = [ |
| mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range |
| ] |
| mock_init_detector.assert_has_calls(calls) |
|
|
|
|
| @pytest.mark.parametrize( |
| 'det, det_config, det_ckpt, recog, recog_config, recog_ckpt,' |
| 'kie, kie_config, kie_ckpt, config_dir, gt_cfg, gt_ckpt', |
| [('DB_r18', 'test.py', '', 'CRNN', 'test.py', '', 'SDMGR', 'test.py', '', |
| 'configs/', ['test.py', 'test.py', 'test.py'], [ |
| 'https://download.openmmlab.com/mmocr/textdet/' |
| 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth', |
| 'https://download.openmmlab.com/mmocr/textrecog/' |
| 'crnn/crnn_academic-a723a1c5.pth', |
| 'https://download.openmmlab.com/mmocr/kie/' |
| 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' |
| ]), |
| ('DB_r18', '', 'test.ckpt', 'CRNN', '', 'test.ckpt', 'SDMGR', '', |
| 'test.ckpt', '', [ |
| 'textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', |
| 'textrecog/crnn/crnn_academic_dataset.py', |
| 'kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py' |
| ], ['test.ckpt', 'test.ckpt', 'test.ckpt']), |
| ('DB_r18', 'test.py', 'test.ckpt', 'CRNN', 'test.py', 'test.ckpt', |
| 'SDMGR', 'test.py', 'test.ckpt', '', ['test.py', 'test.py', 'test.py'], |
| ['test.ckpt', 'test.ckpt', 'test.ckpt'])]) |
| @mock.patch('mmocr.utils.ocr.init_detector') |
| @mock.patch('mmocr.utils.ocr.build_detector') |
| @mock.patch('mmocr.utils.ocr.Config.fromfile') |
| @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| def test_ocr_init_customize_config(mock_loading, mock_config, |
| mock_build_detector, mock_init_detector, |
| det, det_config, det_ckpt, recog, |
| recog_config, recog_ckpt, kie, kie_config, |
| kie_ckpt, config_dir, gt_cfg, gt_ckpt): |
|
|
| def loadcheckpoint_assert(*args, **kwargs): |
| assert args[1] == gt_ckpt[-1] |
|
|
| mock_loading.side_effect = loadcheckpoint_assert |
| with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): |
| _ = MMOCR( |
| det=det, |
| det_config=det_config, |
| det_ckpt=det_ckpt, |
| recog=recog, |
| recog_config=recog_config, |
| recog_ckpt=recog_ckpt, |
| kie=kie, |
| kie_config=kie_config, |
| kie_ckpt=kie_ckpt, |
| config_dir=config_dir) |
|
|
| i_range = range(len(gt_cfg)) |
| if kie: |
| i_range = i_range[:-1] |
| mock_config.assert_called_with(gt_cfg[-1]) |
| mock_build_detector.assert_called_once() |
| mock_loading.assert_called_once() |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| calls = [ |
| mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range |
| ] |
| mock_init_detector.assert_has_calls(calls) |
|
|
|
|
| @mock.patch('mmocr.utils.ocr.init_detector') |
| @mock.patch('mmocr.utils.ocr.build_detector') |
| @mock.patch('mmocr.utils.ocr.Config.fromfile') |
| @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| @mock.patch('mmocr.utils.ocr.model_inference') |
| def test_single_inference(mock_model_inference, mock_loading, mock_config, |
| mock_build_detector, mock_init_detector): |
|
|
| def dummy_inference(model, arr, batch_mode): |
| return arr |
|
|
| mock_model_inference.side_effect = dummy_inference |
| mmocr = MMOCR() |
|
|
| data = list(range(20)) |
| model = 'dummy' |
| res = mmocr.single_inference(model, data, batch_mode=False) |
| assert (data == res) |
| mock_model_inference.reset_mock() |
|
|
| res = mmocr.single_inference(model, data, batch_mode=True) |
| assert (data == res) |
| mock_model_inference.assert_called_once() |
| mock_model_inference.reset_mock() |
|
|
| res = mmocr.single_inference(model, data, batch_mode=True, batch_size=100) |
| assert (data == res) |
| mock_model_inference.assert_called_once() |
| mock_model_inference.reset_mock() |
|
|
| res = mmocr.single_inference(model, data, batch_mode=True, batch_size=3) |
| assert (data == res) |
|
|
|
|
| @mock.patch('mmocr.utils.ocr.init_detector') |
| @mock.patch('mmocr.utils.ocr.load_checkpoint') |
| def MMOCR_testobj(mock_loading, mock_init_detector, **kwargs): |
| |
| |
| def init_detector_skip_ckpt(config, ckpt, device): |
| return init_detector(config, device=device) |
|
|
| def modify_kie_class(model, ckpt, map_location): |
| model.class_list = 'tests/data/kie_toy_dataset/class_list.txt' |
|
|
| mock_init_detector.side_effect = init_detector_skip_ckpt |
| mock_loading.side_effect = modify_kie_class |
| kwargs['det'] = kwargs.get('det', 'DB_r18') |
| kwargs['recog'] = kwargs.get('recog', 'CRNN') |
| kwargs['kie'] = kwargs.get('kie', 'SDMGR') |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
| return MMOCR(**kwargs, device=device) |
|
|
|
|
| @pytest.mark.skipif( |
| platform.system() == 'Windows', |
| reason='Win container on Github Action does not have enough RAM to run') |
| @mock.patch('mmocr.utils.ocr.KIEDataset') |
| def test_readtext(mock_kiedataset): |
| |
| |
| torch.manual_seed(4) |
| random.seed(4) |
| mmocr = MMOCR_testobj() |
| mmocr_det = MMOCR_testobj(kie='', recog='') |
| mmocr_recog = MMOCR_testobj(kie='', det='', recog='CRNN_TPS') |
| mmocr_det_recog = MMOCR_testobj(kie='') |
|
|
| def readtext(imgs, ocr_obj=mmocr, **kwargs): |
| |
| |
| e2e_res = ocr_obj.readtext(imgs, **kwargs) |
| for res in e2e_res: |
| res.pop('filename') |
| return e2e_res |
|
|
| def kiedataset_with_test_dict(**kwargs): |
| kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt' |
| return KIEDataset(**kwargs) |
|
|
| mock_kiedataset.side_effect = kiedataset_with_test_dict |
|
|
| |
| toy_dir = 'tests/data/toy_dataset/imgs/test/' |
| toy_img1_path = toy_dir + 'img_1.jpg' |
| str_e2e_res = readtext(toy_img1_path) |
| toy_img1 = mmcv.imread(toy_img1_path) |
| np_e2e_res = readtext(toy_img1) |
| assert str_e2e_res == np_e2e_res |
|
|
| |
| toy_img2_path = toy_dir + 'img_2.jpg' |
| toy_img2 = mmcv.imread(toy_img2_path) |
| toy_imgs = [toy_img1, toy_img2] |
| toy_img_paths = [toy_img1_path, toy_img2_path] |
| np_e2e_results = readtext(toy_imgs) |
| str_e2e_results = readtext(toy_img_paths) |
| str_tuple_e2e_results = readtext(tuple(toy_img_paths)) |
| assert np_e2e_results == str_e2e_results |
| assert str_e2e_results == str_tuple_e2e_results |
|
|
| |
| toy_imgs.append(toy_dir + 'img_3.jpg') |
| e2e_res = readtext(toy_imgs) |
| full_batch_e2e_res = readtext(toy_imgs, batch_mode=True) |
| assert full_batch_e2e_res == e2e_res |
| batch_e2e_res = readtext( |
| toy_imgs, batch_mode=True, recog_batch_size=2, det_batch_size=2) |
| assert batch_e2e_res == full_batch_e2e_res |
|
|
| |
| full_batch_det_res = mmocr_det.readtext(toy_imgs, batch_mode=True) |
| det_res = mmocr_det.readtext(toy_imgs) |
| batch_det_res = mmocr_det.readtext( |
| toy_imgs, batch_mode=True, single_batch_size=2) |
| assert len(full_batch_det_res) == len(det_res) |
| assert len(batch_det_res) == len(det_res) |
| assert all([ |
| np.allclose(full_batch_det_res[i]['boundary_result'], |
| det_res[i]['boundary_result']) |
| for i in range(len(full_batch_det_res)) |
| ]) |
| assert all([ |
| np.allclose(batch_det_res[i]['boundary_result'], |
| det_res[i]['boundary_result']) |
| for i in range(len(batch_det_res)) |
| ]) |
|
|
| |
| full_batch_recog_res = mmocr_recog.readtext(toy_imgs, batch_mode=True) |
| recog_res = mmocr_recog.readtext(toy_imgs) |
| batch_recog_res = mmocr_recog.readtext( |
| toy_imgs, batch_mode=True, single_batch_size=2) |
| full_batch_recog_res.sort(key=lambda x: x['text']) |
| batch_recog_res.sort(key=lambda x: x['text']) |
| recog_res.sort(key=lambda x: x['text']) |
| assert np.all([ |
| np.allclose(full_batch_recog_res[i]['score'], recog_res[i]['score']) |
| for i in range(len(full_batch_recog_res)) |
| ]) |
| assert np.all([ |
| np.allclose(batch_recog_res[i]['score'], recog_res[i]['score']) |
| for i in range(len(full_batch_recog_res)) |
| ]) |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| mmocr.readtext(toy_imgs, export=tmpdirname) |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| mmocr_det.readtext(toy_imgs, export=tmpdirname) |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| mmocr_recog.readtext(toy_imgs, export=tmpdirname) |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
|
|
| |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| tmp_output = os.path.join(tmpdirname, '1.jpg') |
| mmocr.readtext(toy_imgs[0], output=tmp_output) |
| assert os.path.exists(tmp_output) |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| mmocr.readtext(toy_imgs, output=tmpdirname) |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) |
|
|
| |
| with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow: |
| mmocr.readtext(toy_img1_path, imshow=True) |
| mock_imshow.assert_called_once() |
| mock_imshow.reset_mock() |
| mmocr.readtext(toy_imgs, imshow=True) |
| assert mock_imshow.call_count == len(toy_imgs) |
|
|
| |
| with io.StringIO() as capturedOutput: |
| sys.stdout = capturedOutput |
| res = mmocr.readtext(toy_imgs, print_result=True) |
| assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace( |
| '\n\n', ',').replace("'", '"')) == res |
| sys.stdout = sys.__stdout__ |
| with io.StringIO() as capturedOutput: |
| sys.stdout = capturedOutput |
| res = mmocr.readtext(toy_imgs, details=True, print_result=True) |
| assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace( |
| '\n\n', ',').replace("'", '"')) == res |
| sys.stdout = sys.__stdout__ |
|
|
| |
| with mock.patch('mmocr.utils.ocr.stitch_boxes_into_lines') as mock_merge: |
| mmocr_det_recog.readtext(toy_imgs, merge=True) |
| assert mock_merge.call_count == len(toy_imgs) |
|
|