Spaces:
Runtime error
Runtime error
| """ | |
| Copyright©2024 Max-Planck-Gesellschaft zur Förderung | |
| der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| for Intelligent Systems. All rights reserved. | |
| Author: Marilyn Keller | |
| See https://skel.is.tue.mpg.de/license.html for licensing and contact information. | |
| """ | |
| import os | |
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| import pickle as pkl | |
| from typing import NewType, Optional | |
| from lib.body_models.skel.joints_def import curve_torch_3d, left_scapula, right_scapula | |
| from lib.body_models.skel.osim_rot import ConstantCurvatureJoint, CustomJoint, EllipsoidJoint, PinJoint, WalkerKnee | |
| from lib.body_models.skel.utils import build_homog_matrix, rotation_matrix_from_vectors, sparce_coo_matrix2tensor, with_zeros, matmul_chain | |
| from dataclasses import dataclass, fields | |
| from lib.body_models.skel.kin_skel import scaling_keypoints, pose_param_names, smpl_joint_corresp | |
| import lib.body_models.skel.config as cg | |
| Tensor = NewType('Tensor', torch.Tensor) | |
| class ModelOutput: | |
| vertices: Optional[Tensor] = None | |
| joints: Optional[Tensor] = None | |
| full_pose: Optional[Tensor] = None | |
| global_orient: Optional[Tensor] = None | |
| transl: Optional[Tensor] = None | |
| v_shaped: Optional[Tensor] = None | |
| def __getitem__(self, key): | |
| return getattr(self, key) | |
| def get(self, key, default=None): | |
| return getattr(self, key, default) | |
| def __iter__(self): | |
| return self.keys() | |
| def keys(self): | |
| keys = [t.name for t in fields(self)] | |
| return iter(keys) | |
| def values(self): | |
| values = [getattr(self, t.name) for t in fields(self)] | |
| return iter(values) | |
| def items(self): | |
| data = [(t.name, getattr(self, t.name)) for t in fields(self)] | |
| return iter(data) | |
| class SKELOutput(ModelOutput): | |
| betas: Optional[Tensor] = None | |
| body_pose: Optional[Tensor] = None | |
| skin_verts: Optional[Tensor] = None | |
| skel_verts: Optional[Tensor] = None | |
| joints: Optional[Tensor] = None | |
| joints_ori: Optional[Tensor] = None | |
| betas: Optional[Tensor] = None | |
| poses: Optional[Tensor] = None | |
| trans : Optional[Tensor] = None | |
| pose_offsets : Optional[Tensor] = None | |
| joints_tpose : Optional[Tensor] = None | |
| v_skin_shaped : Optional[Tensor] = None | |
| class SKEL(nn.Module): | |
| num_betas = 10 | |
| def __init__(self, gender, model_path=None, custom_joint_reg_path=None, **kwargs): | |
| super(SKEL, self).__init__() | |
| if gender not in ['male', 'female']: | |
| raise RuntimeError(f'Invalid Gender, got {gender}') | |
| self.gender = gender | |
| if model_path is None: | |
| # skel_file = f"/Users/mkeller2/Data/skel_models_v1.0/skel_{gender}.pkl" | |
| skel_file = os.path.join(cg.skel_folder, f"skel_{gender}.pkl") | |
| else: | |
| skel_file = os.path.join(model_path, f"skel_{gender}.pkl") | |
| assert os.path.exists(skel_file), f"Skel model file {skel_file} does not exist" | |
| skel_data = pkl.load(open(skel_file, 'rb')) | |
| # Check that the version of the skel model is compatible with this loader | |
| assert 'version' in skel_data, f"Expected version 1.1.1 of the SKEL picke. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html" | |
| version = skel_data['version'] | |
| assert version == '1.1.1', f"Expected version 1.1.1, got {version}. Please download the latest skel pkl versions from https://skel.is.tue.mpg.de/download.html" | |
| self.num_betas = 10 | |
| self.num_q_params = 46 | |
| self.bone_names = skel_data['bone_names'] | |
| self.num_joints = skel_data['J_regressor_osim'].shape[0] | |
| self.num_joints_smpl = skel_data['J_regressor'].shape[0] | |
| self.joints_name = skel_data['joints_name'] | |
| self.pose_params_name = skel_data['pose_params_name'] | |
| # register the template meshes | |
| self.register_buffer('skin_template_v', torch.FloatTensor(skel_data['skin_template_v'])) | |
| self.register_buffer('skin_f', torch.LongTensor(skel_data['skin_template_f'])) | |
| self.register_buffer('skel_template_v', torch.FloatTensor(skel_data['skel_template_v'])) | |
| self.register_buffer('skel_f', torch.LongTensor(skel_data['skel_template_f'])) | |
| # Shape corrective blend shapes | |
| self.register_buffer('shapedirs', torch.FloatTensor(np.array(skel_data['shapedirs'][:,:,:self.num_betas]))) | |
| self.register_buffer('posedirs', torch.FloatTensor(np.array(skel_data['posedirs']))) | |
| # Model sparse joints regressor, regresses joints location from a mesh | |
| self.register_buffer('J_regressor', sparce_coo_matrix2tensor(skel_data['J_regressor'])) | |
| # Regress the anatomical joint location with a regressor learned from BioAmass | |
| if custom_joint_reg_path is not None: | |
| J_regressor_skel = pkl.load(open(custom_joint_reg_path, 'rb')) | |
| if 'scipy.sparse' in str(type(J_regressor_skel)): | |
| J_regressor_skel = J_regressor_skel.todense() | |
| self.register_buffer('J_regressor_osim', torch.FloatTensor(J_regressor_skel)) | |
| print('WARNING: Using custom joint regressor') | |
| else: | |
| self.register_buffer('J_regressor_osim', sparce_coo_matrix2tensor(skel_data['J_regressor_osim'], make_dense=True)) | |
| self.register_buffer('per_joint_rot', torch.FloatTensor(skel_data['per_joint_rot'])) | |
| # Skin model skinning weights | |
| self.register_buffer('skin_weights', sparce_coo_matrix2tensor(skel_data['skin_weights'])) | |
| # Skeleton model skinning weights | |
| self.register_buffer('skel_weights', sparce_coo_matrix2tensor(skel_data['skel_weights'])) | |
| self.register_buffer('skel_weights_rigid', sparce_coo_matrix2tensor(skel_data['skel_weights_rigid'])) | |
| # Kinematic tree of the model | |
| self.register_buffer('kintree_table', torch.from_numpy(skel_data['osim_kintree_table'].astype(np.int64))) | |
| self.register_buffer('parameter_mapping', torch.from_numpy(skel_data['parameter_mapping'].astype(np.int64))) | |
| # transformation from osim can pose to T pose | |
| self.register_buffer('tpose_transfo', torch.FloatTensor(skel_data['tpose_transfo'])) | |
| # transformation from osim can pose to A pose | |
| self.register_buffer('apose_transfo', torch.FloatTensor(skel_data['apose_transfo'])) | |
| self.register_buffer('apose_rel_transfo', torch.FloatTensor(skel_data['apose_rel_transfo'])) | |
| # Indices of bones which orientation should not vary with beta in T pose: | |
| joint_idx_fixed_beta = [0, 5, 10, 13, 18, 23] | |
| self.register_buffer('joint_idx_fixed_beta', torch.IntTensor(joint_idx_fixed_beta)) | |
| id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} | |
| self.register_buffer('parent', torch.LongTensor( | |
| [id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) | |
| # child array | |
| # TODO create this array in the SKEL creator | |
| child_array = [] | |
| Nj = self.num_joints | |
| for i in range(0, Nj): | |
| try: | |
| j_array = torch.where(self.kintree_table[0] == i)[0] # candidate child lines | |
| if len(j_array) == 0: | |
| child_index = 0 | |
| else: | |
| j = j_array[0] | |
| if j>=len(self.kintree_table[1]): | |
| child_index = 0 | |
| else: | |
| child_index = self.kintree_table[1,j].item() | |
| child_array.append(child_index) | |
| except: | |
| import ipdb; ipdb.set_trace() | |
| # print(f"child_array: ") | |
| # [print(i,child_array[i]) for i in range(0, Nj)] | |
| self.register_buffer('child', torch.LongTensor(child_array)) | |
| # Instantiate joints | |
| self.joints_dict = nn.ModuleList([ | |
| CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 0 pelvis | |
| CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, 1, 1]), # 1 femur_r | |
| WalkerKnee(), # 2 tibia_r | |
| PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 3 talus_r Field taken from .osim Joint-> frames -> PhysicalOffsetFrame -> orientation | |
| PinJoint(parent_frame_ori = [-1.76818999, 0.906223, 1.8196000]), # 4 calcn_r | |
| PinJoint(parent_frame_ori = [-3.141589999, 0.6199010, 0]), # 5 toes_r | |
| CustomJoint(axis=[[0,0,1], [1,0,0], [0,1,0]], axis_flip=[1, -1, -1]), # 6 femur_l | |
| WalkerKnee(), # 7 tibia_l | |
| PinJoint(parent_frame_ori = [0.175895, -0.105208, 0.0186622]), # 8 talus_l | |
| PinJoint(parent_frame_ori = [1.768189999 ,-0.906223, 1.8196000]), # 9 calcn_l | |
| PinJoint(parent_frame_ori = [-3.141589999, -0.6199010, 0]), # 10 toes_l | |
| ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 11 lumbar | |
| ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 12 thorax | |
| ConstantCurvatureJoint(axis=[[1,0,0], [0,0,1], [0,1,0]], axis_flip=[1, 1, 1]), # 13 head | |
| EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, -1, -1]), # 14 scapula_r | |
| CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 15 humerus_r | |
| CustomJoint(axis=[[0.0494, 0.0366, 0.99810825]], axis_flip=[[1]]), # 16 ulna_r | |
| CustomJoint(axis=[[-0.01716099, 0.99266564, -0.11966796]], axis_flip=[[1]]), # 17 radius_r | |
| CustomJoint(axis=[[1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 18 hand_r | |
| EllipsoidJoint(axis=[[0,1,0], [0,0,1], [1,0,0]], axis_flip=[1, 1, 1]), # 19 scapula_l | |
| CustomJoint(axis=[[1,0,0], [0,1,0], [0,0,1]], axis_flip=[1, 1, 1]), # 20 humerus_l | |
| CustomJoint(axis=[[-0.0494, -0.0366, 0.99810825]], axis_flip=[[1]]), # 21 ulna_l | |
| CustomJoint(axis=[[0.01716099, -0.99266564, -0.11966796]], axis_flip=[[1]]), # 22 radius_l | |
| CustomJoint(axis=[[-1,0,0], [0,0,-1]], axis_flip=[1, 1]), # 23 hand_l | |
| ]) | |
| def pose_params_to_rot(self, osim_poses): | |
| """ Transform the pose parameters to 3x3 rotation matrices | |
| Each parameter is mapped to a joint as described in joint_dict. | |
| The specific joint object is then used to compute the rotation matrix. | |
| """ | |
| B = osim_poses.shape[0] | |
| Nj = self.num_joints | |
| ident = torch.eye(3, dtype=osim_poses.dtype).to(osim_poses.device) | |
| Rp = ident.unsqueeze(0).unsqueeze(0).repeat(B, Nj,1,1) | |
| tp = torch.zeros(B, Nj, 3).to(osim_poses.device) | |
| start_index = 0 | |
| for i in range(0, Nj): | |
| joint_object = self.joints_dict[i] | |
| end_index = start_index + joint_object.nb_dof | |
| Rp[:, i] = joint_object.q_to_rot(osim_poses[:, start_index:end_index]) | |
| start_index = end_index | |
| return Rp, tp | |
| def params_name_to_index(self, param_name): | |
| assert param_name in pose_param_names | |
| param_index = pose_param_names.index(param_name) | |
| return param_index | |
| def forward(self, poses, betas, trans, poses_type='skel', skelmesh=True, dJ=None, pose_dep_bs=True): | |
| """ | |
| params | |
| poses : B x 46 tensor of pose parameters | |
| betas : B x 10 tensor of shape parameters, same as SMPL | |
| trans : B x 3 tensor of translation | |
| poses_type : str, 'skel', should not be changed | |
| skelmesh : bool, if True, returns the skeleton vertices. The skeleton mesh is heavy so to fit on GPU memory, set to False when not needed. | |
| dJ : B x 24 x 3 tensor of the offset of the joints location from the anatomical regressor. If None, the offset is set to 0. | |
| pose_dep_bs : bool, if True (default), applies the pose dependant blend shapes. If False, the pose dependant blend shapes are not applied. | |
| return SKELOutput class with the following fields: | |
| betas : Bx10 tensor of shape parameters | |
| poses : Bx46 tensor of pose parameters | |
| skin_verts : Bx6890x3 tensor of skin vertices | |
| skel_verts : tensor of skeleton vertices | |
| joints : Bx24x3 tensor of joints location | |
| joints_ori : Bx24x3x3 tensor of joints orientation | |
| trans : Bx3 pose dependant blend shapes offsets | |
| pose_offsets : Bx6080x3 pose dependant blend shapes offsets | |
| joints_tpose : Bx24x3 3D joints location in T pose | |
| In this function we use the following conventions: | |
| B : batch size | |
| Ns : skin vertices | |
| Nk : skeleton vertices | |
| """ | |
| Ns = self.skin_template_v.shape[0] # nb skin vertices | |
| Nk = self.skel_template_v.shape[0] # nb skeleton vertices | |
| Nj = self.num_joints | |
| B = poses.shape[0] | |
| device = poses.device | |
| # Check the shapes of the inputs | |
| assert len(betas.shape) == 2, f"Betas should be of shape (B, {self.num_betas}), but got {betas.shape}" | |
| assert poses.shape[0] == betas.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {betas.shape[0]}" | |
| assert poses.shape[0] == trans.shape[0], f"Expected poses and betas to have the same batch size, but got {poses.shape[0]} and {trans.shape[0]}" | |
| if dJ is not None: | |
| assert len(dJ.shape) == 3, f"Expected dJ to have shape (B, {Nj}, 3), but got {dJ.shape}" | |
| assert dJ is None or dJ.shape[0] == B, f"Expected dJ to have the same batch size as poses, but got {dJ.shape[0]} and {poses.shape[0]}" | |
| assert dJ.shape[1] == Nj, f"Expected dJ to have the same number of joints as the model, but got {dJ.shape[1]} and {Nj}" | |
| # Check the device of the inputs | |
| assert betas.device == device, f"Betas should be on device {device}, but got {betas.device}" | |
| assert trans.device == device, f"Trans should be on device {device}, but got {trans.device}" | |
| skin_v0 = self.skin_template_v[None, :] | |
| skel_v0 = self.skel_template_v[None, :] | |
| betas = betas[:, :, None] # TODO Name the expanded beta differently | |
| # TODO clean this part | |
| assert poses_type in ['skel', 'bsm'], f"got {poses_type}" | |
| if poses_type == 'bsm': | |
| assert poses.shape[1] == self.num_q_params - 3, f'With poses_type bsm, expected parameters of shape (B, {self.num_q_params - 3}, got {poses.shape}' | |
| poses_bsm = poses | |
| poses_skel = torch.zeros(B, self.num_q_params) | |
| poses_skel[:,:3] = poses_bsm[:, :3] | |
| trans = poses_bsm[:, 3:6] # In BSM parametrization, the hips translation is given by params 3 to 5 | |
| poses_skel[:, 3:] = poses_bsm | |
| poses = poses_skel | |
| else: | |
| assert poses.shape[1] == self.num_q_params, f'With poses_type skel, expected parameters of shape (B, {self.num_q_params}), got {poses.shape}' | |
| pass | |
| # Load poses as expected | |
| # Distinction bsm skel. by default it will be bsm | |
| # ------- Shape ---------- | |
| # Apply the beta offset to the template | |
| shapedirs = self.shapedirs.view(-1, self.num_betas)[None, :].expand(B, -1, -1) # B x D*Ns x num_betas | |
| v_shaped = skin_v0 + torch.matmul(shapedirs, betas).view(B, Ns, 3) | |
| # ------- Joints ---------- | |
| # Regress the anatomical joint location | |
| J = torch.einsum('bik,ji->bjk', [v_shaped, self.J_regressor_osim]) # BxJx3 # osim regressor | |
| # J = self.apose_transfo[:, :3, -1].view(1, Nj, 3).expand(B, -1, -1) # Osim default pose joints location | |
| if dJ is not None: | |
| J = J + dJ | |
| J_tpose = J.clone() | |
| # Local translation | |
| J_ = J.clone() # BxJx3 | |
| J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] | |
| t = J_[:, :, :, None] # BxJx3x1 | |
| # ------- Bones transformation matrix---------- | |
| # Bone initial transform to go from unposed to SMPL T pose | |
| Rk01 = self.compute_bone_orientation(J, J_) | |
| # BSM default pose rotations | |
| Ra = self.apose_rel_transfo[:, :3, :3].view(1, Nj, 3,3).expand(B, Nj, 3, 3) | |
| # Local bone rotation given by the pose param | |
| Rp, tp = self.pose_params_to_rot(poses) # BxNjx3x3 pose params to rotation | |
| R = matmul_chain([Rk01, Ra.transpose(2,3), Rp, Ra, Rk01.transpose(2,3)]) | |
| ###### Compute translation for non pure rotation joints | |
| t_posed = t.clone() | |
| # Scapula | |
| thorax_width = torch.norm(J[:, 19, :] - J[:, 14, :], dim=1) # Distance between the two scapula joints, size B | |
| thorax_height = torch.norm(J[:, 12, :] - J[:, 11, :], dim=1) # Distance between the two scapula joints, size B | |
| angle_abduction = poses[:,26] | |
| angle_elevation = poses[:,27] | |
| angle_rot = poses[:,28] | |
| angle_zero = torch.zeros_like(angle_abduction) | |
| t_posed[:,14] = t_posed[:,14] + \ | |
| (right_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1) | |
| - right_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1)) | |
| angle_abduction = poses[:,36] | |
| angle_elevation = poses[:,37] | |
| angle_rot = poses[:,38] | |
| angle_zero = torch.zeros_like(angle_abduction) | |
| t_posed[:,19] = t_posed[:,19] + \ | |
| (left_scapula(angle_abduction, angle_elevation, angle_rot, thorax_width, thorax_height).view(-1,3,1) | |
| - left_scapula(angle_zero, angle_zero, angle_zero, thorax_width, thorax_height).view(-1,3,1)) | |
| # Knee_r | |
| # TODO add the Walker knee offset | |
| # bone_scale = self.compute_bone_scale(J_,J, skin_v0, v_shaped) | |
| # f1 = poses[:, 2*3+2].clone() | |
| # scale_femur = bone_scale[:, 2] | |
| # factor = 0.076/0.080 * scale_femur # The template femur medial laterak spacing #66 | |
| # f = -f1*180/torch.pi #knee_flexion | |
| # varus = (0.12367*f)-0.0009*f**2 | |
| # introt = 0.3781*f-0.001781*f**2 | |
| # ydis = (-0.0683*f | |
| # + 8.804e-4 * f**2 | |
| # - 3.750e-06*f**3 | |
| # )/1000*factor # up-down | |
| # zdis = (-0.1283*f | |
| # + 4.796e-4 * f**2)/1000*factor # | |
| # import ipdb; ipdb.set_trace() | |
| # poses[:, 9] = poses[:, 9] + varus | |
| # t_posed[:,2] = t_posed[:,2] + torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1) | |
| # poses[:, 2*3+2]=0 | |
| # t_unposed = torch.zeros_like(t_posed) | |
| # t_unposed[:,2] = torch.stack([torch.zeros_like(ydis), ydis, zdis], dim=1).view(-1,3,1) | |
| # Spine | |
| lumbar_bending = poses[:,17] | |
| lumbar_extension = poses[:,18] | |
| angle_zero = torch.zeros_like(lumbar_bending) | |
| interp_t = torch.ones_like(lumbar_bending) | |
| l = torch.abs(J[:, 11, 1] - J[:, 0, 1]) # Length of the spine section along y axis | |
| t_posed[:,11] = t_posed[:,11] + \ | |
| (curve_torch_3d(lumbar_bending, lumbar_extension, t=interp_t, l=l) | |
| - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) | |
| thorax_bending = poses[:,20] | |
| thorax_extension = poses[:,21] | |
| angle_zero = torch.zeros_like(thorax_bending) | |
| interp_t = torch.ones_like(thorax_bending) | |
| l = torch.abs(J[:, 12, 1] - J[:, 11, 1]) # Length of the spine section | |
| t_posed[:,12] = t_posed[:,12] + \ | |
| (curve_torch_3d(thorax_bending, thorax_extension, t=interp_t, l=l) | |
| - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) | |
| head_bending = poses[:, 23] | |
| head_extension = poses[:,24] | |
| angle_zero = torch.zeros_like(head_bending) | |
| interp_t = torch.ones_like(head_bending) | |
| l = torch.abs(J[:, 13, 1] - J[:, 12, 1]) # Length of the spine section | |
| t_posed[:,13] = t_posed[:,13] + \ | |
| (curve_torch_3d(head_bending, head_extension, t=interp_t, l=l) | |
| - curve_torch_3d(angle_zero, angle_zero, t=interp_t, l=l)) | |
| # ------- Body surface transformation matrix---------- | |
| G_ = torch.cat([R, t_posed], dim=-1) # BxJx3x4 local transformation matrix | |
| pad_row = torch.FloatTensor([0, 0, 0, 1]).to(device).view(1, 1, 1, 4).expand(B, Nj, -1, -1) # BxJx1x4 | |
| G_ = torch.cat([G_, pad_row], dim=2) # BxJx4x4 padded to be 4x4 matrix an enable multiplication for the kinematic chain | |
| # Global transform | |
| G = [G_[:, 0].clone()] | |
| for i in range(1, Nj): | |
| G.append(torch.matmul(G[self.parent[i - 1]], G_[:, i, :, :])) | |
| G = torch.stack(G, dim=1) | |
| # ------- Pose dependant blend shapes ---------- | |
| if pose_dep_bs is False: | |
| v_shaped_pd = v_shaped | |
| else: | |
| # Note : Those should be retrained for SKEL as the SKEL joints location are different from SMPL. | |
| # But the current version lets use get decent pose dependant deformations for the shoulders, belly and knies | |
| ident = torch.eye(3, dtype=v_shaped.dtype, device=device) | |
| # We need the per SMPL joint bone transform to compute pose dependant blend shapes. | |
| # Initialize each joint rotation with identity | |
| Rsmpl = ident.unsqueeze(0).unsqueeze(0).expand(B, self.num_joints_smpl, -1, -1).clone() # BxNjx3x3 | |
| Rskin = G_[:, :, :3, :3] # BxNjx3x3 | |
| Rsmpl[:, smpl_joint_corresp] = Rskin[:] # BxNjx3x3 pose params to rotation | |
| pose_feature = Rsmpl[:, 1:].view(B, -1, 3, 3) - ident | |
| pose_offsets = torch.matmul(pose_feature.view(B, -1), | |
| self.posedirs.view(Ns*3, -1).T).view(B, -1, 3) | |
| v_shaped_pd = v_shaped + pose_offsets | |
| ########################################################################################## | |
| #Transform skin mesh | |
| ############################################################################################ | |
| # Apply global transformation to the template mesh | |
| rest = torch.cat([J, torch.zeros(B, Nj, 1).to(device)], dim=2).view(B, Nj, 4, 1) # BxJx4x1 | |
| zeros = torch.zeros(B, Nj, 4, 3).to(device) # BxJx4x3 | |
| rest = torch.cat([zeros, rest], dim=-1) # BxJx4x4 | |
| rest = torch.matmul(G, rest) # This is a 4x4 transformation matrix that only contains translation to the rest pose joint location | |
| Gskin = G - rest | |
| # Compute per vertex transformation matrix (after weighting) | |
| T = torch.matmul(self.skin_weights, Gskin.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Ns, B, 4,4).transpose(0, 1) | |
| rest_shape_h = torch.cat([v_shaped_pd, torch.ones_like(v_shaped_pd)[:, :, [0]]], dim=-1) | |
| v_posed = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] | |
| # translation | |
| v_trans = v_posed + trans[:,None,:] | |
| ########################################################################################## | |
| #Transform joints | |
| ############################################################################################ | |
| # import ipdb; ipdb.set_trace() | |
| root_transform = with_zeros(torch.cat((R[:,0],J[:,0][:,:,None]),2)) | |
| results = [root_transform] | |
| for i in range(0, self.parent.shape[0]): | |
| transform_i = with_zeros(torch.cat((R[:, i + 1], t_posed[:,i+1]), 2)) | |
| curr_res = torch.matmul(results[self.parent[i]],transform_i) | |
| results.append(curr_res) | |
| results = torch.stack(results, dim=1) | |
| posed_joints = results[:, :, :3, 3] | |
| J_transformed = posed_joints + trans[:,None,:] | |
| ########################################################################################## | |
| # Transform skeleton | |
| ############################################################################################ | |
| if skelmesh: | |
| G_bones = None | |
| # Shape the skeleton by scaling its bones | |
| skel_rest_shape_h = torch.cat([skel_v0, torch.ones_like(skel_v0)[:, :, [0]]], dim=-1).expand(B, Nk, -1) # (1,Nk,3) | |
| # compute the bones scaling from the kinematic tree and skin mesh | |
| #with torch.no_grad(): | |
| # TODO: when dJ is optimized the shape of the mesh should be affected by the gradients | |
| bone_scale = self.compute_bone_scale(J_, v_shaped, skin_v0) | |
| # Apply bone meshes scaling: | |
| skel_v_shaped = torch.cat([(torch.matmul(bone_scale[:,:,0], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 0])[:, :, None], | |
| (torch.matmul(bone_scale[:,:,1], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 1])[:, :, None], | |
| (torch.matmul(bone_scale[:,:,2], self.skel_weights_rigid.T) * skel_rest_shape_h[:, :, 2])[:, :, None], | |
| (torch.ones(B, Nk, 1).to(device)) | |
| ], dim=-1) | |
| # Align the bones with the proper axis | |
| Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4 | |
| T = torch.matmul(self.skel_weights_rigid, Gk01.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) #[1, 48757, 3, 3] | |
| skel_v_align = torch.matmul(T, skel_v_shaped[:, :, :, None])[:, :, :, 0] | |
| # This transfo will be applied with weights, effectively unposing the whole skeleton mesh in each joint frame. | |
| # Then, per joint weighted transformation can then be applied | |
| G_tpose_to_unposed = build_homog_matrix(torch.eye(3).view(1,1,3,3).expand(B, Nj, 3, 3).to(device), -J.unsqueeze(-1)) # BxJx4x4 | |
| G_skel = torch.matmul(G, G_tpose_to_unposed) | |
| G_bones = torch.matmul(G, Gk01) | |
| T = torch.matmul(self.skel_weights, G_skel.permute(1, 0, 2, 3).contiguous().view(Nj, -1)).view(Nk, B, 4,4).transpose(0, 1) | |
| skel_v_posed = torch.matmul(T, skel_v_align[:, :, :, None])[:, :, :3, 0] | |
| skel_trans = skel_v_posed + trans[:,None,:] | |
| else: | |
| skel_trans = skel_v0 | |
| Gk01 = build_homog_matrix(Rk01, J.unsqueeze(-1)) # BxJx4x4 | |
| G_bones = torch.matmul(G, Gk01) | |
| joints = J_transformed | |
| skin_verts = v_trans | |
| skel_verts = skel_trans | |
| joints_ori = G_bones[:,:,:3,:3] | |
| if skin_verts.max() > 1e3: | |
| import ipdb; ipdb.set_trace() | |
| output = SKELOutput(skin_verts=skin_verts, | |
| skel_verts=skel_verts, | |
| joints=joints, | |
| joints_ori=joints_ori, | |
| betas=betas, | |
| poses=poses, | |
| trans = trans, | |
| pose_offsets = pose_offsets, | |
| joints_tpose = J_tpose, | |
| v_shaped = v_shaped,) | |
| return output | |
| def compute_bone_scale(self, J_, v_shaped, skin_v0): | |
| # index [0, 1, 2, 3 4, 5, , ...] # todo add last one, figure out bone scale indices | |
| # J_ bone vectors [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] | |
| # norm(J) = length of the bone [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] | |
| # self.joints_sockets [j0, j1-j0, j2-j0, j3-j0, j4-j1, j5-j2, ...] | |
| # self.skel_weights [j0, j1, j2, j3, j4, j5, ...] | |
| B = J_.shape[0] | |
| Nj = J_.shape[1] | |
| bone_scale = torch.ones(B, Nj).to(J_.device) | |
| # BSM template joints location | |
| osim_joints_r = self.apose_rel_transfo[:, :3, 3].view(1, Nj, 3).expand(B, Nj, 3).clone() | |
| length_bones_bsm = torch.norm(osim_joints_r, dim=-1).expand(B, -1) | |
| length_bones_smpl = torch.norm(J_, dim=-1) # (B, Nj) | |
| bone_scale_parent = length_bones_smpl / length_bones_bsm | |
| non_leaf_node = (self.child != 0) | |
| bone_scale[:,non_leaf_node] = (bone_scale_parent[:,self.child])[:,non_leaf_node] | |
| # Ulna should have the same scale as radius | |
| bone_scale[:,16] = bone_scale[:,17] | |
| bone_scale[:,16] = bone_scale[:,17] | |
| bone_scale[:,21] = bone_scale[:,22] | |
| bone_scale[:,21] = bone_scale[:,22] | |
| # Thorax | |
| # Thorax scale is defined by the relative position of the thorax to its child joint, not parent joint as for other bones | |
| bone_scale[:, 12] = bone_scale[:, 11] | |
| # Lumbars | |
| # Lumbar scale is defined by the y relative position of the lumbar joint | |
| length_bones_bsm = torch.abs(osim_joints_r[:,11, 1]) | |
| length_bones_smpl = torch.abs(J_[:, 11, 1]) # (B, Nj) | |
| bone_scale_lumbar = length_bones_smpl / length_bones_bsm | |
| bone_scale[:, 11] = bone_scale_lumbar | |
| # Expand to 3 dimensions and adjest scaling to avoid skin-skeleton intersection and handle the scaling of leaf body parts (hands, feet) | |
| bone_scale = bone_scale.reshape(B, Nj, 1).expand(B, Nj, 3).clone() | |
| for (ji, doi, dsi), (v1, v2) in scaling_keypoints.items(): | |
| bone_scale[:, ji, doi] = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,dsi] # Top over chin | |
| #TODO: Add keypoints for feet scaling in scaling_keypoints | |
| # Adjust thorax front-back scaling | |
| # TODO fix this part | |
| v1 = 3027 #thorax back | |
| v2 = 3495 #thorax front | |
| scale_thorax_up = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # good for large people | |
| v2 = 3506 #sternum | |
| scale_thorax_sternum = ((v_shaped[:,v1] - v_shaped[:, v2])/ (skin_v0[:,v1] - skin_v0[:, v2]))[:,2] # Good for skinny people | |
| bone_scale[:, 12, 0] = torch.min(scale_thorax_up, scale_thorax_sternum) # Avoids super expanded ribcage for large people and sternum outside for skinny people | |
| #lumbars, adjust width to be same as thorax | |
| bone_scale[:, 11, 0] = bone_scale[:, 12, 0] | |
| return bone_scale | |
| def compute_bone_orientation(self, J, J_): | |
| """Compute each bone orientation in T pose """ | |
| # method = 'unposed' | |
| # method = 'learned' | |
| method = 'learn_adjust' | |
| B = J_.shape[0] | |
| Nj = J_.shape[1] | |
| # Create an array of bone vectors the bone meshes should be aligned to. | |
| bone_vect = torch.zeros_like(J_) # / torch.norm(J_, dim=-1)[:, :, None] # (B, Nj, 3) | |
| bone_vect[:] = J_[:, self.child] # Most bones are aligned between their parent and child joint | |
| bone_vect[:,16] = bone_vect[:,16]+bone_vect[:,17] # We want to align the ulna to the segment joint 16 to 18 | |
| bone_vect[:,21] = bone_vect[:,21]+bone_vect[:,22] # Same other ulna | |
| # TODO Check indices here | |
| # bone_vect[:,13] = bone_vect[:,12].clone() | |
| bone_vect[:,12] = bone_vect.clone()[:,11].clone() # We want to align the thorax on the thorax-lumbar segment | |
| # bone_vect[:,11] = bone_vect[:,0].clone() | |
| osim_vect = self.apose_rel_transfo[:, :3, 3].clone().view(1, Nj, 3).expand(B, Nj, 3).clone() | |
| osim_vect[:] = osim_vect[:,self.child] | |
| osim_vect[:,16] = osim_vect[:,16]+osim_vect[:,17] # We want to align the ulna to the segment joint 16 to 18 | |
| osim_vect[:,21] = osim_vect[:,21]+osim_vect[:,22] # We want to align the ulna to the segment joint 16 to 18 | |
| # TODO: remove when this has been checked | |
| # import matplotlib.pyplot as plt | |
| # fig = plt.figure() | |
| # ax = fig.add_subplot(111, projection='3d') | |
| # ax.plot(osim_vect[:,0,0], osim_vect[:,0,1], osim_vect[:,0,2], color='r') | |
| # plt.show() | |
| Gk = torch.eye(3, device=J_.device).repeat(B, Nj, 1, 1) | |
| if method == 'unposed': | |
| return Gk | |
| elif method == 'learn_adjust': | |
| Gk_learned = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) #load learned rotation | |
| osim_vect_corr = torch.matmul(Gk_learned, osim_vect.unsqueeze(-1)).squeeze(-1) | |
| Gk[:,:] = rotation_matrix_from_vectors(osim_vect_corr, bone_vect) | |
| # set nan to zero | |
| # TODO: Check again why the following line was required | |
| Gk[torch.isnan(Gk)] = 0 | |
| # Gk[:,[18,23]] = Gk[:,[16,21]] # hand has same orientation as ulna | |
| # Gk[:,[5,10]] = Gk[:,[4,9]] # toe has same orientation as calcaneus | |
| # Gk[:,[0,11,12,13,14,19]] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, 6, 3, 3) # pelvis, torso and shoulder blade orientation does not vary with beta, leave it | |
| Gk[:, self.joint_idx_fixed_beta] = torch.eye(3, device=J_.device).view(1,3,3).expand(B, len(self.joint_idx_fixed_beta), 3, 3) # pelvis, torso and shoulder blade orientation should not vary with beta, leave it | |
| Gk = torch.matmul(Gk, Gk_learned) | |
| elif method == 'learned': | |
| """ Apply learned transformation""" | |
| Gk = self.per_joint_rot.view(1, Nj, 3, 3).expand(B, -1, -1, -1) | |
| else: | |
| raise NotImplementedError | |
| return Gk | |