Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import torch | |
| def load_state_dicts(folder_path): | |
| state_dicts = {} | |
| for filename in os.listdir(folder_path): | |
| if filename.endswith(".pth"): | |
| print('Processing {}'.format(filename)) | |
| file_path = os.path.join(folder_path, filename) | |
| state_dict = torch.load(file_path) | |
| new_state_dict = {"state_dict": {}, | |
| "optimizer": state_dict['optimizer'], | |
| "meta": state_dict['meta'], | |
| } | |
| for key in state_dict['state_dict'].keys(): | |
| if 'spatial_pos_encoder' in key or 'skeleton_head.MLP' in key or 'skeleton_head.adj_output_mlp' in key: | |
| continue | |
| new_key = key.replace("keypoint_head.", "keypoint_head_module.").replace('bias_function_prior_weight', 'markov_structural_mlp') | |
| new_state_dict['state_dict'][new_key] = state_dict['state_dict'][key] | |
| new_file_path = os.path.join(folder_path, f'{filename}') | |
| print(f'Saving to {new_file_path}') | |
| torch.save(new_state_dict, new_file_path) | |
| return state_dicts | |
| if __name__ == "__main__": | |
| folder_path = sys.argv[1] | |
| load_state_dicts(folder_path) |