BugGen / app.py
piliguori's picture
Update app.py
2456544 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
import torch
import autopep8
import glob
import re
import os
from huggingface_hub import hf_hub_download
# ==========================
# Utility functions
# ==========================
def normalize_indentation(code):
"""
Normalize indentation in example code by removing excessive tabs.
Also removes any backslash characters.
"""
code = code.replace("\\", "")
lines = code.split("\n")
if not lines:
return ""
fixed_lines = []
indent_fix_mode = False
for i, line in enumerate(lines):
if line.strip().startswith("def "):
fixed_lines.append(line)
indent_fix_mode = True
elif indent_fix_mode and line.strip():
# For indented lines in a function
if line.startswith("\t\t"): # Two tabs
fixed_lines.append("\t" + line[2:]) # Replace with one tab
elif line.startswith(" "): # 8 spaces (2 levels)
fixed_lines.append(" " + line[8:]) # Replace with 4 spaces
else:
fixed_lines.append(line)
else:
fixed_lines.append(line)
return "\n".join(fixed_lines)
def clear_text(text):
"""
Cleans text from escape sequences while preserving original formatting.
"""
temp_newline = "TEMP_NEWLINE_PLACEHOLDER"
temp_tab = "TEMP_TAB_PLACEHOLDER"
text = text.replace("\\n", temp_newline)
text = text.replace("\\t", temp_tab)
text = text.replace("\\", "")
text = text.replace(temp_newline, "\n")
text = text.replace(temp_tab, "\t")
return text
def encode_text(text):
"""
Encodes control characters into escape sequences.
"""
text = text.replace("\n", "\\n")
text = text.replace("\t", "\\t")
return text
def format_code(code):
"""
Format Python code using autopep8 with aggressive settings.
"""
try:
formatted_code = autopep8.fix_code(
code,
options={
"aggressive": 2,
"max_line_length": 88,
"indent_size": 4,
},
)
# Additional formatting for consistent spacing around parentheses and operators
formatted_code = formatted_code.replace("( ", "(").replace(" )", ")")
for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]:
formatted_code = formatted_code.replace(f"{op} ", op + " ")
formatted_code = formatted_code.replace(f" {op}", " " + op)
formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code)
return formatted_code
except Exception as e:
print(f"Error formatting code: {str(e)}")
return code
def fix_common_syntax_issues(code):
"""
Fix common syntax issues in generated code without modifying indentation.
"""
lines = code.split("\n")
fixed_lines = []
for line in lines:
stripped = line.strip()
if (
stripped.startswith("if ")
or stripped.startswith("elif ")
or stripped.startswith("else")
or stripped.startswith("for ")
or stripped.startswith("while ")
or stripped.startswith("def ")
or stripped.startswith("class ")
):
if not stripped.endswith(":") and not stripped.endswith("\\"):
line = line.rstrip() + ":"
fixed_lines.append(line)
code = "\n".join(fixed_lines)
# Fix mismatched quotes
quote_chars = ['"', "'"]
for quote in quote_chars:
if code.count(quote) % 2 != 0:
lines = code.split("\n")
for i, line in enumerate(lines):
if line.count(quote) % 2 != 0:
lines[i] = line.rstrip() + quote
break
code = "\n".join(lines)
# Fix missing parentheses in function calls
pattern = r"(\w+)\s*\([^)]*$"
if re.search(pattern, code):
lines = code.split("\n")
for i, line in enumerate(lines):
if re.search(pattern, line) and not any(
lines[j].strip().startswith(")")
for j in range(i + 1, min(i + 3, len(lines)))
):
lines[i] = line.rstrip() + ")"
code = "\n".join(lines)
return code
def load_example_from_file(example_path):
"""
Load example from a file with format:
description_BREAK_code
where 'code' uses \\n and \\t for formatting.
"""
try:
with open(example_path, "r") as f:
content = f.read()
parts = content.split("_BREAK_")
if len(parts) == 2:
description = parts[0].strip()
code = parts[1].strip()
code = code.replace("\\n", "\n").replace("\\t", "\t")
code = normalize_indentation(code)
return description, code
else:
print(f"Invalid format in example file: {example_path}")
return "", ""
except Exception as e:
print(f"Error loading example file {example_path}: {str(e)}")
return "", ""
def find_example_files():
"""
Find all raw.in example files in the examples directory.
"""
example_files = glob.glob("examples/*/raw.in")
return example_files
# ==========================
# Load model from HF Hub
# ==========================
BASE_MODEL_ID = "Salesforce/codet5p-770m"
FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs"
FINETUNED_FILENAME = "pytorch_model.bin"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading tokenizer from base model: {BASE_MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
print(f"Loading base model: {BASE_MODEL_ID}")
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID)
model.to(device)
print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}")
ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME)
print(f"Loading state_dict from: {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location="cpu")
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
model.eval()
# ==========================
# Gradio logic
# ==========================
# State variables
current_code = None
bug_counter = 0
def generate_bugged_code(description, code, chat_history, is_first_time):
global current_code, bug_counter
if chat_history is None:
chat_history = []
if is_first_time:
bug_counter = 0
current_code = None
chat_history = []
bug_counter += 1
if bug_counter == 1:
input_for_model = code
input_type = "original"
else:
if current_code is None:
return chat_history, gr.update(value=""), False
input_for_model = current_code
input_type = "previous bugged code"
print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}")
encoded_code = encode_text(input_for_model)
combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}"
inputs = tokenizer(
combined_input,
return_tensors="pt",
truncation=True,
max_length=512,
).input_ids.to(device)
try:
print("Starting generation...")
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=256,
num_beams=1,
do_sample=False,
early_stopping=True,
)
print("Generation done.")
except Exception as e:
print("Generation error:", repr(e))
raise e
bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True)
bugged_code = clear_text(bugged_code_escaped)
bugged_code = fix_common_syntax_issues(bugged_code)
bugged_code = format_code(bugged_code)
current_code = bugged_code
user_message = f"**Description**: {description}"
if input_type == "original":
user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```"
else:
user_message += (
f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```"
)
ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```"
chat_history = chat_history + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ai_message},
]
return chat_history, gr.update(value=""), False
def reset_interface():
global current_code, bug_counter
current_code = None
bug_counter = 0
return [], gr.update(value=""), True
example_files = find_example_files()
example_names = [
f"Example {i+1}: {os.path.basename(os.path.dirname(f))}"
for i, f in enumerate(example_files)
]
def load_example(example_index):
if example_index < len(example_files):
return load_example_from_file(example_files[example_index])
return "", ""
with gr.Blocks(title="Software-Fault Injection from NL") as demo:
gr.Markdown("# 🐞 Software-Fault Injection from Natural Language")
gr.Markdown(
"Generate Python code with specific bugs based on a description and original code. "
"The model used is **BugGen (CodeT5+ 770M, PyResBugs)**."
)
with gr.Row():
with gr.Column(scale=2):
description_input = gr.Textbox(
label="Bug Description",
placeholder="Describe the type of bug to introduce...",
lines=3,
)
code_input = gr.Code(
label="Original Code",
language="python",
lines=12,
)
is_first = gr.State(True)
submit_btn = gr.Button("Generate Bugged Code")
reset_btn = gr.Button("Start Over")
gr.Markdown("### Examples")
example_buttons = [gr.Button(name) for name in example_names]
with gr.Column(scale=3):
chat_output = gr.Chatbot(
label="Conversation",
height=500,
)
for i, btn in enumerate(example_buttons):
btn.click(
fn=lambda i=i: load_example(i),
outputs=[description_input, code_input],
)
submit_btn.click(
fn=generate_bugged_code,
inputs=[description_input, code_input, chat_output, is_first],
outputs=[chat_output, description_input, is_first],
)
reset_btn.click(
fn=reset_interface,
outputs=[chat_output, description_input, is_first],
)
print("Launching Gradio interface...")
demo.queue(max_size=10).launch()