| |
| import pytest |
| import torch |
|
|
| from mmocr.models.textrecog.preprocessor import (BasePreprocessor, |
| TPSPreprocessor) |
|
|
|
|
| def test_tps_preprocessor(): |
| with pytest.raises(AssertionError): |
| TPSPreprocessor(num_fiducial=-1) |
| with pytest.raises(AssertionError): |
| TPSPreprocessor(img_size=32) |
| with pytest.raises(AssertionError): |
| TPSPreprocessor(rectified_img_size=100) |
| with pytest.raises(AssertionError): |
| TPSPreprocessor(num_img_channel='bgr') |
|
|
| tps_preprocessor = TPSPreprocessor( |
| num_fiducial=20, |
| img_size=(32, 100), |
| rectified_img_size=(32, 100), |
| num_img_channel=1) |
| tps_preprocessor.init_weights() |
| tps_preprocessor.train() |
|
|
| batch_img = torch.randn(1, 1, 32, 100) |
| processed = tps_preprocessor(batch_img) |
| assert processed.shape == torch.Size([1, 1, 32, 100]) |
|
|
|
|
| def test_base_preprocessor(): |
| preprocessor = BasePreprocessor() |
| preprocessor.init_weights() |
| preprocessor.train() |
|
|
| batch_img = torch.randn(1, 1, 32, 100) |
| processed = preprocessor(batch_img) |
| assert processed.shape == torch.Size([1, 1, 32, 100]) |
|
|