Spaces:
Runtime error
Runtime error
| from lib.kits.basic import * | |
| def T_to_Rt( | |
| T : Union[torch.Tensor, np.ndarray], | |
| ): | |
| ''' Get (..., 3, 3) rotation matrix and (..., 3) translation vector from (..., 4, 4) transformation matrix. ''' | |
| if isinstance(T, np.ndarray): | |
| T = torch.from_numpy(T).float() | |
| assert T.shape[-2:] == (4, 4), f'T.shape[-2:] = {T.shape[-2:]}' | |
| R = T[..., :3, :3] | |
| t = T[..., :3, 3] | |
| return R, t | |
| def Rt_to_T( | |
| R : Union[torch.Tensor, np.ndarray], | |
| t : Union[torch.Tensor, np.ndarray], | |
| ): | |
| ''' Get (..., 4, 4) transformation matrix from (..., 3, 3) rotation matrix and (..., 3) translation vector. ''' | |
| if isinstance(R, np.ndarray): | |
| R = torch.from_numpy(R).float() | |
| if isinstance(t, np.ndarray): | |
| t = torch.from_numpy(t).float() | |
| assert R.shape[-2:] == (3, 3), f'R should be a (..., 3, 3) matrix, but R.shape = {R.shape}' | |
| assert t.shape[-1] == 3, f't should be a (..., 3) vector, but t.shape = {t.shape}' | |
| assert R.shape[:-2] == t.shape[:-1], f'R and t should have the same shape prefix but {R.shape[:-2]} != {t.shape[:-1]}' | |
| T = torch.eye(4, device=R.device, dtype=R.dtype).repeat(R.shape[:-2] + (1, 1)) # (..., 4, 4) | |
| T[..., :3, :3] = R | |
| T[..., :3, 3] = t | |
| return T | |
| def apply_Ts_on_pts(Ts:torch.Tensor, pts:torch.Tensor): | |
| ''' | |
| Apply transformation matrix `T` on the points `pts`. | |
| ### Args | |
| - Ts: torch.Tensor, (...B, 4, 4) | |
| - pts: torch.Tensor, (...B, N, 3) | |
| ''' | |
| assert len(pts.shape) >= 3 and pts.shape[-1] == 3, f'Shape of pts should be (...B, N, 3) but {pts.shape}' | |
| assert Ts.shape[-2:] == (4, 4), f'Shape of Ts should be (..., 4, 4) but {Ts.shape}' | |
| assert Ts.device == pts.device, f'Device of Ts and pts should be the same but {Ts.device} != {pts.device}' | |
| ret_pts = torch.einsum('...ij,...nj->...ni', Ts[..., :3, :3], pts) + Ts[..., None, :3, 3] | |
| ret_pts = ret_pts.squeeze(0) # (B, N, 3) | |
| return ret_pts | |
| def apply_T_on_pts(T:torch.Tensor, pts:torch.Tensor): | |
| ''' | |
| Apply transformation matrix `T` on the points `pts`. | |
| ### Args | |
| - T: torch.Tensor, (4, 4) | |
| - pts: torch.Tensor, (B, N, 3) or (N, 3) | |
| ''' | |
| unbatched = len(pts.shape) == 2 | |
| if unbatched: | |
| pts = pts[None] | |
| ret = apply_Ts_on_pts(T[None], pts) | |
| return ret.squeeze(0) if unbatched else ret | |
| def apply_Ks_on_pts(Ks:torch.Tensor, pts:torch.Tensor): | |
| ''' | |
| Apply intrinsic camera matrix `K` on the points `pts`, i.e. project the 3D points to 2D. | |
| ### Args | |
| - Ks: torch.Tensor, (...B, 3, 3) | |
| - pts: torch.Tensor, (...B, N, 3) | |
| ''' | |
| assert len(pts.shape) >= 3 and pts.shape[-1] == 3, f'Shape of pts should be (...B, N, 3) but {pts.shape}' | |
| assert Ks.shape[-2:] == (3, 3), f'Shape of Ks should be (..., 3, 3) but {Ks.shape}' | |
| assert Ks.device == pts.device, f'Device of Ks and pts should be the same but {Ks.device} != {pts.device}' | |
| pts_proj_homo = torch.einsum('...ij,...vj->...vi', Ks, pts) | |
| pts_proj = pts_proj_homo[..., :2] / pts_proj_homo[..., 2:3] | |
| return pts_proj | |
| def apply_K_on_pts(K:torch.Tensor, pts:torch.Tensor): | |
| ''' | |
| Apply intrinsic camera matrix `K` on the points `pts`, i.e. project the 3D points to 2D. | |
| ### Args | |
| - K: torch.Tensor, (3, 3) | |
| - pts: torch.Tensor, (B, N, 3) or (N, 3) | |
| ''' | |
| unbatched = len(pts.shape) == 2 | |
| if unbatched: | |
| pts = pts[None] | |
| ret = apply_Ks_on_pts(K[None], pts) | |
| return ret.squeeze(0) if unbatched else ret | |
| def perspective_projection( | |
| points : torch.Tensor, | |
| translation : torch.Tensor, | |
| focal_length : torch.Tensor, | |
| camera_center : Optional[torch.Tensor] = None, | |
| rotation : Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| ''' | |
| Computes the perspective projection of a set of 3D points. | |
| https://github.com/shubham-goel/4D-Humans/blob/6ec79656a23c33237c724742ca2a0ec00b398b53/hmr2/utils/geometry.py#L64-L102 | |
| ### Args | |
| - points: torch.Tensor, (B, N, 3) | |
| - The input 3D points. | |
| - translation: torch.Tensor, (B, 3) | |
| - The 3D camera translation. | |
| - focal_length: torch.Tensor, (B, 2) | |
| - The focal length in pixels. | |
| - camera_center: torch.Tensor, (B, 2) | |
| - The camera center in pixels. | |
| - rotation: torch.Tensor, (B, 3, 3) | |
| - The camera rotation. | |
| ### Returns | |
| - torch.Tensor, (B, N, 2) | |
| - The projection of the input points. | |
| ''' | |
| B = points.shape[0] | |
| if rotation is None: | |
| rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(B, -1, -1) | |
| if camera_center is None: | |
| camera_center = torch.zeros(B, 2, device=points.device, dtype=points.dtype) | |
| # Populate intrinsic camera matrix K. | |
| K = torch.zeros([B, 3, 3], device=points.device, dtype=points.dtype) | |
| K[:, 0, 0] = focal_length[:, 0] | |
| K[:, 1, 1] = focal_length[:, 1] | |
| K[:, 2, 2] = 1. | |
| K[:, :-1, -1] = camera_center | |
| # Transform points | |
| points = torch.einsum('bij, bkj -> bki', rotation, points) | |
| points = points + translation.unsqueeze(1) | |
| # Apply perspective distortion | |
| projected_points = points / points[:, :, -1].unsqueeze(-1) | |
| # Apply camera intrinsics | |
| projected_points = torch.einsum('bij, bkj -> bki', K, projected_points) | |
| return projected_points[:, :, :-1] | |
| def estimate_translation_np(S, joints_2d, joints_conf, focal_length=5000, img_size=224): | |
| ''' | |
| Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. | |
| Copied from: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/utils/geometry.py#L94-L132 | |
| ### Args | |
| - S: shape = (25, 3) | |
| - 3D joint locations. | |
| - joints: shape = (25, 3) | |
| - 2D joint locations and confidence. | |
| ### Returns | |
| - shape = (3,) | |
| - Camera translation vector. | |
| ''' | |
| num_joints = S.shape[0] | |
| # focal length | |
| f = np.array([focal_length,focal_length]) | |
| # optical center | |
| center = np.array([img_size/2., img_size/2.]) | |
| # transformations | |
| Z = np.reshape(np.tile(S[:,2],(2,1)).T,-1) | |
| XY = np.reshape(S[:,0:2],-1) | |
| O = np.tile(center,num_joints) | |
| F = np.tile(f,num_joints) | |
| weight2 = np.reshape(np.tile(np.sqrt(joints_conf),(2,1)).T,-1) | |
| # least squares | |
| Q = np.array([F*np.tile(np.array([1,0]),num_joints), F*np.tile(np.array([0,1]),num_joints), O-np.reshape(joints_2d,-1)]).T | |
| c = (np.reshape(joints_2d,-1)-O)*Z - F*XY | |
| # weighted least squares | |
| W = np.diagflat(weight2) | |
| Q = np.dot(W,Q) | |
| c = np.dot(W,c) | |
| # square matrix | |
| A = np.dot(Q.T,Q) | |
| b = np.dot(Q.T,c) | |
| # solution | |
| trans = np.linalg.solve(A, b) | |
| return trans | |
| def estimate_camera_trans( | |
| S : torch.Tensor, | |
| joints_2d : torch.Tensor, | |
| focal_length : float = 5000., | |
| img_size : float = 224., | |
| conf_thre : float = 4., | |
| ): | |
| ''' | |
| Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d. | |
| Modified from: https://github.com/nkolot/SPIN/blob/2476c436013055be5cb3905e4e4ecfa86966fac3/utils/geometry.py#L135-L157 | |
| ### Args | |
| - S: torch.Tensor, shape = (B, J, 3) | |
| - 3D joint locations. | |
| - joints: torch.Tensor, shape = (B, J, 3) | |
| - Ground truth 2D joint locations and confidence. | |
| - focal_length: float | |
| - img_size: float | |
| - conf_thre: float | |
| - Confidence threshold to judge whether we use gt_kp2d or that from OpenPose. | |
| ### Returns | |
| - torch.Tensor, shape = (B, 3) | |
| - Camera translation vectors. | |
| ''' | |
| device = S.device | |
| B = len(S) | |
| S = to_numpy(S) | |
| joints_2d = to_numpy(joints_2d) | |
| joints_conf = joints_2d[:, :, -1] # (B, J) | |
| joints_2d = joints_2d[:, :, :-1] # (B, J, 2) | |
| trans = np.zeros((S.shape[0], 3), dtype=np.float32) | |
| # Find the translation for each example in the batch | |
| for i in range(B): | |
| conf_i = joints_conf[i] | |
| # When the ground truth joints are not enough, use all the joints. | |
| if conf_i[25:].sum() < conf_thre: | |
| S_i = S[i] | |
| joints_i = joints_2d[i] | |
| else: | |
| S_i = S[i, 25:] | |
| conf_i = joints_conf[i, 25:] | |
| joints_i = joints_2d[i, 25:] | |
| trans[i] = estimate_translation_np(S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size) | |
| return torch.from_numpy(trans).to(device) |