|
|
""" |
|
|
Helion-V1.5-XL Inference Script |
|
|
Supports multiple inference modes and optimization techniques |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
BitsAndBytesConfig, |
|
|
GenerationConfig |
|
|
) |
|
|
from typing import Optional, Dict, Any, List |
|
|
import argparse |
|
|
import json |
|
|
import time |
|
|
|
|
|
|
|
|
class HelionInference: |
|
|
"""Inference wrapper for Helion-V1.5-XL""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "DeepXR/Helion-V1.5-XL", |
|
|
load_in_4bit: bool = False, |
|
|
load_in_8bit: bool = False, |
|
|
device_map: str = "auto", |
|
|
torch_dtype: str = "bfloat16" |
|
|
): |
|
|
""" |
|
|
Initialize the model and tokenizer |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model identifier |
|
|
load_in_4bit: Enable 4-bit quantization |
|
|
load_in_8bit: Enable 8-bit quantization |
|
|
device_map: Device mapping strategy |
|
|
torch_dtype: PyTorch dtype for model weights |
|
|
""" |
|
|
self.model_name = model_name |
|
|
print(f"Loading model: {model_name}") |
|
|
|
|
|
|
|
|
dtype_map = { |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float16": torch.float16, |
|
|
"float32": torch.float32 |
|
|
} |
|
|
torch_dtype = dtype_map.get(torch_dtype, torch.bfloat16) |
|
|
|
|
|
|
|
|
quantization_config = None |
|
|
if load_in_4bit: |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch_dtype, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4" |
|
|
) |
|
|
elif load_in_8bit: |
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"device_map": device_map, |
|
|
"trust_remote_code": True, |
|
|
} |
|
|
|
|
|
if quantization_config: |
|
|
model_kwargs["quantization_config"] = quantization_config |
|
|
else: |
|
|
model_kwargs["torch_dtype"] = torch_dtype |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
self.model.eval() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.1, |
|
|
do_sample: bool = True, |
|
|
num_return_sequences: int = 1, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate text from a prompt |
|
|
|
|
|
Args: |
|
|
prompt: Input text prompt |
|
|
max_new_tokens: Maximum number of tokens to generate |
|
|
temperature: Sampling temperature (0.0 to 2.0) |
|
|
top_p: Nucleus sampling threshold |
|
|
top_k: Top-k sampling threshold |
|
|
repetition_penalty: Penalty for repetition |
|
|
do_sample: Whether to use sampling |
|
|
num_return_sequences: Number of sequences to generate |
|
|
|
|
|
Returns: |
|
|
List of generated text strings |
|
|
""" |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
generation_config = GenerationConfig( |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=do_sample, |
|
|
num_return_sequences=num_return_sequences, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
generation_config=generation_config |
|
|
) |
|
|
|
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
|
|
|
responses = [] |
|
|
for output in outputs: |
|
|
response = self.tokenizer.decode(output, skip_special_tokens=True) |
|
|
|
|
|
response = response[len(prompt):].strip() |
|
|
responses.append(response) |
|
|
|
|
|
|
|
|
total_tokens = sum(len(output) for output in outputs) |
|
|
tokens_per_sec = total_tokens / generation_time |
|
|
|
|
|
print(f"\nGeneration Stats:") |
|
|
print(f" Time: {generation_time:.2f}s") |
|
|
print(f" Tokens/sec: {tokens_per_sec:.2f}") |
|
|
|
|
|
return responses |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate response in chat format |
|
|
|
|
|
Args: |
|
|
messages: List of message dicts with 'role' and 'content' |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
Generated response string |
|
|
""" |
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
responses = self.generate( |
|
|
prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
return responses[0] |
|
|
|
|
|
def batch_generate( |
|
|
self, |
|
|
prompts: List[str], |
|
|
max_new_tokens: int = 512, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate responses for multiple prompts in batch |
|
|
|
|
|
Args: |
|
|
prompts: List of input prompts |
|
|
max_new_tokens: Maximum tokens per generation |
|
|
|
|
|
Returns: |
|
|
List of generated responses |
|
|
""" |
|
|
inputs = self.tokenizer( |
|
|
prompts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True |
|
|
).to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
responses = [] |
|
|
for i, output in enumerate(outputs): |
|
|
response = self.tokenizer.decode(output, skip_special_tokens=True) |
|
|
|
|
|
response = response[len(prompts[i]):].strip() |
|
|
responses.append(response) |
|
|
|
|
|
return responses |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Helion-V1.5-XL Inference") |
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
default="DeepXR/Helion-V1.5-XL", |
|
|
help="Model name or path" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Input prompt" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-tokens", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Maximum tokens to generate" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", |
|
|
type=float, |
|
|
default=0.7, |
|
|
help="Sampling temperature" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top-p", |
|
|
type=float, |
|
|
default=0.9, |
|
|
help="Nucleus sampling threshold" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load-in-4bit", |
|
|
action="store_true", |
|
|
help="Load model in 4-bit quantization" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--load-in-8bit", |
|
|
action="store_true", |
|
|
help="Load model in 8-bit quantization" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chat-mode", |
|
|
action="store_true", |
|
|
help="Use chat format" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
inference = HelionInference( |
|
|
model_name=args.model, |
|
|
load_in_4bit=args.load_in_4bit, |
|
|
load_in_8bit=args.load_in_8bit |
|
|
) |
|
|
|
|
|
|
|
|
if args.chat_mode: |
|
|
messages = [ |
|
|
{"role": "user", "content": args.prompt} |
|
|
] |
|
|
response = inference.chat( |
|
|
messages, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p |
|
|
) |
|
|
else: |
|
|
responses = inference.generate( |
|
|
args.prompt, |
|
|
max_new_tokens=args.max_tokens, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p |
|
|
) |
|
|
response = responses[0] |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("PROMPT:") |
|
|
print("="*80) |
|
|
print(args.prompt) |
|
|
print("\n" + "="*80) |
|
|
print("RESPONSE:") |
|
|
print("="*80) |
|
|
print(response) |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |