| |
| import pytest |
| import torch |
|
|
| from mmocr.models.textrecog.backbones import (ResNet, ResNet31OCR, ResNetABI, |
| ShallowCNN, VeryDeepVgg) |
|
|
|
|
| def test_resnet31_ocr_backbone(): |
| """Test resnet backbone.""" |
| with pytest.raises(AssertionError): |
| ResNet31OCR(2.5) |
|
|
| with pytest.raises(AssertionError): |
| ResNet31OCR(3, layers=5) |
|
|
| with pytest.raises(AssertionError): |
| ResNet31OCR(3, channels=5) |
|
|
| |
| model = ResNet31OCR() |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 32, 160) |
| feat = model(imgs) |
| assert feat.shape == torch.Size([1, 512, 4, 40]) |
|
|
|
|
| def test_vgg_deep_vgg_ocr_backbone(): |
|
|
| model = VeryDeepVgg() |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 32, 160) |
| feats = model(imgs) |
| assert feats.shape == torch.Size([1, 512, 1, 41]) |
|
|
|
|
| def test_shallow_cnn_ocr_backbone(): |
|
|
| model = ShallowCNN() |
| model.init_weights() |
| model.train() |
|
|
| imgs = torch.randn(1, 1, 32, 100) |
| feat = model(imgs) |
| assert feat.shape == torch.Size([1, 512, 8, 25]) |
|
|
|
|
| def test_resnet_abi(): |
| """Test resnet backbone.""" |
| with pytest.raises(AssertionError): |
| ResNetABI(2.5) |
|
|
| with pytest.raises(AssertionError): |
| ResNetABI(3, arch_settings=5) |
|
|
| with pytest.raises(AssertionError): |
| ResNetABI(3, stem_channels=None) |
|
|
| with pytest.raises(AssertionError): |
| ResNetABI(arch_settings=[3, 4, 6, 6], strides=[1, 2, 1, 2, 1]) |
|
|
| |
| model = ResNetABI() |
| model.train() |
|
|
| imgs = torch.randn(1, 3, 32, 160) |
| feat = model(imgs) |
| assert feat.shape == torch.Size([1, 512, 8, 40]) |
|
|
|
|
| def test_resnet(): |
| """Test all ResNet backbones.""" |
|
|
| resnet45_aster = ResNet( |
| in_channels=3, |
| stem_channels=[64, 128], |
| block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), |
| arch_layers=[3, 4, 6, 6, 3], |
| arch_channels=[32, 64, 128, 256, 512], |
| strides=[(2, 2), (2, 2), (2, 1), (2, 1), (2, 1)]) |
|
|
| resnet45_abi = ResNet( |
| in_channels=3, |
| stem_channels=32, |
| block_cfgs=dict(type='BasicBlock', use_conv1x1='True'), |
| arch_layers=[3, 4, 6, 6, 3], |
| arch_channels=[32, 64, 128, 256, 512], |
| strides=[2, 1, 2, 1, 1]) |
|
|
| resnet_31 = ResNet( |
| in_channels=3, |
| stem_channels=[64, 128], |
| block_cfgs=dict(type='BasicBlock'), |
| arch_layers=[1, 2, 5, 3], |
| arch_channels=[256, 256, 512, 512], |
| strides=[1, 1, 1, 1], |
| plugins=[ |
| dict( |
| cfg=dict(type='Maxpool2d', kernel_size=2, stride=(2, 2)), |
| stages=(True, True, False, False), |
| position='before_stage'), |
| dict( |
| cfg=dict(type='Maxpool2d', kernel_size=(2, 1), stride=(2, 1)), |
| stages=(False, False, True, False), |
| position='before_stage'), |
| dict( |
| cfg=dict( |
| type='ConvModule', |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU')), |
| stages=(True, True, True, True), |
| position='after_stage') |
| ]) |
| img = torch.rand(1, 3, 32, 100) |
|
|
| assert resnet45_aster(img).shape == torch.Size([1, 512, 1, 25]) |
| assert resnet45_abi(img).shape == torch.Size([1, 512, 8, 25]) |
| assert resnet_31(img).shape == torch.Size([1, 512, 4, 25]) |
|
|