Spaces:
Runtime error
Runtime error
| import torch | |
| from lib.body_models.skel.joints_def import curve_torch_3d | |
| from lib.body_models.skel.utils import axis_angle_to_matrix, euler_angles_to_matrix, rodrigues | |
| class OsimJoint(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| pass | |
| def q_to_translation(self, q, **kwargs): | |
| return torch.zeros(q.shape[0], 3).to(q.device) | |
| class CustomJoint(OsimJoint): | |
| def __init__(self, axis, axis_flip) -> None: | |
| super().__init__() | |
| self.register_buffer('axis', torch.FloatTensor(axis)) | |
| self.register_buffer('axis_flip', torch.FloatTensor(axis_flip)) | |
| self.register_buffer('nb_dof', torch.tensor(len(axis))) | |
| def q_to_rot(self, q, **kwargs): | |
| ident = torch.eye(3, dtype=q.dtype).to(q.device) | |
| Rp = ident.unsqueeze(0).expand(q.shape[0],3,3) # torch.eye(q.shape[0], 3, 3) | |
| for i in range(self.nb_dof): | |
| axis = self.axis[i].to(q.device) | |
| angle_axis = q[:, i:i+1] * self.axis_flip[i].to(q.device) * axis | |
| Rp_i = axis_angle_to_matrix(angle_axis) | |
| Rp = torch.matmul(Rp_i, Rp) | |
| return Rp | |
| class CustomJoint1D(OsimJoint): | |
| def __init__(self, axis, axis_flip) -> None: | |
| super().__init__() | |
| self.axis = torch.FloatTensor(axis) | |
| self.axis = self.axis / torch.linalg.norm(self.axis) | |
| self.axis_flip = torch.FloatTensor(axis_flip) | |
| self.nb_dof = 1 | |
| def q_to_rot(self, q, **kwargs): | |
| axis = self.axis.to(q.device) | |
| angle_axis = q[:, 0:1] * self.axis_flip.to(q.device) * axis | |
| Rp_i = axis_angle_to_matrix(angle_axis) | |
| return Rp_i | |
| class WalkerKnee(OsimJoint): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.register_buffer('nb_dof', torch.tensor(1)) | |
| # self.nb_dof = 1 | |
| def q_to_rot(self, q, **kwargs): | |
| # Todo : for now implement a basic knee | |
| theta_i = torch.zeros(q.shape[0], 3).to(q.device) | |
| theta_i[:, 2] = -q[:, 0] | |
| Rp_i = axis_angle_to_matrix(theta_i) | |
| return Rp_i | |
| class PinJoint(OsimJoint): | |
| def __init__(self, parent_frame_ori) -> None: | |
| super().__init__() | |
| self.register_buffer('parent_frame_ori', torch.FloatTensor(parent_frame_ori)) | |
| self.register_buffer('nb_dof', torch.tensor(1)) | |
| def q_to_rot(self, q, **kwargs): | |
| talus_orient_torch = self.parent_frame_ori.to(q.device) | |
| Ra_i = euler_angles_to_matrix(talus_orient_torch, 'XYZ') | |
| z_axis = torch.FloatTensor([0,0,1]).to(q.device) | |
| axis = torch.matmul(Ra_i, z_axis).to(q.device) | |
| axis_angle = q[:, 0:1] * axis | |
| Rp_i = axis_angle_to_matrix(axis_angle) | |
| return Rp_i | |
| class ConstantCurvatureJoint(CustomJoint): | |
| def __init__(self, **kwargs ) -> None: | |
| super().__init__( **kwargs) | |
| class EllipsoidJoint(CustomJoint): | |
| def __init__(self, **kwargs) -> None: | |
| super().__init__(**kwargs) | |