| import streamlit as st |
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
| |
| model_name = "starcoder2" |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| def code_complete(prompt, max_length=256): |
| """ |
| Generate code completion suggestions for the given prompt. |
| |
| Args: |
| prompt (str): The incomplete code snippet. |
| max_length (int, optional): The maximum length of the generated code. Defaults to 256. |
| |
| Returns: |
| list: A list of code completion suggestions. |
| """ |
| |
| inputs = tokenizer.encode_plus(prompt, |
| add_special_tokens=True, |
| max_length=max_length, |
| padding="max_length", |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors="pt") |
|
|
| |
| outputs = model.generate(inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_length=max_length) |
|
|
| |
| suggestions = [] |
| for output in outputs: |
| decoded_code = tokenizer.decode(output, skip_special_tokens=True) |
| suggestions.append(decoded_code) |
|
|
| return suggestions |
|
|
| def code_fix(code): |
| """ |
| Fix errors in the given code snippet. |
| |
| Args: |
| code (str): The code snippet with errors. |
| |
| Returns: |
| str: The corrected code snippet. |
| """ |
| |
| inputs = tokenizer.encode_plus(code, |
| add_special_tokens=True, |
| max_length=512, |
| padding="max_length", |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors="pt") |
|
|
| |
| outputs = model.generate(inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_length=512) |
|
|
| |
| corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| return corrected_code |
|
|
| def text_to_code(text, max_length=256): |
| """ |
| Generate code from a natural language description. |
| |
| Args: |
| text (str): The natural language description of the code. |
| max_length (int, optional): The maximum length of the generated code. Defaults to 256. |
| |
| Returns: |
| str: The generated code. |
| """ |
| |
| inputs = tokenizer.encode_plus(text, |
| add_special_tokens=True, |
| max_length=max_length, |
| padding="max_length", |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors="pt") |
|
|
| |
| outputs = model.generate(inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_length=max_length) |
|
|
| |
| generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| return generated_code |
|
|
| |
| st.title("Codebot") |
| st.write("Welcome to the Codebot! You can use this app to generate code completions, fix errors in your code, or generate code from a natural language description.") |
|
|
| |
| code_completion_tab = st.tab("Code Completion") |
|
|
| with code_completion_tab: |
| st.write("Enter an incomplete code snippet:") |
| prompt_input = st.text_input("Prompt:", value="") |
| generate_button = st.button("Generate Completions") |
|
|
| if generate_button: |
| completions = code_complete(prompt_input) |
| st.write("Code completions:") |
| for i, completion in enumerate(completions): |
| st.write(f"{i+1}. {completion}") |
|
|
| |
| code_fixing_tab = st.tab("Code Fixing") |
|
|
| with code_fixing_tab: |
| st.write("Enter a code snippet with errors:") |
| code_input = st.text_area("Code:", height=300) |
| fix_button = st.button("Fix Errors") |
|
|
| if fix_button: |
| corrected_code = code_fix(code_input) |
| st.write("Corrected code:") |
| st.code(corrected_code) |
|
|
| |
| text_to_code_tab = st.tab("Text-to-Code") |
|
|
| with text_to_code_tab: |
| st.write("Enter a natural language description of the code:") |
| text_input = st.text_input("Description:", value="") |
| generate_button = st.button("Generate Code") |
|
|
| if generate_button: |
| generated_code = text_to_code(text_input) |
| st.write("Generated code:") |
| st.code(generated_code) |
|
|
| |
| if __name__ == "__main__": |
| st.run() |