| | from typing import Dict, Any |
| | import torch |
| | from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.tokenizer.add_bos_token = True |
| |
|
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 |
| | ).to("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | self.generator = pipeline( |
| | "text-generation", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | device=0 if torch.cuda.is_available() else -1, |
| | return_full_text=False, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 |
| | ) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | prompt = data.get("inputs", "") |
| | if not prompt: |
| | return {"error": "Missing 'inputs' field."} |
| |
|
| | defaults = { |
| | "max_new_tokens": 100, |
| | "do_sample": True, |
| | "temperature": 0.7, |
| | "top_p": 0.9, |
| | "eos_token_id": self.tokenizer.eos_token_id |
| | } |
| |
|
| | generation_args = {**defaults, **data.get("parameters", {})} |
| |
|
| | try: |
| | outputs = self.generator(prompt, **generation_args) |
| | output_text = outputs[0]["generated_text"].strip() |
| |
|
| | finish_reason = "stop" |
| | if len(self.tokenizer.encode(output_text)) >= generation_args["max_new_tokens"]: |
| | finish_reason = "length" |
| |
|
| | return { |
| | "choices": [{ |
| | "message": { |
| | "role": "assistant", |
| | "content": output_text |
| | }, |
| | "finish_reason": finish_reason |
| | }] |
| | } |
| |
|
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|