| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import unittest |
|
|
| from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig |
| from transformers.models.bert.configuration_bert import BertConfig |
| from transformers.models.roberta.configuration_roberta import RobertaConfig |
| from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER |
|
|
|
|
| SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") |
|
|
|
|
| class AutoConfigTest(unittest.TestCase): |
| def test_config_from_model_shortcut(self): |
| config = AutoConfig.from_pretrained("bert-base-uncased") |
| self.assertIsInstance(config, BertConfig) |
|
|
| def test_config_model_type_from_local_file(self): |
| config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG) |
| self.assertIsInstance(config, RobertaConfig) |
|
|
| def test_config_model_type_from_model_identifier(self): |
| config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) |
| self.assertIsInstance(config, RobertaConfig) |
|
|
| def test_config_for_model_str(self): |
| config = AutoConfig.for_model("roberta") |
| self.assertIsInstance(config, RobertaConfig) |
|
|
| def test_pattern_matching_fallback(self): |
| """ |
| In cases where config.json doesn't include a model_type, |
| perform a few safety checks on the config mapping's order. |
| """ |
| |
| keys = list(CONFIG_MAPPING.keys()) |
| for i, key in enumerate(keys): |
| self.assertFalse(any(key in later_key for later_key in keys[i + 1 :])) |
|
|