| |
| import pytest |
| import torch |
|
|
| from mmocr.models.textrecog.encoders import (ABIVisionModel, BaseEncoder, |
| NRTREncoder, SAREncoder, |
| SatrnEncoder, TransformerEncoder) |
|
|
|
|
| def test_sar_encoder(): |
| with pytest.raises(AssertionError): |
| SAREncoder(enc_bi_rnn='bi') |
| with pytest.raises(AssertionError): |
| SAREncoder(enc_do_rnn=2) |
| with pytest.raises(AssertionError): |
| SAREncoder(enc_gru='gru') |
| with pytest.raises(AssertionError): |
| SAREncoder(d_model=512.5) |
| with pytest.raises(AssertionError): |
| SAREncoder(d_enc=200.5) |
| with pytest.raises(AssertionError): |
| SAREncoder(mask='mask') |
|
|
| encoder = SAREncoder() |
| encoder.init_weights() |
| encoder.train() |
|
|
| feat = torch.randn(1, 512, 4, 40) |
| img_metas = [{'valid_ratio': 1.0}] |
| with pytest.raises(AssertionError): |
| encoder(feat, img_metas * 2) |
| out_enc = encoder(feat, img_metas) |
|
|
| assert out_enc.shape == torch.Size([1, 512]) |
|
|
|
|
| def test_nrtr_encoder(): |
| tf_encoder = NRTREncoder() |
| tf_encoder.init_weights() |
| tf_encoder.train() |
|
|
| feat = torch.randn(1, 512, 1, 25) |
| out_enc = tf_encoder(feat) |
| print('hello', out_enc.size()) |
| assert out_enc.shape == torch.Size([1, 25, 512]) |
|
|
|
|
| def test_satrn_encoder(): |
| satrn_encoder = SatrnEncoder() |
| satrn_encoder.init_weights() |
| satrn_encoder.train() |
|
|
| feat = torch.randn(1, 512, 8, 25) |
| out_enc = satrn_encoder(feat) |
| assert out_enc.shape == torch.Size([1, 200, 512]) |
|
|
|
|
| def test_base_encoder(): |
| encoder = BaseEncoder() |
| encoder.init_weights() |
| encoder.train() |
|
|
| feat = torch.randn(1, 256, 4, 40) |
| out_enc = encoder(feat) |
| assert out_enc.shape == torch.Size([1, 256, 4, 40]) |
|
|
|
|
| def test_transformer_encoder(): |
| model = TransformerEncoder() |
| x = torch.randn(10, 512, 8, 32) |
| assert model(x).shape == torch.Size([10, 512, 8, 32]) |
|
|
|
|
| def test_abi_vision_model(): |
| model = ABIVisionModel( |
| decoder=dict(type='ABIVisionDecoder', max_seq_len=10, use_result=None)) |
| x = torch.randn(1, 512, 8, 32) |
| result = model(x) |
| assert result['feature'].shape == torch.Size([1, 10, 512]) |
| assert result['logits'].shape == torch.Size([1, 10, 90]) |
| assert result['attn_scores'].shape == torch.Size([1, 10, 8, 32]) |
|
|