Source code for fastdev.xform.warp_rotation

# mypy: disable-error-code="valid-type"
from typing import Dict, Literal, Optional, Tuple, get_args

import numpy as np
import torch
import warp as wp
from beartype import beartype
from jaxtyping import Float


@wp.kernel
def axis_angle_to_matrix_via_quat_kernel(
    axis: wp.array(dtype=wp.vec3),
    angle: wp.array(dtype=wp.float32),
    rot_mat: wp.array(dtype=wp.mat33),
):
    tid = wp.tid()
    rot_mat[tid] = wp.quat_to_matrix(wp.quat_from_axis_angle(axis[tid], angle[tid]))


@wp.kernel
def axis_angle_to_matrix_kernel(
    axis: wp.array(dtype=wp.vec3),
    angle: wp.array(dtype=wp.float32),
    rot_mat: wp.array(dtype=wp.mat33),
):
    tid = wp.tid()

    axis_elem = axis[tid]
    x, y, z = axis_elem[0], axis_elem[1], axis_elem[2]
    s, c = wp.sin(angle[tid]), wp.cos(angle[tid])
    C = 1.0 - c

    xs, ys, zs = x * s, y * s, z * s
    xC, yC, zC = x * C, y * C, z * C
    xyC, yzC, zxC = x * yC, y * zC, z * xC

    rot_mat[tid] = wp.mat33(
        x * xC + c, xyC - zs, zxC + ys, xyC + zs, y * yC + c, yzC - xs, zxC - ys, yzC + xs, z * zC + c
    )


class AxisAngleToMatrix(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx, axis: Float[torch.Tensor, "... 3"], angle: Float[torch.Tensor, "..."]
    ) -> Float[torch.Tensor, "... 3 3"]:
        wp.init()
        axis_wp = wp.from_torch(axis.view(-1, 3), dtype=wp.vec3, requires_grad=axis.requires_grad)
        angles_wp = wp.from_torch(angle.view(-1), dtype=wp.float32, requires_grad=angle.requires_grad)
        rot_mat_wp = wp.empty(
            axis_wp.shape,
            dtype=wp.mat33,  # type: ignore
            device=axis_wp.device,
            requires_grad=axis_wp.requires_grad or angles_wp.requires_grad,
        )
        wp.launch(
            kernel=axis_angle_to_matrix_kernel,
            dim=(axis_wp.shape[0],),
            inputs=[axis_wp, angles_wp],
            outputs=[rot_mat_wp],
            device=axis_wp.device,
        )
        if axis.requires_grad or angle.requires_grad:
            ctx.axis_wp = axis_wp
            ctx.angles_wp = angles_wp
            ctx.rot_mat_wp = rot_mat_wp
        return wp.to_torch(rot_mat_wp).view(angle.shape + (3, 3))

    @staticmethod
    def backward(  # type: ignore
        ctx, rot_mat_grad: Float[torch.Tensor, "... 3 3"]
    ) -> Tuple[Optional[Float[torch.Tensor, "... 3"]], Optional[Float[torch.Tensor, "..."]]]:
        wp.init()
        ctx.rot_mat_wp.grad = wp.from_torch(rot_mat_grad.contiguous().view(-1, 3, 3), dtype=wp.mat33)
        wp.launch(
            kernel=axis_angle_to_matrix_kernel,
            dim=(ctx.axis_wp.shape[0],),
            inputs=[ctx.axis_wp, ctx.angles_wp],
            outputs=[ctx.rot_mat_wp],
            adj_inputs=[ctx.axis_wp.grad, ctx.angles_wp.grad],
            adj_outputs=[ctx.rot_mat_wp.grad],
            adjoint=True,
            device=ctx.axis_wp.device,
        )
        axis_grad = wp.to_torch(ctx.axis_wp.grad).view(rot_mat_grad.shape[:-1]) if ctx.axis_wp.requires_grad else None
        angle_grad = (
            wp.to_torch(ctx.angles_wp.grad).view(rot_mat_grad.shape[:-2]) if ctx.angles_wp.requires_grad else None
        )
        return axis_grad, angle_grad


@beartype
[docs] def axis_angle_to_matrix( axis: Float[torch.Tensor, "... 3"], angle: Float[torch.Tensor, "..."] ) -> Float[torch.Tensor, "... 3 3"]: """ Converts axis angles to rotation matrices using Rodrigues formula. Args: axis (torch.Tensor): axis, the shape could be [..., 3]. angle (torch.Tensor): angle, the shape could be [...]. Returns: torch.Tensor: Rotation matrices [..., 3, 3]. Example: >>> axis = torch.tensor([1.0, 0.0, 0.0]) >>> angle = torch.tensor(0.5) >>> axis_angle_to_matrix(axis, angle) tensor([[ 1.0000, 0.0000, 0.0000], [ 0.0000, 0.8776, -0.4794], [ 0.0000, 0.4794, 0.8776]]) """ return AxisAngleToMatrix.apply(axis, angle) # type: ignore
# fmt: off _AXES = Literal[ "sxyz", "sxyx", "sxzy", "sxzx", "syzx", "syzy", "syxz", "syxy", "szxy", "szxz", "szyx", "szyz", "rzyx", "rxyx", "ryzx", "rxzx", "rxzy", "ryzy", "rzxy", "ryxy", "ryxz", "rzxz", "rxyz", "rzyz" ] # fmt: on _AXES_SPEC: Dict[_AXES, wp.vec4i] = { axes: wp.vec4i("sr".index(axes[0]), "xyz".index(axes[1]), "xyz".index(axes[2]), "xyz".index(axes[3])) for axes in get_args(_AXES) } @wp.func def _euler_angle_to_matrix(angle: wp.float32, axis: wp.int32) -> wp.mat33: c, s = wp.cos(angle), wp.sin(angle) if axis == 0: return wp.mat33(1.0, 0.0, 0.0, 0.0, c, -s, 0.0, s, c) elif axis == 1: return wp.mat33(c, 0.0, s, 0.0, 1.0, 0.0, -s, 0.0, c) else: return wp.mat33(c, -s, 0.0, s, c, 0.0, 0.0, 0.0, 1.0) @wp.kernel def euler_angles_to_matrix_kernel( euler_angles: wp.array(dtype=wp.vec3), axes: wp.vec4i, rot_mat: wp.array(dtype=wp.mat33) ): tid = wp.tid() euler_angles_elem = euler_angles[tid] if axes[0] == 0: # static/extrinsic rotation rot_mat[tid] = wp.mul( wp.mul( _euler_angle_to_matrix(euler_angles_elem[2], axes[3]), _euler_angle_to_matrix(euler_angles_elem[1], axes[2]), ), _euler_angle_to_matrix(euler_angles_elem[0], axes[1]), ) else: # rotating/intrinsic rotation rot_mat[tid] = wp.mul( wp.mul( _euler_angle_to_matrix(euler_angles_elem[0], axes[1]), _euler_angle_to_matrix(euler_angles_elem[1], axes[2]), ), _euler_angle_to_matrix(euler_angles_elem[2], axes[3]), ) class EulerAnglesToMatrix(torch.autograd.Function): @staticmethod def forward( ctx, euler_angles: Float[torch.Tensor, "... 3"], axes: _AXES = "sxyz" ) -> Float[torch.Tensor, "... 3 3"]: axes = axes.lower() # type: ignore if len(axes) == 3: axes = f"s{axes}" # type: ignore if axes not in _AXES_SPEC: raise ValueError(f"Invalid axes: {axes}") wp.init() euler_angles_wp = wp.from_torch( euler_angles.view(-1, 3).contiguous(), dtype=wp.vec3, requires_grad=euler_angles.requires_grad ) rot_mat_wp = wp.from_torch( torch.empty( euler_angles.shape + (3,), dtype=euler_angles.dtype, device=euler_angles.device, requires_grad=euler_angles.requires_grad, ).view(-1, 3, 3), dtype=wp.mat33, requires_grad=euler_angles.requires_grad, ) axes_spec = _AXES_SPEC[axes] wp.launch( kernel=euler_angles_to_matrix_kernel, dim=(euler_angles_wp.shape[0],), inputs=[euler_angles_wp, axes_spec], outputs=[rot_mat_wp], device=euler_angles_wp.device, ) if euler_angles.requires_grad: ctx.euler_angles_wp = euler_angles_wp ctx.rot_mat_wp = rot_mat_wp ctx.axes_spec = axes_spec return wp.to_torch(rot_mat_wp).view(euler_angles.shape + (3,)) @staticmethod def backward( # type: ignore ctx, rot_mat_grad: Float[torch.Tensor, "... 3 3"] ) -> Tuple[Optional[Float[torch.Tensor, "... 3"]], None]: wp.init() ctx.rot_mat_wp.grad = wp.from_torch(rot_mat_grad.contiguous().view(-1, 3, 3), dtype=wp.mat33) wp.launch( kernel=euler_angles_to_matrix_kernel, dim=(ctx.euler_angles_wp.shape[0],), inputs=[ctx.euler_angles_wp, ctx.axes_spec], outputs=[ctx.rot_mat_wp], adj_inputs=[ctx.euler_angles_wp.grad, ctx.axes_spec], adj_outputs=[ctx.rot_mat_wp.grad], adjoint=True, device=ctx.euler_angles_wp.device, ) return wp.to_torch(ctx.euler_angles_wp.grad).view(rot_mat_grad.shape[:-1]), None @beartype
[docs] def euler_angles_to_matrix( euler_angles: Float[torch.Tensor, "... 3"], axes: _AXES = "sxyz" ) -> Float[torch.Tensor, "... 3 3"]: """Converts Euler angles to rotation matrices. Args: euler_angles (torch.Tensor): Tensor of Euler angles with shape [..., 3]. axes (str): Axis specification string, one of 24 possible sequences (e.g., "sxyz"). If only 3 characters are provided, "s" will be prefixed. Returns: torch.Tensor: Rotation matrices with shape [..., 3, 3]. Example: >>> euler_angles = torch.tensor([1.0, 0.5, 2.0]) >>> euler_angles_to_matrix(euler_angles, axes="sxyz") tensor([[-0.3652, -0.6592, 0.6574], [ 0.7980, 0.1420, 0.5857], [-0.4794, 0.7385, 0.4742]]) >>> euler_angles_to_matrix(euler_angles, axes="rxyz") tensor([[-0.3652, -0.7980, 0.4794], [ 0.3234, -0.5917, -0.7385], [ 0.8729, -0.1146, 0.4742]]) """ return EulerAnglesToMatrix.apply(euler_angles, axes) # type: ignore
@wp.kernel def matrix_to_quaternion_kernel( rot_mat: wp.array(dtype=wp.mat33), scalar_first: bool, canonical: bool, quat: wp.array(dtype=wp.quat), ): tid = wp.tid() q = wp.quat_from_matrix(rot_mat[tid]) # warp assumes scalar last, xyzw if canonical: if q[3] < 0: q = wp.quat(-q[0], -q[1], -q[2], -q[3]) if scalar_first: q = wp.quat(q[3], q[0], q[1], q[2]) # xyzw -> wxyz quat[tid] = q class MatrixToQuaternion(torch.autograd.Function): @staticmethod def forward( ctx, rot_mat: Float[torch.Tensor, "... 3 3"], scalar_first: bool = True, canonical: bool = True ) -> Float[torch.Tensor, "... 4"]: wp.init() rot_mat_wp = wp.from_torch( rot_mat.contiguous().view(-1, 3, 3), dtype=wp.mat33, requires_grad=rot_mat.requires_grad ) quat_wp = wp.from_torch( torch.empty( rot_mat.shape[:-2] + (4,), dtype=rot_mat.dtype, device=rot_mat.device, requires_grad=rot_mat.requires_grad, ).view(-1, 4), dtype=wp.quat, requires_grad=rot_mat.requires_grad, ) wp.launch( kernel=matrix_to_quaternion_kernel, dim=(rot_mat_wp.shape[0],), inputs=[rot_mat_wp, scalar_first, canonical], outputs=[quat_wp], device=rot_mat_wp.device, ) if rot_mat.requires_grad: ctx.scalar_first = scalar_first ctx.canonical = canonical ctx.rot_mat_wp = rot_mat_wp ctx.quat_wp = quat_wp return wp.to_torch(quat_wp).view(rot_mat.shape[:-2] + (4,)) @staticmethod def backward( # type: ignore ctx, quat_grad: Float[torch.Tensor, "... 4"] ) -> Tuple[Optional[Float[torch.Tensor, "... 3 3"]], None, None]: wp.init() ctx.quat_wp.grad = wp.from_torch(quat_grad.contiguous().view(-1, 4), dtype=wp.quat) wp.launch( kernel=matrix_to_quaternion_kernel, dim=(ctx.rot_mat_wp.shape[0],), inputs=[ctx.rot_mat_wp, ctx.scalar_first, ctx.canonical], outputs=[ctx.quat_wp], adj_inputs=[ctx.rot_mat_wp.grad, ctx.scalar_first, ctx.canonical], adj_outputs=[ctx.quat_wp.grad], adjoint=True, device=ctx.rot_mat_wp.device, ) return wp.to_torch(ctx.rot_mat_wp.grad).view(quat_grad.shape[:-1] + (3, 3)), None, None @beartype
[docs] def matrix_to_quaternion( rot_mat: Float[torch.Tensor, "... 3 3"], scalar_first: bool = True, canonical: bool = True ) -> Float[torch.Tensor, "... 4"]: """ Converts rotation matrices to quaternions. Args: rot_mat (torch.Tensor): Rotation matrices with shape [..., 3, 3]. Returns: torch.Tensor: Quaternions with shape [..., 4]. Example: >>> rot_mat = torch.tensor([[-0.2533, -0.6075, 0.7529], ... [ 0.8445, -0.5185, -0.1343], ... [ 0.4720, 0.6017, 0.6443]]) >>> matrix_to_quaternion(rot_mat) tensor([0.4671, 0.3940, 0.1503, 0.7772]) .. note:: The gradient of this function differs from the pytorch3d implementation, but it should be okay for most use cases. Ref_ .. _Ref: https://github.com/facebookresearch/pytorch3d/issues/503#issuecomment-755493515 """ return MatrixToQuaternion.apply(rot_mat, scalar_first, canonical)
@beartype
[docs] def matrix_to_quaternion_numpy( rot_mat: Float[np.ndarray, "... 3 3"], scalar_first: bool = True, canonical: bool = True ) -> Float[np.ndarray, "... 4"]: """ Converts rotation matrices to quaternions. Args: rot_mat (np.ndarray): Rotation matrices with shape [..., 3, 3]. Returns: np.ndarray: Quaternions with shape [..., 4]. Example: >>> rot_mat = np.array([[ 0.45930517, -0.11985919, -0.88015485], ... [-0.4041326 , 0.85417026, -0.32721555], ... [ 0.7910219 , 0.50599104, 0.3438858 ]]) >>> matrix_to_quaternion_numpy(rot_mat) array([ 0.8150708 , 0.25556272, -0.5125865 , -0.08719288], dtype=float32) """ wp.init() rot_mat_wp = wp.from_numpy(rot_mat.reshape((-1, 3, 3)), dtype=wp.mat33) quat_wp = wp.from_numpy(np.empty(rot_mat.shape[:-2] + (4,), dtype=rot_mat.dtype).reshape((-1, 4)), dtype=wp.quat) wp.launch( kernel=matrix_to_quaternion_kernel, dim=(rot_mat_wp.shape[0],), inputs=[rot_mat_wp, scalar_first, canonical], outputs=[quat_wp], device=rot_mat_wp.device, ) return quat_wp.numpy().reshape(rot_mat.shape[:-2] + (4,))
__all__ = ["axis_angle_to_matrix", "euler_angles_to_matrix", "matrix_to_quaternion", "matrix_to_quaternion_numpy"]