Commit
·
9638d60
1
Parent(s):
97ba9b9
Add mandatory register_for_auto_class function for transformers-4.26+ support
Browse files- tokenization_hat.py +22 -0
tokenization_hat.py
CHANGED
|
@@ -247,3 +247,25 @@ class HATTokenizer:
|
|
| 247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
| 248 |
))
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
| 248 |
))
|
| 249 |
|
| 250 |
+
@classmethod
|
| 251 |
+
def register_for_auto_class(cls, auto_class="AutoModel"):
|
| 252 |
+
"""
|
| 253 |
+
Register this class with a given auto class. This should only be used for custom models as the ones in the
|
| 254 |
+
library are already mapped with an auto class.
|
| 255 |
+
<Tip warning={true}>
|
| 256 |
+
This API is experimental and may have some slight breaking changes in the next releases.
|
| 257 |
+
</Tip>
|
| 258 |
+
Args:
|
| 259 |
+
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
|
| 260 |
+
The auto class to register this new model with.
|
| 261 |
+
"""
|
| 262 |
+
if not isinstance(auto_class, str):
|
| 263 |
+
auto_class = auto_class.__name__
|
| 264 |
+
|
| 265 |
+
import transformers.models.auto as auto_module
|
| 266 |
+
|
| 267 |
+
if not hasattr(auto_module, auto_class):
|
| 268 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
| 269 |
+
|
| 270 |
+
cls._auto_class = auto_class
|
| 271 |
+
|