Spaces:
Running
Running
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from math import isqrt | |
| import torch | |
| from einops import einsum | |
| try: | |
| from e3nn.o3 import matrix_to_angles, wigner_D | |
| except ImportError: | |
| from depth_anything_3.utils.logger import logger | |
| logger.warn("Dependency 'e3nn' not found. Required for rotating the camera space SH coeff") | |
| def project_to_so3_strict(M: torch.Tensor) -> torch.Tensor: | |
| if M.shape[-2:] != (3, 3): | |
| raise ValueError("Input must be a batch of 3x3 matrices (i.e., shape [..., 3, 3]).") | |
| # 1. Compute SVD | |
| U, S, Vh = torch.linalg.svd(M) | |
| V = Vh.mH | |
| # 2. Handle reflection case (det = -1) | |
| det_U = torch.det(U) | |
| det_V = torch.det(V) | |
| is_reflection = (det_U * det_V) < 0 | |
| correction_sign = torch.where( | |
| is_reflection[..., None], | |
| torch.tensor([1, 1, -1.0], device=M.device, dtype=M.dtype), | |
| torch.tensor([1, 1, 1.0], device=M.device, dtype=M.dtype), | |
| ) | |
| correction_matrix = torch.diag_embed(correction_sign) | |
| U_corrected = U @ correction_matrix | |
| R_so3_initial = U_corrected @ V.transpose(-2, -1) | |
| # 3. Explicitly ensure determinant is 1 (or extremely close) | |
| current_det = torch.det(R_so3_initial) | |
| det_correction_factor = torch.pow(current_det, -1 / 3)[..., None, None] | |
| R_so3_final = R_so3_initial * det_correction_factor | |
| return R_so3_final | |
| def rotate_sh( | |
| sh_coefficients: torch.Tensor, # "*#batch n" | |
| rotations: torch.Tensor, # "*#batch 3 3" | |
| ) -> torch.Tensor: # "*batch n" | |
| # https://github.com/graphdeco-inria/gaussian-splatting/issues/176#issuecomment-2452412653 | |
| device = sh_coefficients.device | |
| dtype = sh_coefficients.dtype | |
| *_, n = sh_coefficients.shape | |
| with torch.autocast(device_type=rotations.device.type, enabled=False): | |
| rotations_float32 = rotations.to(torch.float32) | |
| # switch axes: yzx -> xyz | |
| P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]).unsqueeze(0).to(rotations_float32) | |
| permuted_rotations = torch.linalg.inv(P) @ rotations_float32 @ P | |
| # ensure rotation has det == 1 in float32 type | |
| permuted_rotations_so3 = project_to_so3_strict(permuted_rotations) | |
| alpha, beta, gamma = matrix_to_angles(permuted_rotations_so3) | |
| result = [] | |
| for degree in range(isqrt(n)): | |
| with torch.device(device): | |
| sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype) | |
| sh_rotated = einsum( | |
| sh_rotations, | |
| sh_coefficients[..., degree**2 : (degree + 1) ** 2], | |
| "... i j, ... j -> ... i", | |
| ) | |
| result.append(sh_rotated) | |
| return torch.cat(result, dim=-1) | |