Spaces:
Runtime error
Runtime error
| """ | |
| Copyright©2023 Max-Planck-Gesellschaft zur Förderung | |
| der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| for Intelligent Systems. All rights reserved. | |
| Author: Soyong Shin, Marilyn Keller | |
| See https://skel.is.tue.mpg.de/license.html for licensing and contact information. | |
| """ | |
| import traceback | |
| import math | |
| import os | |
| import pickle | |
| import torch | |
| import smplx | |
| import omegaconf | |
| import torch.nn.functional as F | |
| from psbody.mesh import Mesh, MeshViewer, MeshViewers | |
| from tqdm import trange | |
| from pathlib import Path | |
| import lib.body_models.skel.config as cg | |
| from lib.body_models.skel.skel_model import SKEL | |
| from .losses import compute_anchor_pose, compute_anchor_trans, compute_pose_loss, compute_scapula_loss, compute_spine_loss, compute_time_loss, pretty_loss_print | |
| from .utils import location_to_spheres, to_numpy, to_params, to_torch | |
| from .align_config import config | |
| from .align_config_joint import config as config_joint | |
| class SkelFitter(object): | |
| def __init__(self, gender, device, num_betas=10, export_meshes=False, joint_optim=False) -> None: | |
| self.smpl = smplx.create(cg.smpl_folder, model_type='smpl', gender=gender, num_betas=num_betas, batch_size=1, export_meshes=False).to(device) | |
| self.skel = SKEL(gender).to(device) | |
| self.gender = gender | |
| self.device = device | |
| self.num_betas = num_betas | |
| # Instanciate masks used for the vertex to vertex fitting | |
| fitting_mask_file = Path(__file__).parent / 'riggid_parts_mask.pkl' | |
| fitting_indices = pickle.load(open(fitting_mask_file, 'rb')) | |
| fitting_mask = torch.zeros(6890, dtype=torch.bool, device=self.device) | |
| fitting_mask[fitting_indices] = 1 | |
| self.fitting_mask = fitting_mask.reshape(1, -1, 1).to(self.device) # 1xVx1 to be applied to verts that are BxVx3 | |
| smpl_torso_joints = [0,3] | |
| verts_mask = (self.smpl.lbs_weights[:,smpl_torso_joints]>0.5).sum(dim=-1)>0 | |
| self.torso_verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) # Because verts are of shape BxVx3 | |
| self.export_meshes = export_meshes | |
| # make the cfg being an object using omegaconf | |
| if joint_optim: | |
| self.cfg = omegaconf.OmegaConf.create(config_joint) | |
| else: | |
| self.cfg = omegaconf.OmegaConf.create(config) | |
| # Instanciate the mesh viewer to visualize the fitting | |
| if('DISABLE_VIEWER' in os.environ): | |
| self.mv = None | |
| print("\n DISABLE_VIEWER flag is set, running in headless mode") | |
| else: | |
| self.mv = MeshViewers((1,2), keepalive=self.cfg.keepalive_meshviewer) | |
| def run_fit(self, | |
| trans_in, | |
| betas_in, | |
| poses_in, | |
| batch_size=20, | |
| skel_data_init=None, | |
| force_recompute=False, | |
| debug=False, | |
| watch_frame=0, | |
| freevert_mesh=None, | |
| opt_sequence=False, | |
| fix_poses=False, | |
| variant_exp=''): | |
| """Align SKEL to a SMPL sequence.""" | |
| self.nb_frames = poses_in.shape[0] | |
| self.watch_frame = watch_frame | |
| self.is_skel_data_init = skel_data_init is not None | |
| self.force_recompute = force_recompute | |
| print('Fitting {} frames'.format(self.nb_frames)) | |
| print('Watching frame: {}'.format(watch_frame)) | |
| # Initialize SKEL torch params | |
| body_params = self._init_params(betas_in, poses_in, trans_in, skel_data_init, variant_exp) | |
| # We cut the whole sequence in batches for parallel optimization | |
| if batch_size > self.nb_frames: | |
| batch_size = self.nb_frames | |
| print('Batch size is larger than the number of frames. Setting batch size to {}'.format(batch_size)) | |
| n_batch = math.ceil(self.nb_frames/batch_size) | |
| pbar = trange(n_batch, desc='Running batch optimization') | |
| # Initialize the res dict to store the per frame result skel parameters | |
| out_keys = ['poses', 'betas', 'trans'] | |
| if self.export_meshes: | |
| out_keys += ['skel_v', 'skin_v', 'smpl_v'] | |
| res_dict = {key: [] for key in out_keys} | |
| res_dict['gender'] = self.gender | |
| if self.export_meshes: | |
| res_dict['skel_f'] = self.skel.skel_f.cpu().numpy().copy() | |
| res_dict['skin_f'] = self.skel.skin_f.cpu().numpy().copy() | |
| res_dict['smpl_f'] = self.smpl.faces | |
| # Iterate over the batches to fit the whole sequence | |
| for i in pbar: | |
| if debug: | |
| # Only run the first batch to test, ignore the rest | |
| if i > 1: | |
| continue | |
| # Get batch start and end indices | |
| i_start = i * batch_size | |
| i_end = min((i+1) * batch_size, self.nb_frames) | |
| # Fit the batch | |
| betas, poses, trans, verts = self._fit_batch(body_params, i, i_start, i_end, enable_time=opt_sequence, fix_poses=fix_poses) | |
| # if torch.isnan(betas).any() \ | |
| # or torch.isnan(poses).any() \ | |
| # or torch.isnan(trans).any(): | |
| # print(f'Nan values detected.') | |
| # raise ValueError('Nan values detected in the output.') | |
| # Store ethe results | |
| res_dict['poses'].append(poses) | |
| res_dict['betas'].append(betas) | |
| res_dict['trans'].append(trans) | |
| if self.export_meshes: | |
| # Store the meshes vertices | |
| skel_output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', skelmesh=True) | |
| res_dict['skel_v'].append(skel_output.skel_verts) | |
| res_dict['skin_v'].append(skel_output.skin_verts) | |
| res_dict['smpl_v'].append(verts) | |
| if opt_sequence: | |
| # Initialize the next frames with current frame | |
| body_params['poses_skel'][i_end:] = poses[-1:].detach() | |
| body_params['trans_skel'][i_end:] = trans[-1].detach() | |
| body_params['betas_skel'][i_end:] = betas[-1:].detach() | |
| # Concatenate the batches and convert to numpy | |
| for key, val in res_dict.items(): | |
| if isinstance(val, list): | |
| res_dict[key] = torch.cat(val, dim=0).detach().cpu().numpy() | |
| return res_dict | |
| def _init_params(self, betas_smpl, poses_smpl, trans_smpl, skel_data_init=None, variant_exp=''): | |
| """ Return initial SKEL parameters from SMPL data dictionary and an optional SKEL data dictionary.""" | |
| if skel_data_init is None or self.force_recompute: | |
| poses_skel = torch.zeros((self.nb_frames, self.skel.num_q_params), device=self.device) | |
| if variant_exp == '' or variant_exp == '_official_old': | |
| poses_skel[:, :3] = poses_smpl[:, :3] # Global orient are similar between SMPL and SKEL, so init with SMPL angles | |
| elif variant_exp == '_official_fix': | |
| # https://github.com/MarilynKeller/SKEL/commit/d1f6ff62235c142ba010158e00e21fd4fe25807f#diff-09188717a56a42e9589e9bd289f9ddb4fb53160e03c81a7ced70b3a84c1d9d0bR157 | |
| pass | |
| elif variant_exp == '_my_fix': | |
| gt_orient_aa = poses_smpl[:, :3] | |
| # IMPORTANT: The alignment comes from `exp/inspect_skel/archive/orientation.py`. | |
| from lib.utils.geometry.rotation import axis_angle_to_matrix, matrix_to_euler_angles | |
| gt_orient_mat = axis_angle_to_matrix(gt_orient_aa) | |
| gt_orient_ea = matrix_to_euler_angles(gt_orient_mat, 'YXZ') | |
| flip = torch.tensor([-1, 1, 1], device=self.device) | |
| poses_skel[:, :3] = gt_orient_ea[:, [2, 1, 0]] * flip | |
| else: | |
| raise ValueError(f'Unknown variant_exp {variant_exp}') | |
| betas_skel = torch.zeros((self.nb_frames, 10), device=self.device) | |
| betas_skel[:] = betas_smpl[..., :10] | |
| trans_skel = trans_smpl # Translation is similar between SMPL and SKEL, so init with SMPL translation | |
| else: | |
| # Load from previous alignment | |
| betas_skel = to_torch(skel_data_init['betas'], self.device) | |
| poses_skel = to_torch(skel_data_init['poses'], self.device) | |
| trans_skel = to_torch(skel_data_init['trans'], self.device) | |
| # Make a dictionary out of the necessary body parameters | |
| body_params = { | |
| 'betas_skel': betas_skel, | |
| 'poses_skel': poses_skel, | |
| 'trans_skel': trans_skel, | |
| 'betas_smpl': betas_smpl, | |
| 'poses_smpl': poses_smpl, | |
| 'trans_smpl': trans_smpl | |
| } | |
| return body_params | |
| def _fit_batch(self, body_params, i, i_start, i_end, enable_time=False, fix_poses=False): | |
| """ Create parameters for the batch and run the optimization.""" | |
| # Sample a batch ver | |
| body_params = { key: val[i_start:i_end] for key, val in body_params.items()} | |
| # SMPL params | |
| betas_smpl = body_params['betas_smpl'] | |
| poses_smpl = body_params['poses_smpl'] | |
| trans_smpl = body_params['trans_smpl'] | |
| # SKEL params | |
| betas = to_params(body_params['betas_skel'], device=self.device) | |
| poses = to_params(body_params['poses_skel'], device=self.device) | |
| trans = to_params(body_params['trans_skel'], device=self.device) | |
| if 'verts' in body_params: | |
| verts = body_params['verts'] | |
| else: | |
| # Run a SMPL forward pass to get the SMPL body vertices | |
| smpl_output = self.smpl(betas=betas_smpl, body_pose=poses_smpl[:,3:], transl=trans_smpl, global_orient=poses_smpl[:,:3]) | |
| verts = smpl_output.vertices | |
| # Optimize | |
| config = self.cfg.optim_steps | |
| current_cfg = config[0] | |
| # from lib.kits.debug import set_trace | |
| # set_trace() | |
| try: | |
| if fix_poses: | |
| # for ci, cfg in enumerate(config[1:]): | |
| for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step | |
| current_cfg.update(cfg) | |
| print(f'Step {ci+1}: {current_cfg.description}') | |
| self._optim([trans,betas], poses, betas, trans, verts, current_cfg, enable_time) | |
| else: | |
| if not enable_time or not self.is_skel_data_init: | |
| # Optimize the global rotation and translation for the initial fitting | |
| print(f'Step 0: {current_cfg.description}') | |
| self._optim([trans,poses], poses, betas, trans, verts, current_cfg, enable_time) | |
| for ci, cfg in enumerate(config[1:]): | |
| # for ci, cfg in enumerate([config[-1]]): # To debug, only run the last step | |
| current_cfg.update(cfg) | |
| print(f'Step {ci+1}: {current_cfg.description}') | |
| self._optim([poses], poses, betas, trans, verts, current_cfg, enable_time) | |
| # # Refine by optimizing the whole body | |
| # cfg.update(self.cfg_optim[]) | |
| # cfg.update({'mode' : 'free', 'tolerance_change': 0.0001, 'l_joint': 0.2e4}) | |
| # self._optim([trans, poses], poses, betas, trans, verts, cfg) | |
| except Exception as e: | |
| print(e) | |
| traceback.print_exc() | |
| # from lib.kits.debug import set_trace | |
| # set_trace() | |
| return betas, poses, trans, verts | |
| def _optim(self, | |
| params, | |
| poses, | |
| betas, | |
| trans, | |
| verts, | |
| cfg, | |
| enable_time=False, | |
| ): | |
| # regress anatomical joints from SMPL's vertices | |
| anat_joints = torch.einsum('bik,ji->bjk', [verts, self.skel.J_regressor_osim]) | |
| dJ=torch.zeros((poses.shape[0], 24, 3), device=betas.device) | |
| # Create the optimizer | |
| optimizer = torch.optim.LBFGS(params, | |
| lr=cfg.lr, | |
| max_iter=cfg.max_iter, | |
| line_search_fn=cfg.line_search_fn, | |
| tolerance_change=cfg.tolerance_change) | |
| poses_init = poses.detach().clone() | |
| trans_init = trans.detach().clone() | |
| def closure(): | |
| optimizer.zero_grad() | |
| # fi = self.watch_frame #frame of the batch to display | |
| # output = self.skel.forward(poses=poses[fi:fi+1], | |
| # betas=betas[fi:fi+1], | |
| # trans=trans[fi:fi+1], | |
| # poses_type='skel', | |
| # dJ=dJ[fi:fi+1], | |
| # skelmesh=True) | |
| # self._fstep_plot(output, cfg, verts[fi:fi+1], anat_joints[fi:fi+1], ) | |
| loss_dict = self._fitting_loss(poses, | |
| poses_init, | |
| betas, | |
| trans, | |
| trans_init, | |
| dJ, | |
| anat_joints, | |
| verts, | |
| cfg, | |
| enable_time) | |
| # print(pretty_loss_print(loss_dict)) | |
| loss = sum(loss_dict.values()) | |
| loss.backward() | |
| return loss | |
| for step_i in range(cfg.num_steps): | |
| loss = optimizer.step(closure).item() | |
| def _get_masks(self, cfg): | |
| pose_mask = torch.ones((self.skel.num_q_params)).to(self.device).unsqueeze(0) | |
| verts_mask = torch.ones_like(self.fitting_mask) | |
| joint_mask = torch.ones((self.skel.num_joints, 3)).to(self.device).unsqueeze(0).bool() | |
| # Mask vertices | |
| if cfg.mode=='root_only': | |
| # Only optimize the global rotation of the body, i.e. the first 3 angles of the pose | |
| pose_mask[:] = 0 # Only optimize for the global rotation | |
| pose_mask[:,:3] = 1 | |
| # Only fit the thorax vertices to recover the proper body orientation and translation | |
| verts_mask = self.torso_verts_mask | |
| elif cfg.mode=='fixed_upper_limbs': | |
| upper_limbs_joints = [0,1,2,3,6,9,12,15,17] | |
| verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 | |
| verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) | |
| joint_mask[:, [3,4,5,8,9,10,18,23], :] = 0 # Do not try to match the joints of the upper limbs | |
| pose_mask[:] = 1 | |
| pose_mask[:,:3] = 0 # Block the global rotation | |
| pose_mask[:,19] = 0 # block the lumbar twist | |
| # pose_mask[:, 36:39] = 0 | |
| # pose_mask[:, 43:46] = 0 | |
| # pose_mask[:, 62:65] = 0 | |
| # pose_mask[:, 62:65] = 0 | |
| elif cfg.mode=='fixed_root': | |
| pose_mask[:] = 1 | |
| pose_mask[:,:3] = 0 # Block the global rotation | |
| # pose_mask[:,19] = 0 # block the lumbar twist | |
| # The orientation of the upper limbs is often wrong in SMPL so ignore these vertices for the finale step | |
| upper_limbs_joints = [1,2,16,17] | |
| verts_mask = (self.smpl.lbs_weights[:,upper_limbs_joints]>0.5).sum(dim=-1)>0 | |
| verts_mask = torch.logical_not(verts_mask) | |
| verts_mask = verts_mask.unsqueeze(0).unsqueeze(-1) | |
| elif cfg.mode=='free': | |
| verts_mask = torch.ones_like(self.fitting_mask ) | |
| joint_mask[:]=0 | |
| joint_mask[:, [19,14], :] = 1 # Only fir the scapula join to avoid collapsing shoulders | |
| else: | |
| raise ValueError(f'Unknown mode {cfg.mode}') | |
| return pose_mask, verts_mask, joint_mask | |
| def _fitting_loss(self, | |
| poses, | |
| poses_init, | |
| betas, | |
| trans, | |
| trans_init, | |
| dJ, | |
| anat_joints, | |
| verts, | |
| cfg, | |
| enable_time=False): | |
| loss_dict = {} | |
| pose_mask, verts_mask, joint_mask = self._get_masks(cfg) | |
| poses = poses * pose_mask + poses_init * (1-pose_mask) | |
| # Mask joints to not optimize before computing the losses | |
| output = self.skel.forward(poses=poses, betas=betas, trans=trans, poses_type='skel', dJ=dJ, skelmesh=False) | |
| # Fit the SMPL vertices | |
| # We know the skinning of the forearm and the neck are not perfect, | |
| # so we create a mask of the SMPL vertices that are important to fit, like the hands and the head | |
| loss_dict['verts_loss_loose'] = cfg.l_verts_loose * (verts_mask * (output.skin_verts - verts)**2).sum() / (((verts_mask).sum()*self.nb_frames)) | |
| # Fit the regressed joints, this avoids collapsing shoulders | |
| # loss_dict['joint_loss'] = cfg.l_joint * F.mse_loss(output.joints, anat_joints) | |
| loss_dict['joint_loss'] = cfg.l_joint * (joint_mask * (output.joints - anat_joints)**2).mean() | |
| # Time consistancy | |
| if poses.shape[0] > 1 and enable_time: | |
| # This avoids unstable hips orientationZ | |
| loss_dict['time_loss'] = cfg.l_time_loss * F.mse_loss(poses[1:], poses[:-1]) | |
| loss_dict['pose_loss'] = cfg.l_pose_loss * compute_pose_loss(poses, poses_init) | |
| if cfg.use_basic_loss is False: | |
| # These losses can be used to regularize the optimization but are not always necessary | |
| loss_dict['anch_rot'] = cfg.l_anch_pose * compute_anchor_pose(poses, poses_init) | |
| loss_dict['anch_trans'] = cfg.l_anch_trans * compute_anchor_trans(trans, trans_init) | |
| loss_dict['verts_loss'] = cfg.l_verts * (verts_mask * self.fitting_mask * (output.skin_verts - verts)**2).sum() / (self.fitting_mask*verts_mask).sum() | |
| # Regularize the pose | |
| loss_dict['scapula_loss'] = cfg.l_scapula_loss * compute_scapula_loss(poses) | |
| loss_dict['spine_loss'] = cfg.l_spine_loss * compute_spine_loss(poses) | |
| # Adjust the losses of all the pose regularizations sub losses with the pose_reg_factor value | |
| for key in ['scapula_loss', 'spine_loss', 'pose_loss']: | |
| loss_dict[key] = cfg.pose_reg_factor * loss_dict[key] | |
| return loss_dict | |
| def _fstep_plot(self, output, cfg, verts, anat_joints): | |
| "Function to plot each step" | |
| if('DISABLE_VIEWER' in os.environ): | |
| return | |
| pose_mask, verts_mask, joint_mask = self._get_masks(cfg) | |
| skin_err_value = ((output.skin_verts[0] - verts[0])**2).sum(dim=-1).sqrt() | |
| skin_err_value = skin_err_value / 0.05 | |
| skin_err_value = to_numpy(skin_err_value) | |
| skin_mesh = Mesh(v=to_numpy(output.skin_verts[0]), f=[], vc='white') | |
| skel_mesh = Mesh(v=to_numpy(output.skel_verts[0]), f=self.skel.skel_f.cpu().numpy(), vc='white') | |
| # Display vertex distance on SMPL | |
| smpl_verts = to_numpy(verts[0]) | |
| smpl_mesh = Mesh(v=smpl_verts, f=self.smpl.faces) | |
| smpl_mesh.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) | |
| smpl_mesh_masked = Mesh(v=smpl_verts[to_numpy(verts_mask[0,:,0])], f=[], vc='green') | |
| smpl_mesh_pc = Mesh(v=smpl_verts, f=[], vc='green') | |
| skin_mesh_err = Mesh(v=to_numpy(output.skin_verts[0]), f=self.skel.skin_f.cpu().numpy(), vc='white') | |
| skin_mesh_err.set_vertex_colors_from_weights(skin_err_value, scale_to_range_1=False) | |
| # List the meshes to display | |
| meshes_left = [skin_mesh_err, smpl_mesh_pc] | |
| meshes_right = [smpl_mesh_masked, skin_mesh, skel_mesh] | |
| if cfg.l_joint > 0: | |
| # Plot the joints | |
| meshes_right += location_to_spheres(to_numpy(output.joints[joint_mask[:,:,0]]), color=(1,0,0), radius=0.02) | |
| meshes_right += location_to_spheres(to_numpy(anat_joints[joint_mask[:,:,0]]), color=(0,1,0), radius=0.02) \ | |
| self.mv[0][0].set_dynamic_meshes(meshes_left) | |
| self.mv[0][1].set_dynamic_meshes(meshes_right) | |
| # print(poses[frame_to_watch, :3]) | |
| # print(trans[frame_to_watch]) | |
| # print(betas[frame_to_watch, :3]) | |
| # mv.get_keypress() | |