| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
| |
| model_name = "your_model_repo" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
| |
| if tokenizer.special_tokens_map is None: |
| tokenizer.special_tokens_map = { |
| "bos_token": "<s>", |
| "eos_token": "</s>", |
| "unk_token": "<unk>", |
| "sep_token": "</s>", |
| "pad_token": "<pad>", |
| "cls_token": "<s>", |
| "mask_token": "<mask>" |
| } |
| tokenizer.save_pretrained(model_name) |
|
|
| preprocessor_config = { |
| "do_lower_case": False, |
| "max_length": 128, |
| "truncation": True, |
| "padding": "max_length" |
| } |
|
|
| |
| def generate_code(prompt, max_length=128, temperature=0.7, top_p=0.9): |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=preprocessor_config["max_length"]) |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_length=max_length, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True |
| ) |
| |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| if __name__ == "__main__": |
| prompt = "def quicksort(arr):" |
| generated_code = generate_code(prompt) |
| print("Generated Code:\n", generated_code) |