| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.configuration_utils import PretrainedConfig | |
| import torch | |
| from transformers import ZeroShotClassificationPipeline | |
| class CustomConfig(PretrainedConfig): | |
| model_type = "test-zeroshot" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| class CustomModel(PreTrainedModel): | |
| config_class = CustomConfig | |
| def __init__(self, config: CustomConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.embeddings = torch.nn.Embedding(num_embeddings=1, embedding_dim=1) | |
| def forward(self, **kwargs) -> SequenceClassifierOutput: | |
| return SequenceClassifierOutput(logits=torch.tensor([[1, 2, 3]])) | |
| from transformers.pipelines import PIPELINE_REGISTRY | |
| from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification | |
| if __name__ == "__main__": | |
| from transformers import pipeline | |
| classifier = pipeline("zero-shot-classification", | |
| model="cl-tohoku/bert-base-japanese") | |
| from transformers import AutoConfig, AutoModel, AutoModelForImageClassification | |
| CustomConfig.register_for_auto_class() | |
| CustomModel.register_for_auto_class("AutoModel") | |
| p = ZeroShotClassificationPipeline(model=CustomModel(CustomConfig()), | |
| tokenizer=classifier.tokenizer) | |
| from huggingface_hub import Repository | |
| repo = Repository("zero-shot-classification", | |
| clone_from="paulhindemith/zero-shot-classification") | |
| p.save_pretrained("zero-shot-classification") | |
| repo.push_to_hub() |