| from transformers import MarianMTModel, MarianTokenizer |
| from typing import Any, List, Dict |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.model = MarianMTModel.from_pretrained(path) |
| self.tokenizer = MarianTokenizer.from_pretrained(path) |
|
|
| def __call__(self, data: Any) -> List[Dict[str, str]]: |
| """ |
| Args: |
| data (dict): The request payload with an "inputs" key containing the text to translate. |
| Returns: |
| List[Dict]: A list containing the translated text. |
| """ |
| |
| text = data.get("inputs", "") |
|
|
| |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True) |
|
|
| |
| translated = self.model.generate(**inputs) |
|
|
| |
| translated_text = self.tokenizer.decode(translated[0], skip_special_tokens=True) |
|
|
| |
| return [{"translation_text": translated_text}] |