Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import pickle | |
| from typing import Optional | |
| import smplx | |
| from smplx.lbs import vertices2joints | |
| from smplx.utils import SMPLOutput | |
| class SMPLWrapper(smplx.SMPLLayer): | |
| def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs): | |
| """ | |
| Extension of the official SMPL implementation to support more joints. | |
| Args: | |
| Same as SMPLLayer. | |
| joint_regressor_extra (str): Path to extra joint regressor. | |
| """ | |
| super(SMPLWrapper, self).__init__(*args, **kwargs) | |
| smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, | |
| 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34] | |
| if joint_regressor_extra is not None: | |
| self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32)) | |
| self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long)) | |
| self.update_hips = update_hips | |
| def forward(self, *args, **kwargs) -> SMPLOutput: | |
| """ | |
| Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified. | |
| """ | |
| smpl_output = super(SMPLWrapper, self).forward(*args, **kwargs) | |
| joints_smpl = smpl_output.joints.clone() | |
| joints = smpl_output.joints[:, self.joint_map, :] | |
| if self.update_hips: | |
| joints[:,[9,12]] = joints[:,[9,12]] + \ | |
| 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \ | |
| 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]])) | |
| if hasattr(self, 'joint_regressor_extra'): | |
| extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices) | |
| joints = torch.cat([joints, extra_joints], dim=1) | |
| smpl_output.joints = joints | |
| smpl_output.joints_smpl = joints_smpl | |
| return smpl_output | |