|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import gradio as gr |
|
|
import torch |
|
|
import autopep8 |
|
|
import glob |
|
|
import re |
|
|
import os |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_indentation(code): |
|
|
""" |
|
|
Normalize indentation in example code by removing excessive tabs. |
|
|
Also removes any backslash characters. |
|
|
""" |
|
|
code = code.replace("\\", "") |
|
|
|
|
|
lines = code.split("\n") |
|
|
if not lines: |
|
|
return "" |
|
|
|
|
|
fixed_lines = [] |
|
|
indent_fix_mode = False |
|
|
|
|
|
for i, line in enumerate(lines): |
|
|
if line.strip().startswith("def "): |
|
|
fixed_lines.append(line) |
|
|
indent_fix_mode = True |
|
|
elif indent_fix_mode and line.strip(): |
|
|
|
|
|
if line.startswith("\t\t"): |
|
|
fixed_lines.append("\t" + line[2:]) |
|
|
elif line.startswith(" "): |
|
|
fixed_lines.append(" " + line[8:]) |
|
|
else: |
|
|
fixed_lines.append(line) |
|
|
else: |
|
|
fixed_lines.append(line) |
|
|
|
|
|
return "\n".join(fixed_lines) |
|
|
|
|
|
|
|
|
def clear_text(text): |
|
|
""" |
|
|
Cleans text from escape sequences while preserving original formatting. |
|
|
""" |
|
|
temp_newline = "TEMP_NEWLINE_PLACEHOLDER" |
|
|
temp_tab = "TEMP_TAB_PLACEHOLDER" |
|
|
|
|
|
text = text.replace("\\n", temp_newline) |
|
|
text = text.replace("\\t", temp_tab) |
|
|
|
|
|
text = text.replace("\\", "") |
|
|
|
|
|
text = text.replace(temp_newline, "\n") |
|
|
text = text.replace(temp_tab, "\t") |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
def encode_text(text): |
|
|
""" |
|
|
Encodes control characters into escape sequences. |
|
|
""" |
|
|
text = text.replace("\n", "\\n") |
|
|
text = text.replace("\t", "\\t") |
|
|
return text |
|
|
|
|
|
|
|
|
def format_code(code): |
|
|
""" |
|
|
Format Python code using autopep8 with aggressive settings. |
|
|
""" |
|
|
try: |
|
|
formatted_code = autopep8.fix_code( |
|
|
code, |
|
|
options={ |
|
|
"aggressive": 2, |
|
|
"max_line_length": 88, |
|
|
"indent_size": 4, |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
formatted_code = formatted_code.replace("( ", "(").replace(" )", ")") |
|
|
|
|
|
for op in ["+", "-", "*", "/", "=", "==", "!=", ">=", "<=", ">", "<"]: |
|
|
formatted_code = formatted_code.replace(f"{op} ", op + " ") |
|
|
formatted_code = formatted_code.replace(f" {op}", " " + op) |
|
|
|
|
|
formatted_code = re.sub(r"(\w+)\s+\(", r"\1(", formatted_code) |
|
|
|
|
|
return formatted_code |
|
|
except Exception as e: |
|
|
print(f"Error formatting code: {str(e)}") |
|
|
return code |
|
|
|
|
|
|
|
|
def fix_common_syntax_issues(code): |
|
|
""" |
|
|
Fix common syntax issues in generated code without modifying indentation. |
|
|
""" |
|
|
lines = code.split("\n") |
|
|
fixed_lines = [] |
|
|
|
|
|
for line in lines: |
|
|
stripped = line.strip() |
|
|
if ( |
|
|
stripped.startswith("if ") |
|
|
or stripped.startswith("elif ") |
|
|
or stripped.startswith("else") |
|
|
or stripped.startswith("for ") |
|
|
or stripped.startswith("while ") |
|
|
or stripped.startswith("def ") |
|
|
or stripped.startswith("class ") |
|
|
): |
|
|
if not stripped.endswith(":") and not stripped.endswith("\\"): |
|
|
line = line.rstrip() + ":" |
|
|
|
|
|
fixed_lines.append(line) |
|
|
|
|
|
code = "\n".join(fixed_lines) |
|
|
|
|
|
|
|
|
quote_chars = ['"', "'"] |
|
|
for quote in quote_chars: |
|
|
if code.count(quote) % 2 != 0: |
|
|
lines = code.split("\n") |
|
|
for i, line in enumerate(lines): |
|
|
if line.count(quote) % 2 != 0: |
|
|
lines[i] = line.rstrip() + quote |
|
|
break |
|
|
code = "\n".join(lines) |
|
|
|
|
|
|
|
|
pattern = r"(\w+)\s*\([^)]*$" |
|
|
if re.search(pattern, code): |
|
|
lines = code.split("\n") |
|
|
for i, line in enumerate(lines): |
|
|
if re.search(pattern, line) and not any( |
|
|
lines[j].strip().startswith(")") |
|
|
for j in range(i + 1, min(i + 3, len(lines))) |
|
|
): |
|
|
lines[i] = line.rstrip() + ")" |
|
|
code = "\n".join(lines) |
|
|
|
|
|
return code |
|
|
|
|
|
|
|
|
def load_example_from_file(example_path): |
|
|
""" |
|
|
Load example from a file with format: |
|
|
description_BREAK_code |
|
|
where 'code' uses \\n and \\t for formatting. |
|
|
""" |
|
|
try: |
|
|
with open(example_path, "r") as f: |
|
|
content = f.read() |
|
|
|
|
|
parts = content.split("_BREAK_") |
|
|
if len(parts) == 2: |
|
|
description = parts[0].strip() |
|
|
code = parts[1].strip() |
|
|
|
|
|
code = code.replace("\\n", "\n").replace("\\t", "\t") |
|
|
code = normalize_indentation(code) |
|
|
|
|
|
return description, code |
|
|
else: |
|
|
print(f"Invalid format in example file: {example_path}") |
|
|
return "", "" |
|
|
except Exception as e: |
|
|
print(f"Error loading example file {example_path}: {str(e)}") |
|
|
return "", "" |
|
|
|
|
|
|
|
|
def find_example_files(): |
|
|
""" |
|
|
Find all raw.in example files in the examples directory. |
|
|
""" |
|
|
example_files = glob.glob("examples/*/raw.in") |
|
|
return example_files |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_MODEL_ID = "Salesforce/codet5p-770m" |
|
|
FINETUNED_REPO_ID = "OSS-forge/codet5p-770m-pyresbugs" |
|
|
FINETUNED_FILENAME = "pytorch_model.bin" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
print(f"Loading tokenizer from base model: {BASE_MODEL_ID}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
|
|
|
print(f"Loading base model: {BASE_MODEL_ID}") |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_ID) |
|
|
model.to(device) |
|
|
|
|
|
print(f"Downloading fine-tuned weights from repo: {FINETUNED_REPO_ID}") |
|
|
ckpt_path = hf_hub_download(FINETUNED_REPO_ID, FINETUNED_FILENAME) |
|
|
|
|
|
print(f"Loading state_dict from: {ckpt_path}") |
|
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
if "model_state_dict" in state_dict: |
|
|
state_dict = state_dict["model_state_dict"] |
|
|
|
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
print(f"Loaded fine-tuned weights. Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_code = None |
|
|
bug_counter = 0 |
|
|
|
|
|
|
|
|
def generate_bugged_code(description, code, chat_history, is_first_time): |
|
|
global current_code, bug_counter |
|
|
|
|
|
if chat_history is None: |
|
|
chat_history = [] |
|
|
|
|
|
if is_first_time: |
|
|
bug_counter = 0 |
|
|
current_code = None |
|
|
chat_history = [] |
|
|
|
|
|
bug_counter += 1 |
|
|
|
|
|
if bug_counter == 1: |
|
|
input_for_model = code |
|
|
input_type = "original" |
|
|
else: |
|
|
if current_code is None: |
|
|
return chat_history, gr.update(value=""), False |
|
|
input_for_model = current_code |
|
|
input_type = "previous bugged code" |
|
|
|
|
|
print(f"Using {input_type} - counter: {bug_counter}\n{input_for_model}") |
|
|
|
|
|
encoded_code = encode_text(input_for_model) |
|
|
combined_input = f"Description: {description} _BREAK_ Code: {encoded_code}" |
|
|
|
|
|
inputs = tokenizer( |
|
|
combined_input, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
).input_ids.to(device) |
|
|
|
|
|
try: |
|
|
print("Starting generation...") |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
inputs, |
|
|
max_new_tokens=256, |
|
|
num_beams=1, |
|
|
do_sample=False, |
|
|
early_stopping=True, |
|
|
) |
|
|
print("Generation done.") |
|
|
except Exception as e: |
|
|
print("Generation error:", repr(e)) |
|
|
raise e |
|
|
|
|
|
bugged_code_escaped = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
bugged_code = clear_text(bugged_code_escaped) |
|
|
bugged_code = fix_common_syntax_issues(bugged_code) |
|
|
bugged_code = format_code(bugged_code) |
|
|
|
|
|
current_code = bugged_code |
|
|
|
|
|
user_message = f"**Description**: {description}" |
|
|
if input_type == "original": |
|
|
user_message += f"\n\n**Original code**:\n```python\n{input_for_model}\n```" |
|
|
else: |
|
|
user_message += ( |
|
|
f"\n\n**Previous bugged code**:\n```python\n{input_for_model}\n```" |
|
|
) |
|
|
|
|
|
ai_message = f"**Bugged code**:\n```python\n{bugged_code}\n```" |
|
|
|
|
|
chat_history = chat_history + [ |
|
|
{"role": "user", "content": user_message}, |
|
|
{"role": "assistant", "content": ai_message}, |
|
|
] |
|
|
|
|
|
return chat_history, gr.update(value=""), False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_interface(): |
|
|
global current_code, bug_counter |
|
|
current_code = None |
|
|
bug_counter = 0 |
|
|
return [], gr.update(value=""), True |
|
|
|
|
|
|
|
|
example_files = find_example_files() |
|
|
example_names = [ |
|
|
f"Example {i+1}: {os.path.basename(os.path.dirname(f))}" |
|
|
for i, f in enumerate(example_files) |
|
|
] |
|
|
|
|
|
|
|
|
def load_example(example_index): |
|
|
if example_index < len(example_files): |
|
|
return load_example_from_file(example_files[example_index]) |
|
|
return "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Software-Fault Injection from NL") as demo: |
|
|
gr.Markdown("# π Software-Fault Injection from Natural Language") |
|
|
gr.Markdown( |
|
|
"Generate Python code with specific bugs based on a description and original code. " |
|
|
"The model used is **BugGen (CodeT5+ 770M, PyResBugs)**." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
description_input = gr.Textbox( |
|
|
label="Bug Description", |
|
|
placeholder="Describe the type of bug to introduce...", |
|
|
lines=3, |
|
|
) |
|
|
code_input = gr.Code( |
|
|
label="Original Code", |
|
|
language="python", |
|
|
lines=12, |
|
|
) |
|
|
|
|
|
is_first = gr.State(True) |
|
|
|
|
|
submit_btn = gr.Button("Generate Bugged Code") |
|
|
reset_btn = gr.Button("Start Over") |
|
|
|
|
|
gr.Markdown("### Examples") |
|
|
example_buttons = [gr.Button(name) for name in example_names] |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
chat_output = gr.Chatbot( |
|
|
label="Conversation", |
|
|
height=500, |
|
|
) |
|
|
|
|
|
for i, btn in enumerate(example_buttons): |
|
|
btn.click( |
|
|
fn=lambda i=i: load_example(i), |
|
|
outputs=[description_input, code_input], |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=generate_bugged_code, |
|
|
inputs=[description_input, code_input, chat_output, is_first], |
|
|
outputs=[chat_output, description_input, is_first], |
|
|
) |
|
|
|
|
|
reset_btn.click( |
|
|
fn=reset_interface, |
|
|
outputs=[chat_output, description_input, is_first], |
|
|
) |
|
|
|
|
|
print("Launching Gradio interface...") |
|
|
demo.queue(max_size=10).launch() |