from typing import Dict, Literal, Optional, Union, get_args
import numpy as np
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
[docs]
def random_rotation_matrix(
num: Optional[int] = None,
random_state: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
return_tensors: Literal["np", "pt"] = "np",
):
try:
from scipy.spatial.transform import Rotation as R
except ImportError:
raise ImportError("This function requires scipy to be installed.")
random_rotations = R.random(num=num, random_state=random_state)
rotation_matrices = random_rotations.as_matrix()
if return_tensors == "pt":
return torch.as_tensor(rotation_matrices, dtype=torch.float32)
elif return_tensors == "np":
return rotation_matrices
else:
raise ValueError("return_tensors should be either 'np' or 'pt'")
# Adapted from https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
[docs]
def split_axis_angle_vector(axis_angle):
axis_angle = torch.as_tensor(axis_angle)
angle = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) # type: ignore
axis = axis_angle / angle
return axis, angle
[docs]
def compose_axis_angle_vector(axis, angle):
return axis * angle
[docs]
def axis_angle_vector_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as axis/angle to rotation matrices.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
return quaternion_to_matrix(axis_angle_vector_to_quaternion(axis_angle))
[docs]
def matrix_to_axis_angle_vector(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle_vector(matrix_to_quaternion(matrix))
[docs]
def axis_angle_vector_to_quaternion(axis_angle):
"""
Convert rotations given as axis/angle to quaternions.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
quaternions with real part first, as tensor of shape (..., 4).
Reference: https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Unit_quaternions
"""
axis_angle = torch.as_tensor(axis_angle)
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) # type: ignore
half_angles = angles * 0.5
eps = 1e-6
small_angles = torch.abs(angles) < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)
return quaternions
[docs]
def quaternion_to_axis_angle_vector(quaternions):
"""
Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
Reference: https://en.wikipedia.org/wiki/Axis%E2%80%93angle_representation#Unit_quaternions
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) # type: ignore
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = torch.abs(angles) < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
quaternions = quaternions[..., 1:] / sin_half_angles_over_angles
return quaternions
def normalize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Normalize quaternions to have unit length.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Normalized quaternions as tensor of shape (..., 4).
"""
return F.normalize(quaternions, p=2, dim=-1)
[docs]
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first, as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
[docs]
def quaternion_real_to_last(quaternions):
# move the real part in quaternions to last
return quaternions[..., [1, 2, 3, 0]]
[docs]
def quaternion_real_to_first(quaternions):
# move the real part in quaternions to first
return quaternions[..., [3, 0, 1, 2]]
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions shape (..., 4).
"""
aw, ax, ay, az = torch.unbind(a, -1)
bw, bx, by, bz = torch.unbind(b, -1) # type: ignore
ow = aw * bw - ax * bx - ay * by - az * bz
ox = aw * bx + ax * bw + ay * bz - az * by
oy = aw * by - ax * bz + ay * bw + az * bx
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)
[docs]
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions representing rotations, returning the quaternion
representing their composition, i.e. the versor with nonnegative real part.
Usual torch rules for broadcasting apply.
Args:
a: Quaternions as tensor of shape (..., 4), real part first.
b: Quaternions as tensor of shape (..., 4), real part first.
Returns:
The product of a and b, a tensor of quaternions of shape (..., 4).
"""
ab = quaternion_raw_multiply(a, b)
return standardize_quaternion(ab)
[docs]
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
"""
Given a quaternion representing rotation, get the quaternion representing
its inverse.
Args:
quaternion: Quaternions as tensor of shape (..., 4), with real part
first, which must be versors (unit quaternions).
Returns:
The inverse, a tensor of quaternions of shape (..., 4).
"""
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
return quaternion * scaling
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
"""
Apply the rotation given by a quaternion to a 3D point.
Usual torch rules for broadcasting apply.
Args:
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
point: Tensor of 3D points of shape (..., 3).
Returns:
Tensor of rotated points of shape (..., 3).
"""
if point.size(-1) != 3:
raise ValueError(f"Points are not in 3D, {point.shape}.")
real_parts = point.new_zeros(point.shape[:-1] + (1,))
point_as_quaternion = torch.cat((real_parts, point), -1)
out = quaternion_raw_multiply(
quaternion_raw_multiply(quaternion, point_as_quaternion),
quaternion_invert(quaternion),
)
return out[..., 1:]
[docs]
def quaternion_to_matrix(quaternions: Tensor) -> Tensor:
"""Convert rotations given as quaternions to rotation matrices.
Args:
quaternions (Tensor): quaternions with real part first with shape (..., 4).
Returns:
Tensor: Rotation matrices as tensor of shape (..., 3, 3).
Reference: https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
"""
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / torch.sum(quaternions * quaternions, dim=-1)
# fmt: off
matrices = torch.stack([1 - two_s * (j * j + k * k), two_s * (i * j - k * r), two_s * (i * k + j * r),
two_s * (i * j + k * r), 1 - two_s * (i * i + k * k), two_s * (j * k - i * r),
two_s * (i * k - j * r), two_s * (j * k + i * r), 1 - two_s * (i * i + j * j)], dim=-1)
# fmt: on
matrices = torch.reshape(matrices, quaternions.shape[:-1] + (3, 3))
return matrices
[docs]
def matrix_to_quaternion(matrix: Float[torch.Tensor, "... 3 3"]) -> Float[torch.Tensor, "... 4"]:
"""
Convert rotation matrices to quaternions using Shepperds's method.
Args:
matrix: (np.ndarray, torch.Tensor): rotation matrices, the shape could be ...3x3.
Returns:
quaternions with real part first in shape of (..., 4).
Example:
>>> rot_mat = torch.tensor([[-0.2533, -0.6075, 0.7529],
... [ 0.8445, -0.5185, -0.1343],
... [ 0.4720, 0.6017, 0.6443]])
>>> matrix_to_quaternion(rot_mat)
tensor([0.4671, 0.3940, 0.1503, 0.7772])
Ref: http://www.iri.upc.edu/files/scidoc/2068-Accurate-Computation-of-Quaternions-from-Rotation-Matrices.pdf
Note that the way to determine the best solution is slightly different from the PDF.
"""
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(torch.reshape(matrix, batch_dim + (9,)), dim=-1)
# fmt: off
q_abs = _sqrt_positive_part(torch.stack([1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1))
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack([torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1)], dim=-2)
# fmt: on
# We floor here at 0.1 but the exact level is not important; if q_abs is small, the candidate won't be picked.
flr = torch.tensor([0.1], device=q_abs.device, dtype=q_abs.dtype)
quat_candidates = quat_by_rijk / (2.0 * torch.maximum(q_abs[..., None], flr))
quat = quat_candidates[F.one_hot(torch.argmax(q_abs, dim=-1), num_classes=4) > 0.5, :]
quat = torch.reshape(quat, batch_dim + (4,))
return quat
[docs]
def axis_angle_to_matrix(
axis: Float[torch.Tensor, "... 3"], angle: Float[torch.Tensor, "..."]
) -> Float[torch.Tensor, "... 3 3"]:
"""
Converts axis angles to rotation matrices using Rodrigues formula.
Args:
axis (torch.Tensor): axis, the shape could be [..., 3].
angle (torch.Tensor): angle, the shape could be [...].
Returns:
torch.Tensor: Rotation matrices [..., 3, 3].
Example:
>>> axis = torch.tensor([1.0, 0.0, 0.0])
>>> angle = torch.tensor(0.5)
>>> axis_angle_to_matrix(axis, angle)
tensor([[ 1.0000, 0.0000, 0.0000],
[ 0.0000, 0.8776, -0.4794],
[ 0.0000, 0.4794, 0.8776]])
"""
x, y, z = torch.unbind(axis, -1)
s, c = torch.sin(angle), torch.cos(angle)
C = 1 - c
xs, ys, zs = x * s, y * s, z * s
xC, yC, zC = x * C, y * C, z * C
xyC, yzC, zxC = x * yC, y * zC, z * xC
# fmt: off
rot_mat = torch.stack([x * xC + c, xyC - zs, zxC + ys,
xyC + zs, y * yC + c, yzC - xs,
zxC - ys, yzC + xs, z * zC + c], dim=-1).reshape(angle.shape + (3, 3))
# fmt: on
return rot_mat
def _index_from_letter(letter: str) -> int:
if letter not in "xyz":
raise ValueError(f"{letter} is not a valid axis letter")
return "xyz".index(letter)
def _angle_from_tan(axis, other_axis, data, horizontal, tait_bryan):
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
Args:
axis: Axis label "x" or "y or "z" for the angle we are finding.
other_axis: Axis label "x" or "y or "z" for the middle axis in the
convention.
data: Rotation matrices as tensor of shape (..., 3, 3).
horizontal: Whether we are looking for the angle for the third axis,
which means the relevant entries are in the same row of the
rotation matrix. If not, they are in the same column.
tait_bryan: Whether the first and third axes in the convention differ.
Returns:
Euler Angles in radians for each matrix in data as a tensor
of shape (...).
"""
i1, i2 = {"x": (2, 1), "y": (0, 2), "z": (1, 0)}[axis]
if horizontal:
i2, i1 = i1, i2
even = (axis + other_axis) in ["xy", "yz", "zx"]
if isinstance(data, np.ndarray):
if horizontal == even:
return np.arctan2(data[..., i1], data[..., i2])
if tait_bryan:
return np.arctan2(-data[..., i2], data[..., i1])
return np.arctan2(data[..., i2], -data[..., i1])
elif isinstance(data, torch.Tensor):
if horizontal == even:
return torch.atan2(data[..., i1], data[..., i2])
if tait_bryan:
return torch.atan2(-data[..., i2], data[..., i1])
return torch.atan2(data[..., i2], -data[..., i1])
else:
raise ValueError("data must be a numpy array or torch tensor")
[docs]
def matrix_to_euler_angles(matrix: Tensor, convention: str = "xyz") -> Tensor:
"""
Convert rotations given as rotation matrices to Euler angles in radians.
Args:
matrix: Rotation matrices with shape (..., 3, 3).
convention: Convention string of 3/4 letters, e.g. "xyz", "sxyz", "rxyz", "exyz".
If the length is 3, the extrinsic rotation is assumed.
If the length is 4, the first character is "r/i" (rotating/intrinsic), or "s/e" (static / extrinsic).
The remaining characters are the axis "x, y, z" in the order.
Returns:
Euler angles in radians with shape (..., 3).
"""
convention = convention.lower()
extrinsic = True
if len(convention) != 3 and len(convention) != 4:
raise ValueError(f"{convention} is not a valid convention")
if len(convention) == 4:
if convention[0] not in ["r", "i", "s", "e"]:
raise ValueError(f"{convention[0]} is not a valid first character for a convention")
extrinsic = convention[0] in ["s", "e"]
convention = convention[1:]
if not extrinsic: # intrinsic
convention = convention[::-1] # reverse order
i0 = _index_from_letter(convention[0])
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
matrix = torch.as_tensor(matrix)
if tait_bryan:
central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0))
else:
central_angle = torch.acos(matrix[..., i2, i2])
angle3 = _angle_from_tan(convention[2], convention[1], matrix[..., i0], False, tait_bryan)
angle1 = _angle_from_tan(convention[0], convention[1], matrix[..., i2, :], True, tait_bryan)
if not extrinsic:
angle3, angle1 = angle1, angle3
return torch.stack([angle1, central_angle, angle3], -1) # type: ignore
def _axis_angle_rotation(axis, angle):
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "x" or "y or "z".
angle: Any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "x":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.reshape(torch.stack(R_flat, -1), angle.shape + (3, 3))
# fmt: off
_AXES = Literal[
"sxyz", "sxyx", "sxzy", "sxzx", "syzx", "syzy", "syxz", "syxy", "szxy", "szxz", "szyx", "szyz",
"rzyx", "rxyx", "ryzx", "rxzx", "rxzy", "ryzy", "rzxy", "ryxy", "ryxz", "rzxz", "rxyz", "rzyz"
]
# fmt: on
_VALID_AXES: Dict[_AXES, None] = {axes: None for axes in get_args(_AXES)}
[docs]
def euler_angles_to_matrix(
euler_angles: Float[torch.Tensor, "... 3"], axes: _AXES = "sxyz"
) -> Float[torch.Tensor, "... 3 3"]:
"""Converts Euler angles to rotation matrices.
Args:
euler_angles (torch.Tensor): Euler angles, the shape could be [..., 3].
axes (str): Axis specification; one of 24 axis string sequences - e.g. `sxyz (the default). It's recommended to use the full name of the axes, e.g. "sxyz" instead of "xyz", but if 3 characters are provided, it will be prefixed with "s".
Returns:
torch.Tensor: Rotation matrices [..., 3, 3].
Example:
>>> euler_angles = torch.tensor([1.0, 0.5, 2.0])
>>> euler_angles_to_matrix(euler_angles, axes="sxyz")
tensor([[-0.3652, -0.6592, 0.6574],
[ 0.7980, 0.1420, 0.5857],
[-0.4794, 0.7385, 0.4742]])
>>> euler_angles_to_matrix(euler_angles, axes="rxyz")
tensor([[-0.3652, -0.7980, 0.4794],
[ 0.3234, -0.5917, -0.7385],
[ 0.8729, -0.1146, 0.4742]])
"""
axes = axes.lower() # type: ignore
if len(axes) == 3:
axes = f"s{axes}" # type: ignore
if axes not in _VALID_AXES:
raise ValueError(f"Invalid axes: {axes}")
matrices = [_axis_angle_rotation(c, e) for c, e in zip(axes[1:], torch.unbind(euler_angles, -1))]
if axes[0] == "s":
return torch.matmul(torch.matmul(matrices[2], matrices[1]), matrices[0])
else:
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
[docs]
def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
"""Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalization per Section B of [1].
Args:
d6 (Tensor): 6D rotation representation of shape [..., 6]
Returns:
Tensor: Rotation matrices of shape [..., 3, 3]
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks. CVPR 2019. arxiv_
`pytorch3d implementation`_
.. _arxiv: https://arxiv.org/pdf/1812.07035
.. _`pytorch3d implementation`: https://github.com/facebookresearch/pytorch3d/blob/bd52f4a408b29dc6b4357b70c93fd7a9749ca820/pytorch3d/transforms/rotation_conversions.py#L558
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - torch.sum(b1 * a2, dim=-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
[docs]
def matrix_to_rotation_6d(matrix: Tensor) -> Tensor:
"""Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
by dropping the last row. Note that 6D representation is not unique.
Args:
matrix: batch of rotation matrices of size [..., 3, 3]
Returns:
6D rotation representation, of shape [..., 6]
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks. CVPR 2019. arxiv_
.. _arxiv: https://arxiv.org/pdf/1812.07035
"""
batch_dim = matrix.shape[:-2]
return torch.reshape(torch.clone(matrix[..., :2, :]), batch_dim + (6,))
__all__ = [
"axis_angle_to_matrix",
"axis_angle_vector_to_quaternion",
"compose_axis_angle_vector",
"euler_angles_to_matrix",
"matrix_to_euler_angles",
"matrix_to_quaternion",
"matrix_to_rotation_6d",
"quaternion_real_to_first",
"quaternion_real_to_last",
"quaternion_to_axis_angle_vector",
"quaternion_to_matrix",
"random_rotation_matrix",
"rotation_6d_to_matrix",
"split_axis_angle_vector",
"standardize_quaternion",
]