| | |
| | """ |
| | Example usage script to evaluate a fine-tuned OlmoE adapter model |
| | and demonstrate generation with adapters. |
| | """ |
| |
|
| | import argparse |
| | import torch |
| | from transformers import AutoTokenizer |
| | from modeling_olmoe import OlmoEWithAdaptersForCausalLM, OlmoConfig |
| |
|
| | def generate_text( |
| | model_path: str, |
| | prompt: str, |
| | max_new_tokens: int = 128, |
| | temperature: float = 0.7, |
| | top_p: float = 0.9, |
| | device: str = "auto", |
| | ): |
| | """Generate text using a fine-tuned OlmoE adapter model.""" |
| | |
| | if device == "auto": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| | |
| | |
| | print(f"Loading model from {model_path}") |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | |
| | |
| | config = OlmoConfig.from_pretrained(model_path) |
| | |
| | |
| | model = OlmoEWithAdaptersForCausalLM.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| | ) |
| | model = model.to(device) |
| | model.eval() |
| | |
| | |
| | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| | |
| | |
| | print("\nGenerating text...\n") |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | input_ids, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | ) |
| | |
| | |
| | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | print(f"Prompt: {prompt}") |
| | print("\nGenerated text:") |
| | print("=" * 40) |
| | print(generated_text) |
| | print("=" * 40) |
| | |
| | return generated_text |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Generate text with OlmoE adapter model") |
| | parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model") |
| | parser.add_argument("--prompt", type=str, default="This is an example of", help="Prompt for text generation") |
| | parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of new tokens to generate") |
| | parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") |
| | parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter") |
| | parser.add_argument("--device", type=str, default="auto", help="Device to use (cuda, cpu, or auto)") |
| | |
| | args = parser.parse_args() |
| | |
| | generate_text( |
| | model_path=args.model_path, |
| | prompt=args.prompt, |
| | max_new_tokens=args.max_new_tokens, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | device=args.device, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |