Trouter-Library commited on
Commit
b42e229
·
verified ·
1 Parent(s): 72b4439

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +178 -0
inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helion-V1 Inference Script
3
+ Safe and helpful conversational AI model
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from typing import List, Dict
9
+ import warnings
10
+
11
+ warnings.filterwarnings('ignore')
12
+
13
+
14
+ class HelionInference:
15
+ def __init__(self, model_name: str = "DeepXR/Helion-V1", device: str = "auto"):
16
+ """
17
+ Initialize the Helion model for inference.
18
+
19
+ Args:
20
+ model_name: HuggingFace model identifier
21
+ device: Device to run inference on ('cuda', 'cpu', or 'auto')
22
+ """
23
+ print(f"Loading Helion-V1 model from {model_name}...")
24
+
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ self.model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ torch_dtype=torch.bfloat16,
29
+ device_map=device,
30
+ trust_remote_code=True
31
+ )
32
+
33
+ self.model.eval()
34
+ print("Model loaded successfully!")
35
+
36
+ # Safety keywords to monitor
37
+ self.safety_keywords = [
38
+ "harm", "illegal", "weapon", "violence", "dangerous",
39
+ "exploit", "hack", "steal", "abuse"
40
+ ]
41
+
42
+ def check_safety(self, text: str) -> bool:
43
+ """
44
+ Basic safety check on input text.
45
+
46
+ Args:
47
+ text: Input text to check
48
+
49
+ Returns:
50
+ True if text appears safe, False otherwise
51
+ """
52
+ text_lower = text.lower()
53
+ for keyword in self.safety_keywords:
54
+ if keyword in text_lower:
55
+ return False
56
+ return True
57
+
58
+ def generate_response(
59
+ self,
60
+ messages: List[Dict[str, str]],
61
+ max_new_tokens: int = 512,
62
+ temperature: float = 0.7,
63
+ top_p: float = 0.9,
64
+ do_sample: bool = True
65
+ ) -> str:
66
+ """
67
+ Generate a response from the model.
68
+
69
+ Args:
70
+ messages: List of message dictionaries with 'role' and 'content'
71
+ max_new_tokens: Maximum number of tokens to generate
72
+ temperature: Sampling temperature
73
+ top_p: Nucleus sampling parameter
74
+ do_sample: Whether to use sampling
75
+
76
+ Returns:
77
+ Generated response text
78
+ """
79
+ # Apply chat template
80
+ input_ids = self.tokenizer.apply_chat_template(
81
+ messages,
82
+ add_generation_prompt=True,
83
+ return_tensors="pt"
84
+ ).to(self.model.device)
85
+
86
+ # Generate response
87
+ with torch.no_grad():
88
+ output = self.model.generate(
89
+ input_ids,
90
+ max_new_tokens=max_new_tokens,
91
+ temperature=temperature,
92
+ top_p=top_p,
93
+ do_sample=do_sample,
94
+ pad_token_id=self.tokenizer.pad_token_id,
95
+ eos_token_id=self.tokenizer.eos_token_id
96
+ )
97
+
98
+ # Decode response
99
+ response = self.tokenizer.decode(
100
+ output[0][input_ids.shape[1]:],
101
+ skip_special_tokens=True
102
+ )
103
+
104
+ return response.strip()
105
+
106
+ def chat(self):
107
+ """Interactive chat mode."""
108
+ print("\n" + "="*60)
109
+ print("Helion-V1 Interactive Chat")
110
+ print("Type 'quit' or 'exit' to end the conversation")
111
+ print("="*60 + "\n")
112
+
113
+ conversation_history = []
114
+
115
+ while True:
116
+ user_input = input("You: ").strip()
117
+
118
+ if user_input.lower() in ['quit', 'exit']:
119
+ print("Goodbye! Have a great day!")
120
+ break
121
+
122
+ if not user_input:
123
+ continue
124
+
125
+ # Basic safety check
126
+ if not self.check_safety(user_input):
127
+ print("Helion: I apologize, but I can't assist with that request. "
128
+ "Let me know if there's something else I can help you with!")
129
+ continue
130
+
131
+ # Add user message to history
132
+ conversation_history.append({
133
+ "role": "user",
134
+ "content": user_input
135
+ })
136
+
137
+ # Generate response
138
+ try:
139
+ response = self.generate_response(conversation_history)
140
+ print(f"Helion: {response}\n")
141
+
142
+ # Add assistant response to history
143
+ conversation_history.append({
144
+ "role": "assistant",
145
+ "content": response
146
+ })
147
+ except Exception as e:
148
+ print(f"Error generating response: {e}")
149
+ conversation_history.pop() # Remove failed user message
150
+
151
+
152
+ def main():
153
+ """Main function for CLI usage."""
154
+ import argparse
155
+
156
+ parser = argparse.ArgumentParser(description="Helion-V1 Inference")
157
+ parser.add_argument("--model", default="DeepXR/Helion-V1", help="Model name or path")
158
+ parser.add_argument("--device", default="auto", help="Device to use (cuda/cpu/auto)")
159
+ parser.add_argument("--interactive", action="store_true", help="Start interactive chat")
160
+ parser.add_argument("--prompt", type=str, help="Single prompt to process")
161
+
162
+ args = parser.parse_args()
163
+
164
+ # Initialize model
165
+ helion = HelionInference(model_name=args.model, device=args.device)
166
+
167
+ if args.interactive:
168
+ helion.chat()
169
+ elif args.prompt:
170
+ messages = [{"role": "user", "content": args.prompt}]
171
+ response = helion.generate_response(messages)
172
+ print(f"Response: {response}")
173
+ else:
174
+ print("Please specify --interactive or --prompt")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()