| | import os |
| | import pandas as pd |
| | import random |
| | import time |
| |
|
| | from code_generation import generate_and_debug, prepare_working_folder |
| |
|
| | def select_seed_implementations( |
| | total_num_sample_solvers, |
| | num_sample_for_refine=None, |
| | ): |
| | if ( |
| | num_sample_for_refine is None or |
| | num_sample_for_refine > total_num_sample_solvers or |
| | num_sample_for_refine == -1 |
| | ): |
| | num_sample_for_refine = total_num_sample_solvers |
| |
|
| | |
| | selected_indices = random.sample(range(total_num_sample_solvers), num_sample_for_refine) |
| | |
| | return selected_indices |
| | |
| |
|
| |
|
| | def refine(cfg): |
| | num_repeated_samples = cfg.method.num_repeated_samples |
| | num_trials = cfg.method.num_debugging_trials_per_sample |
| | pde_name = cfg.pde.name |
| | working_folder = cfg.working_folder |
| | model_name = cfg.model.name |
| | num_sample_for_refine = cfg.method.num_sample_for_refine |
| | start_round = cfg.method.start_round |
| | use_sample_solver_init = cfg.method.use_sample_solver_init |
| | assert use_sample_solver_init, 'Sample solvers must be enabled for refinement' |
| |
|
| | sample_solver_folder = os.path.join( |
| | 'solvers', pde_name, cfg.pde.pde_setting_name, 'seeds' |
| | ) |
| | sample_solver_info = pd.read_csv( |
| | os.path.join(sample_solver_folder, 'seed_results.csv') |
| | ) |
| | total_num_sample_solvers = len(sample_solver_info) |
| |
|
| | if start_round == 0: |
| | prepare_working_folder( |
| | cfg, |
| | working_folder=working_folder, |
| | pde_name=pde_name, |
| | use_sample_solver_init=use_sample_solver_init |
| | ) |
| |
|
| | for round_idx in range(start_round, num_repeated_samples): |
| | try: |
| | seed_implementations = select_seed_implementations( |
| | total_num_sample_solvers=total_num_sample_solvers, |
| | num_sample_for_refine=num_sample_for_refine |
| | ) |
| | generate_and_debug( |
| | cfg, |
| | round_idx=round_idx, |
| | num_trials=num_trials, |
| | pde_name=pde_name, |
| | working_folder=working_folder, |
| | seed_implementations=seed_implementations, |
| | model_name=model_name |
| | ) |
| | except Exception as e: |
| | print(f'Error in sample {round_idx}: {e}. Move on to the next sample.') |
| | |
| | time.sleep(2) |
| |
|