Helion-V1.5-XL / inference.py
Trouter-Library's picture
Create inference.py
23a2e44 verified
"""
Helion-V1.5-XL Inference Script
Supports multiple inference modes and optimization techniques
"""
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
GenerationConfig
)
from typing import Optional, Dict, Any, List
import argparse
import json
import time
class HelionInference:
"""Inference wrapper for Helion-V1.5-XL"""
def __init__(
self,
model_name: str = "DeepXR/Helion-V1.5-XL",
load_in_4bit: bool = False,
load_in_8bit: bool = False,
device_map: str = "auto",
torch_dtype: str = "bfloat16"
):
"""
Initialize the model and tokenizer
Args:
model_name: HuggingFace model identifier
load_in_4bit: Enable 4-bit quantization
load_in_8bit: Enable 8-bit quantization
device_map: Device mapping strategy
torch_dtype: PyTorch dtype for model weights
"""
self.model_name = model_name
print(f"Loading model: {model_name}")
# Setup dtype
dtype_map = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32
}
torch_dtype = dtype_map.get(torch_dtype, torch.bfloat16)
# Setup quantization config
quantization_config = None
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Load model
model_kwargs = {
"device_map": device_map,
"trust_remote_code": True,
}
if quantization_config:
model_kwargs["quantization_config"] = quantization_config
else:
model_kwargs["torch_dtype"] = torch_dtype
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
**model_kwargs
)
self.model.eval()
print("Model loaded successfully!")
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
do_sample: bool = True,
num_return_sequences: int = 1,
**kwargs
) -> List[str]:
"""
Generate text from a prompt
Args:
prompt: Input text prompt
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (0.0 to 2.0)
top_p: Nucleus sampling threshold
top_k: Top-k sampling threshold
repetition_penalty: Penalty for repetition
do_sample: Whether to use sampling
num_return_sequences: Number of sequences to generate
Returns:
List of generated text strings
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
num_return_sequences=num_return_sequences,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
**kwargs
)
start_time = time.time()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
generation_config=generation_config
)
generation_time = time.time() - start_time
# Decode outputs
responses = []
for output in outputs:
response = self.tokenizer.decode(output, skip_special_tokens=True)
# Remove the prompt from response
response = response[len(prompt):].strip()
responses.append(response)
# Calculate tokens per second
total_tokens = sum(len(output) for output in outputs)
tokens_per_sec = total_tokens / generation_time
print(f"\nGeneration Stats:")
print(f" Time: {generation_time:.2f}s")
print(f" Tokens/sec: {tokens_per_sec:.2f}")
return responses
def chat(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Generate response in chat format
Args:
messages: List of message dicts with 'role' and 'content'
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated response string
"""
# Apply chat template
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
responses = self.generate(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
**kwargs
)
return responses[0]
def batch_generate(
self,
prompts: List[str],
max_new_tokens: int = 512,
**kwargs
) -> List[str]:
"""
Generate responses for multiple prompts in batch
Args:
prompts: List of input prompts
max_new_tokens: Maximum tokens per generation
Returns:
List of generated responses
"""
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
**kwargs
)
responses = []
for i, output in enumerate(outputs):
response = self.tokenizer.decode(output, skip_special_tokens=True)
# Remove prompt
response = response[len(prompts[i]):].strip()
responses.append(response)
return responses
def main():
parser = argparse.ArgumentParser(description="Helion-V1.5-XL Inference")
parser.add_argument(
"--model",
type=str,
default="DeepXR/Helion-V1.5-XL",
help="Model name or path"
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Input prompt"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature"
)
parser.add_argument(
"--top-p",
type=float,
default=0.9,
help="Nucleus sampling threshold"
)
parser.add_argument(
"--load-in-4bit",
action="store_true",
help="Load model in 4-bit quantization"
)
parser.add_argument(
"--load-in-8bit",
action="store_true",
help="Load model in 8-bit quantization"
)
parser.add_argument(
"--chat-mode",
action="store_true",
help="Use chat format"
)
args = parser.parse_args()
# Initialize model
inference = HelionInference(
model_name=args.model,
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit
)
# Generate response
if args.chat_mode:
messages = [
{"role": "user", "content": args.prompt}
]
response = inference.chat(
messages,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p
)
else:
responses = inference.generate(
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p
)
response = responses[0]
print("\n" + "="*80)
print("PROMPT:")
print("="*80)
print(args.prompt)
print("\n" + "="*80)
print("RESPONSE:")
print("="*80)
print(response)
print("="*80)
if __name__ == "__main__":
main()