# mypy: disable-error-code="empty-body"
from typing import Literal, Optional, Tuple, Union, overload
import torch
from jaxtyping import Float
from typing_extensions import deprecated
@deprecated("`transform` is deprecated, use `transform_points` instead.")
def transform(
pts: Float[torch.Tensor, "... n 3"], tf_mat: Float[torch.Tensor, "... 4 4"]
) -> Float[torch.Tensor, "... n 3"]:
return transform_points(pts, tf_mat) # type: ignore
[docs]
def rotate_points(
pts: Float[torch.Tensor, "... n 3"], rot_mat: Float[torch.Tensor, "... 3 3"]
) -> Float[torch.Tensor, "... n 3"]:
"""Rotate a set of 3D points by a rotation matrix.
Args:
pts (torch.Tensor): 3D points in shape [... n 3].
rot_mat (torch.Tensor): Rotation matrix in shape [... 3 3].
Returns:
torch.Tensor: Rotated points in shape [... n 3].
"""
if pts.ndim != rot_mat.ndim:
raise ValueError(
f"The dimension number of pts and rot_mat should be the same, but got {pts.ndim=} and {rot_mat.ndim=}"
)
# `pts @ rot_mat.T` or `(rot_mat @ pts.T).T`
new_pts = torch.matmul(pts, torch.transpose(rot_mat, -2, -1))
return new_pts
@overload
[docs]
def project_points(pts: torch.Tensor, intr_mat: torch.Tensor, return_depth: Literal[False] = False) -> torch.Tensor: ...
@overload
def project_points(
pts: torch.Tensor, intr_mat: torch.Tensor, return_depth: Literal[True]
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def project_points(
pts: torch.Tensor, intr_mat: torch.Tensor, return_depth: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Project 3D points in the camera space to the image plane.
Args:
pts: 3D points, could be Nx3 or BxNx3.
intr_mat: Intrinsic matrix, could be 3x3 or Bx3x3.
Returns:
pixels: the order is uv other than xy.
depth (if return_depth): depth in the camera space.
"""
new_pts = pts / pts[..., 2:3]
new_pts = torch.matmul(new_pts, torch.transpose(intr_mat, -2, -1))
if not return_depth:
return new_pts[..., :2]
else:
return new_pts[..., :2], pts[..., 2]
[docs]
def unproject_points(pixels, depth, intr_mat):
"""
Unproject pixels in the image plane to 3D points in the camera space.
Args:
pixels: Pixels in the image plane, could be Nx2 or BxNx2. The order is uv rather than xy.
depth: Depth in the camera space, could be N, Nx1, BxN or BxNx1.
intr_mat: Intrinsic matrix, could be 3x3 or Bx3x3.
Returns:
pts: 3D points, Nx3 or BxNx3.
"""
if depth.ndim < pixels.ndim:
depth = depth[..., None] # N -> Nx1, BxN -> BxNx1
principal_point = torch.unsqueeze(intr_mat[..., :2, 2], dim=-2) # 1x2, Bx1x2
focal_length = torch.cat([intr_mat[..., 0:1, 0:1], intr_mat[..., 1:2, 1:2]], dim=-1) # 1x2, Bx1x2
xys = (pixels - principal_point) * depth / focal_length
pts = torch.cat([xys, depth], dim=-1)
return pts
[docs]
def inverse_tf_mat(rot_or_tf_mat: torch.Tensor) -> torch.Tensor:
"""Inverse a rotation matrix or a transformation matrix. Reference_
Args:
rot_or_tf_mat (torch.Tensor): Rotation matrix (in shape [..., 3, 3]) or transformation matrix (in shape [..., 4, 4]).
Returns:
torch.Tensor: Inversed matrix.
Examples:
>>> tf_mat = torch.tensor([[0, 1, 0, 1], [0, 0, 1, 2], [1, 0, 0, 3], [0, 0, 0, 1]], dtype=torch.float32)
>>> torch.allclose(inverse_tf_mat(tf_mat) @ tf_mat, torch.eye(4))
True
>>> rot_mat = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
>>> torch.allclose(inverse_tf_mat(rot_mat) @ rot_mat, torch.eye(3))
True
.. _Reference: https://math.stackexchange.com/a/1315407/757569
"""
if rot_or_tf_mat.shape[-1] == 3: # rotation matrix
new_mat = torch.transpose(rot_or_tf_mat, -2, -1)
else: # transformation matrix
new_rot_mat = torch.transpose(rot_or_tf_mat[..., :3, :3], -2, -1)
ori_tl = torch.unsqueeze(rot_or_tf_mat[..., :3, 3], dim=-2) # 1x3, Bx1x3
new_tl = torch.squeeze(-rotate_points(ori_tl, new_rot_mat), dim=-2) # 3, Bx3
new_mat = rot_tl_to_tf_mat(new_rot_mat, new_tl)
return new_mat
[docs]
def swap_major(rot_or_tf_mat: torch.Tensor) -> torch.Tensor:
"""Swap the major of a rotation matrix or a transformation matrix. Reference_
Args:
rot_or_tf_mat (torch.Tensor): Rotation or transformation matrix in shape [..., 3, 3] or [..., 4, 4].
Returns:
torch.Tensor: Matrix with swapped major.
.. _Reference: https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/geometry/row-major-vs-column-major-vector # noqa
"""
return torch.transpose(rot_or_tf_mat, -2, -1)
[docs]
def to_homo(pts_3d: Float[torch.Tensor, "... 3"]) -> Float[torch.Tensor, "... 4"]:
"""Convert Cartesian 3D points to Homogeneous 4D points.
Args:
pts_3d (torch.Tensor): Cartesian 3D points in shape [... 3].
Returns:
torch.Tensor: Homogeneous 4D points in shape [... 4].
Examples:
>>> pts = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
>>> to_homo(pts)
tensor([[0., 1., 0., 1.],
[0., 0., 1., 1.],
[1., 0., 0., 1.]])
"""
return torch.cat([pts_3d, torch.ones_like(pts_3d[..., :1])], dim=-1)
[docs]
def rot_tl_to_tf_mat(
rot_mat: Optional[Float[torch.Tensor, "... 3 3"]] = None, tl: Optional[Float[torch.Tensor, "... 3"]] = None
) -> Float[torch.Tensor, "... 4 4"]:
"""Build transformation matrix with rotation matrix and translation vector.
Args:
rot_mat (torch.Tensor, optional): Rotation matrix in shape [... 3 3]. Defaults to None.
tl (torch.Tensor, optional): Translation vector in shape [... 3]. Defaults to None.
Returns:
torch.Tensor: Transformation matrix in shape [... 4 4].
Examples:
>>> rot_mat = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32)
>>> tl = torch.tensor([1, 2, 3], dtype=torch.float32)
>>> rot_tl_to_tf_mat(rot_mat, tl)
tensor([[0., 1., 0., 1.],
[0., 0., 1., 2.],
[1., 0., 0., 3.],
[0., 0., 0., 1.]])
>>> rot_tl_to_tf_mat(tl=tl)
tensor([[1., 0., 0., 1.],
[0., 1., 0., 2.],
[0., 0., 1., 3.],
[0., 0., 0., 1.]])
>>> rot_tl_to_tf_mat(rot_mat=rot_mat)
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
"""
if rot_mat is not None and tl is None:
tl = torch.zeros(rot_mat.shape[:-2] + (3,), device=rot_mat.device, dtype=rot_mat.dtype)
elif rot_mat is None and tl is not None:
rot_mat = torch.eye(3).to(tl).repeat(tl.shape[:-1] + (1, 1))
elif rot_mat is None and tl is None:
raise ValueError("Either rot_mat or tl should be provided.")
tf_mat = torch.cat([rot_mat, tl.unsqueeze(-1)], dim=-1) # type: ignore
return expand_tf_mat(tf_mat)
[docs]
def expand_tf_mat(tf_mat: Float[torch.Tensor, "... 3 4"]) -> Float[torch.Tensor, "... 4 4"]:
"""Expand transformation matrix of shape [... 3 4] to shape [... 4 4].
Args:
tf_mat (torch.Tensor): Transformation matrix in shape [... 3 4] or [... 4 4].
Returns:
torch.Tensor: Expanded transformation matrix in shape [... 4 4].
Examples:
>>> tf_mat = torch.tensor([[0, 1, 0, 1], [0, 0, 1, 2], [1, 0, 0, 3]], dtype=torch.float32)
>>> expand_tf_mat(tf_mat)
tensor([[0., 1., 0., 1.],
[0., 0., 1., 2.],
[1., 0., 0., 3.],
[0., 0., 0., 1.]])
"""
if tf_mat.shape[-2:] == (3, 4):
# use `expand` here should be ok, I guess
last_row = torch.tensor([0.0, 0.0, 0.0, 1.0]).to(tf_mat).expand(tf_mat.shape[:-2] + (1, 4))
tf_mat = torch.cat([tf_mat, last_row], dim=-2)
return tf_mat
__all__ = [
"transform_points",
"rotate_points",
"project_points",
"unproject_points",
"inverse_tf_mat",
"swap_major",
"rot_tl_to_tf_mat",
"to_homo",
"expand_tf_mat",
]