Source code for fastdev.robo.warp_kinematics

# mypy: disable-error-code="valid-type"
# ruff: noqa: F821
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

import numpy as np
import torch
import warp as wp
from jaxtyping import Float, Int, Bool
from warp.fem.linalg import inverse_qr, solve_triangular

from fastdev.xform import matrix_to_quaternion

if TYPE_CHECKING:
    from fastdev.robo.articulation import Articulation


@wp.func
[docs] def axis_angle_to_tf_mat(axis: wp.vec3, angle: wp.float32): x, y, z = axis[0], axis[1], axis[2] s, c = wp.sin(angle), wp.cos(angle) C = 1.0 - c xs, ys, zs = x * s, y * s, z * s xC, yC, zC = x * C, y * C, z * C xyC, yzC, zxC = x * yC, y * zC, z * xC # fmt: off return wp.mat44( x * xC + c, xyC - zs, zxC + ys, 0.0, xyC + zs, y * yC + c, yzC - xs, 0.0, zxC - ys, yzC + xs, z * zC + c, 0.0, 0.0, 0.0, 0.0, 1.0, )
# fmt: on @wp.func
[docs] def axis_distance_to_tf_mat(axis: wp.vec3, distance: wp.float32): x, y, z = axis[0], axis[1], axis[2] # fmt: off return wp.mat44( 1.0, 0.0, 0.0, distance * x, 0.0, 1.0, 0.0, distance * y, 0.0, 0.0, 1.0, distance * z, 0.0, 0.0, 0.0, 1.0, )
# fmt: on @wp.kernel
[docs] def forward_kinematics_kernel( joint_values: wp.array2d(dtype=wp.float32), # [b, num_dofs] root_poses: wp.array2d(dtype=wp.mat44), # [b, num_arti, 4, 4], optional joint_first_indices: wp.array(dtype=wp.int32), link_indices_topological_order: wp.array(dtype=wp.int32), parent_link_indices: wp.array(dtype=wp.int32), link_joint_indices: wp.array(dtype=wp.int32), link_joint_types: wp.array(dtype=wp.int32), link_joint_origins: wp.array(dtype=wp.mat44), link_joint_axes: wp.array(dtype=wp.vec3), link_first_indices: wp.array(dtype=wp.int32), link_poses: wp.array2d(dtype=wp.mat44), # output, [b, num_links] ): b_idx, arti_idx = wp.tid() joint_first_idx = joint_first_indices[arti_idx] link_first_idx = link_first_indices[arti_idx] if arti_idx == wp.int32(link_first_indices.shape[0] - 1): link_last_idx = wp.int32(link_indices_topological_order.shape[0]) else: link_last_idx = link_first_indices[arti_idx + 1] if root_poses.shape[0] > 0: root_pose = root_poses[b_idx, arti_idx] else: root_pose = wp.identity(n=4, dtype=wp.float32) # type: ignore for glb_topo_idx in range(link_first_idx, link_last_idx): glb_link_idx = link_indices_topological_order[glb_topo_idx] + link_first_idx joint_type = link_joint_types[glb_link_idx] if joint_type == -1: # Root link glb_joint_pose = root_pose else: # Non-root links glb_parent_link_idx = parent_link_indices[glb_link_idx] + link_first_idx parent_link_pose = link_poses[b_idx, glb_parent_link_idx] glb_joint_idx = link_joint_indices[glb_link_idx] + joint_first_idx if joint_type == 0: local_joint_tf = wp.identity(n=4, dtype=wp.float32) # type: ignore elif joint_type == 1: # prismatic joint_value = joint_values[b_idx, glb_joint_idx] joint_axis = link_joint_axes[glb_link_idx] local_joint_tf = axis_distance_to_tf_mat(joint_axis, joint_value) elif joint_type == 2: # revolute joint_value = joint_values[b_idx, glb_joint_idx] joint_axis = link_joint_axes[glb_link_idx] local_joint_tf = axis_angle_to_tf_mat(joint_axis, joint_value) joint_origin = link_joint_origins[glb_link_idx] glb_joint_pose = (parent_link_pose @ joint_origin) @ local_joint_tf # type: ignore link_poses[b_idx, glb_link_idx] = glb_joint_pose
_KERNEL_PARAMS_TYPES_AND_GETTERS = { "joint_first_indices": (wp.int32, "get_joint_first_indices"), "link_indices_topological_order": (wp.int32, "get_packed_link_indices_topological_order"), "parent_link_indices": (wp.int32, "get_packed_parent_link_indices"), "link_joint_indices": (wp.int32, "get_packed_link_joint_indices"), "link_joint_types": (wp.int32, "get_packed_link_joint_types"), "link_joint_origins": (wp.mat44, "get_packed_link_joint_origins"), "link_joint_axes": (wp.vec3, "get_packed_link_joint_axes"), "link_first_indices": (wp.int32, "get_link_first_indices"), }
[docs] class ForwardKinematics(torch.autograd.Function): @staticmethod
[docs] def forward( ctx, joint_values: Float[torch.Tensor, "... total_num_joints"], articulation: "Articulation", root_poses: Optional[Float[torch.Tensor, "... num_arti 4 4"]] = None, ) -> Float[torch.Tensor, "... total_num_links 4 4"]: batch_shape = joint_values.shape[:-1] total_num_joints = joint_values.shape[-1] total_num_links = articulation.total_num_links num_arti = articulation.num_arti requires_grad = joint_values.requires_grad or (root_poses is not None and root_poses.requires_grad) wp.init() joint_values_wp = wp.from_torch( joint_values.contiguous().view(-1, total_num_joints), dtype=wp.float32, requires_grad=joint_values.requires_grad, ) root_poses_wp = ( wp.from_torch( root_poses.contiguous().view(-1, num_arti, 4, 4), dtype=wp.mat44, requires_grad=root_poses.requires_grad, ) if root_poses is not None else wp.zeros(shape=(0, 0), dtype=wp.mat44, requires_grad=False, device=joint_values_wp.device) ) link_poses_wp = wp.from_torch( torch.zeros( (joint_values_wp.shape[0], total_num_links, 4, 4), device=joint_values.device, dtype=joint_values.dtype, requires_grad=requires_grad, ), dtype=wp.mat44, requires_grad=requires_grad, ) wp_params = { name: wp.from_torch(getattr(articulation, fn)(return_tensors="pt"), dtype=dtype) for name, (dtype, fn) in _KERNEL_PARAMS_TYPES_AND_GETTERS.items() } wp.launch( kernel=forward_kinematics_kernel, dim=(joint_values_wp.shape[0], num_arti), inputs=[joint_values_wp, root_poses_wp, *wp_params.values()], outputs=[link_poses_wp], device=joint_values_wp.device, ) if joint_values_wp.requires_grad or root_poses_wp.requires_grad: ctx.shapes = (batch_shape, total_num_joints, total_num_links, num_arti) ctx.joint_values_wp = joint_values_wp ctx.root_poses_wp = root_poses_wp ctx.link_poses_wp = link_poses_wp ctx.wp_params = wp_params return wp.to_torch(link_poses_wp).view(*batch_shape, total_num_links, 4, 4)
@staticmethod
[docs] def backward( # type: ignore ctx, link_poses_grad: Float[torch.Tensor, "... total_num_links 4 4"] ) -> Tuple[ Optional[Float[torch.Tensor, "... total_num_joints"]], None, Optional[Float[torch.Tensor, "... num_arti 4 4"]], ]: if not ctx.joint_values_wp.requires_grad and (not ctx.root_poses_wp.requires_grad): return None, None, None batch_shape, total_num_joints, total_num_links, num_arti = ctx.shapes ctx.link_poses_wp.grad = wp.from_torch( link_poses_grad.contiguous().view(-1, total_num_links, 4, 4), dtype=wp.mat44 ) wp.launch( kernel=forward_kinematics_kernel, dim=(ctx.joint_values_wp.shape[0], num_arti), inputs=[ctx.joint_values_wp, ctx.root_poses_wp, *ctx.wp_params.values()], outputs=[ctx.link_poses_wp], adj_inputs=[ctx.joint_values_wp.grad, ctx.root_poses_wp.grad, *([None] * len(ctx.wp_params))], adj_outputs=[ctx.link_poses_wp.grad], adjoint=True, device=ctx.joint_values_wp.device, ) joint_values_grad = ( wp.to_torch(ctx.joint_values_wp.grad).view(*batch_shape, total_num_joints) if ctx.joint_values_wp.requires_grad else None ) root_poses_grad = ( wp.to_torch(ctx.root_poses_wp.grad).view(*batch_shape, num_arti, 4, 4) if ctx.root_poses_wp.requires_grad else None ) return joint_values_grad, None, root_poses_grad
[docs] def forward_kinematics( joint_values: Float[torch.Tensor, "... total_num_joints"], articulation: "Articulation", root_poses: Optional[Float[torch.Tensor, "... num_arti 4 4"]] = None, ) -> Float[torch.Tensor, "... total_num_links 4 4"]: return ForwardKinematics.apply(joint_values, articulation, root_poses)
[docs] def forward_kinematics_numpy( joint_values: Float[np.ndarray, "... total_num_joints"], # noqa: F821 articulation: "Articulation", root_poses: Optional[Float[np.ndarray, "... num_arti 4 4"]] = None, ) -> Float[np.ndarray, "... total_num_links 4 4"]: total_num_joints = joint_values.shape[-1] total_num_links = articulation.total_num_links num_arti = articulation.num_arti wp.init() joint_values_wp = wp.from_numpy(joint_values.reshape(-1, total_num_joints), dtype=wp.float32) # [B, num_dofs] link_poses_wp = wp.from_numpy( np.zeros( (joint_values_wp.shape[0], total_num_links, 4, 4), dtype=joint_values.dtype, ), dtype=wp.mat44, ) root_poses_wp = ( wp.from_numpy(root_poses.reshape(-1, num_arti, 4, 4), dtype=wp.mat44) if root_poses is not None else wp.zeros(shape=(0, 0), dtype=wp.mat44, requires_grad=False, device=joint_values_wp.device) ) wp_params = { name: wp.from_numpy(getattr(articulation, fn)("np"), dtype=dtype) for name, (dtype, fn) in _KERNEL_PARAMS_TYPES_AND_GETTERS.items() } wp.launch( kernel=forward_kinematics_kernel, dim=(joint_values_wp.shape[0], num_arti), inputs=[joint_values_wp, root_poses_wp, *wp_params.values()], outputs=[link_poses_wp], device=joint_values_wp.device, ) return link_poses_wp.numpy().reshape(joint_values.shape[:-1] + (total_num_links, 4, 4))
@wp.kernel
[docs] def calculate_jacobian_kernel( link_poses: wp.array2d(dtype=wp.mat44), # [b, total_num_links] target_link_indices: wp.array(dtype=wp.int32), # [num_arti] ancestor_mask: wp.array(dtype=wp.int32), # [total_num_links] joint_axes: wp.array(dtype=wp.vec3), # [total_num_joints] link_joint_types: wp.array(dtype=wp.int32), # [total_num_links] link_joint_indices: wp.array(dtype=wp.int32), # [total_num_links] link_joint_origins: wp.array(dtype=wp.mat44), # [total_num_links] parent_link_indices: wp.array(dtype=wp.int32), # [total_num_links] joint_first_indices: wp.array(dtype=wp.int32), # [num_arti] link_first_indices: wp.array(dtype=wp.int32), # [num_arti] jacobian: wp.array3d(dtype=wp.float32), # [b, 6, total_num_joints] ): """Compute the Jacobian using precomputed target link poses.""" total_num_links = link_joint_types.shape[0] b_idx, global_link_idx = wp.tid() arti_idx = -1 num_arti = link_first_indices.shape[0] for i in range(num_arti): start = link_first_indices[i] if i < num_arti - 1: end = link_first_indices[i + 1] else: end = total_num_links if global_link_idx >= start and global_link_idx < end: arti_idx = i break if arti_idx == -1: return if ancestor_mask[global_link_idx] == 0: return # use the precomputed target link pose instead of computing indices target_global_link_idx = target_link_indices[arti_idx] + link_first_indices[arti_idx] eef_pose = link_poses[b_idx, target_global_link_idx] p_eef = wp.vec3(eef_pose[0, 3], eef_pose[1, 3], eef_pose[2, 3]) jt = link_joint_types[global_link_idx] if jt == -1 or (jt != 1 and jt != 2): return glb_joint_idx = link_joint_indices[global_link_idx] + joint_first_indices[arti_idx] if global_link_idx == link_first_indices[arti_idx]: joint_pose = link_poses[b_idx, global_link_idx] @ link_joint_origins[global_link_idx] else: parent_idx = parent_link_indices[global_link_idx] + link_first_indices[arti_idx] joint_pose = link_poses[b_idx, parent_idx] @ link_joint_origins[global_link_idx] p_joint = wp.vec3(joint_pose[0, 3], joint_pose[1, 3], joint_pose[2, 3]) r0 = wp.vec3(joint_pose[0, 0], joint_pose[0, 1], joint_pose[0, 2]) r1 = wp.vec3(joint_pose[1, 0], joint_pose[1, 1], joint_pose[1, 2]) r2 = wp.vec3(joint_pose[2, 0], joint_pose[2, 1], joint_pose[2, 2]) axis_local = joint_axes[glb_joint_idx] axis_world = wp.vec3( r0[0] * axis_local[0] + r0[1] * axis_local[1] + r0[2] * axis_local[2], r1[0] * axis_local[0] + r1[1] * axis_local[1] + r1[2] * axis_local[2], r2[0] * axis_local[0] + r2[1] * axis_local[1] + r2[2] * axis_local[2], ) if jt == 2: # revolute v = wp.vec3(p_eef[0] - p_joint[0], p_eef[1] - p_joint[1], p_eef[2] - p_joint[2]) linear_jac = wp.vec3( axis_world[1] * v[2] - axis_world[2] * v[1], axis_world[2] * v[0] - axis_world[0] * v[2], axis_world[0] * v[1] - axis_world[1] * v[0], ) angular_jac = axis_world elif jt == 1: # prismatic linear_jac = axis_world angular_jac = wp.vec3(0.0, 0.0, 0.0) jacobian[b_idx, 0, glb_joint_idx] = linear_jac[0] jacobian[b_idx, 1, glb_joint_idx] = linear_jac[1] jacobian[b_idx, 2, glb_joint_idx] = linear_jac[2] jacobian[b_idx, 3, glb_joint_idx] = angular_jac[0] jacobian[b_idx, 4, glb_joint_idx] = angular_jac[1] jacobian[b_idx, 5, glb_joint_idx] = angular_jac[2]
[docs] class CalculateJacobian(torch.autograd.Function): @staticmethod
[docs] def forward( ctx, joint_values: Float[torch.Tensor, "... total_num_joints"], target_link_indices: torch.Tensor, articulation: "Articulation", root_poses: Optional[Float[torch.Tensor, "... num_arti 4 4"]] = None, ) -> Tuple[Float[torch.Tensor, "... 6 total_num_joints"], Float[torch.Tensor, "... num_arti 4 4"]]: """Forward pass to compute Jacobian and target link poses.""" batch_shape = joint_values.shape[:-1] total_num_joints = joint_values.shape[-1] total_num_links = articulation.total_num_links num_arti = articulation.num_arti requires_grad = joint_values.requires_grad or (root_poses is not None and root_poses.requires_grad) wp.init() joint_values_wp = wp.from_torch( joint_values.contiguous().view(-1, total_num_joints), dtype=wp.float32, requires_grad=joint_values.requires_grad, ) root_poses_wp = ( wp.from_torch( root_poses.contiguous().view(-1, num_arti, 4, 4), dtype=wp.mat44, requires_grad=root_poses.requires_grad, ) if root_poses is not None else wp.zeros(shape=(0, 0), dtype=wp.mat44, requires_grad=False, device=joint_values_wp.device) ) link_poses_wp = wp.from_torch( torch.zeros( (joint_values_wp.shape[0], total_num_links, 4, 4), device=joint_values.device, dtype=joint_values.dtype, requires_grad=requires_grad, ), dtype=wp.mat44, requires_grad=requires_grad, ) target_link_indices_wp = wp.from_torch(target_link_indices, dtype=wp.int32, requires_grad=False) wp_params = { name: wp.from_torch(getattr(articulation, fn)(return_tensors="pt"), dtype=dtype) for name, (dtype, fn) in _KERNEL_PARAMS_TYPES_AND_GETTERS.items() } # launch forward kinematics kernel wp.launch( kernel=forward_kinematics_kernel, dim=(joint_values_wp.shape[0], num_arti), inputs=[joint_values_wp, root_poses_wp, *wp_params.values()], outputs=[link_poses_wp], device=joint_values_wp.device, ) # compute target link poses from link poses joint_axes_wp = wp.from_torch(articulation.get_packed_full_joint_axes(return_tensors="pt"), dtype=wp.vec3) ancestor_mask_pt = articulation.get_packed_ancestor_links_mask(target_link_indices, return_tensors="pt") ancestor_mask_wp = wp.from_torch(ancestor_mask_pt, dtype=wp.int32) jacobian_torch = torch.zeros( (joint_values_wp.shape[0], 6, total_num_joints), device=joint_values.device, dtype=joint_values.dtype, requires_grad=False, ) jacobian_wp = wp.from_torch(jacobian_torch, dtype=wp.float32, requires_grad=False) # launch Jacobian kernel with target link poses as input wp.launch( kernel=calculate_jacobian_kernel, dim=(joint_values_wp.shape[0], total_num_links), inputs=[ link_poses_wp, target_link_indices_wp, ancestor_mask_wp, joint_axes_wp, wp_params["link_joint_types"], wp_params["link_joint_indices"], wp_params["link_joint_origins"], wp_params["parent_link_indices"], wp_params["joint_first_indices"], wp_params["link_first_indices"], ], outputs=[jacobian_wp], device=joint_values_wp.device, ) if requires_grad: ctx.batch_shape = batch_shape ctx.total_num_joints = total_num_joints ctx.total_num_links = total_num_links # stored for backward launch ctx.num_arti = num_arti ctx.joint_values_wp = joint_values_wp ctx.root_poses_wp = root_poses_wp ctx.link_poses_wp = link_poses_wp ctx.wp_params = wp_params ctx.jacobian_wp = jacobian_wp ctx.joint_axes_wp = joint_axes_wp ctx.ancestor_mask_wp = ancestor_mask_wp link_poses_pt = wp.to_torch(link_poses_wp).view(-1, total_num_links, 4, 4) link_first_indices_pt = wp.to_torch(wp_params["link_first_indices"]).view(-1) # [num_arti] target_glb_indices = target_link_indices.to(link_poses_pt.device) + link_first_indices_pt target_link_poses = link_poses_pt[:, target_glb_indices.long(), :, :] return ( wp.to_torch(jacobian_wp).view(*batch_shape, 6, total_num_joints), target_link_poses.view(*batch_shape, -1, 4, 4), )
@staticmethod
[docs] def backward( # type: ignore ctx, jacobian_grad: Float[torch.Tensor, "... 6 total_num_joints"], target_link_poses_grad: Float[torch.Tensor, "... num_arti 4 4"], ) -> Tuple[ Optional[Float[torch.Tensor, "... total_num_joints"]], None, None, Optional[Float[torch.Tensor, "... num_arti 4 4"]], ]: """Backward pass for propagating Jacobian gradients.""" raise NotImplementedError("Backward pass for Jacobian calculation is not verified yet.")
[docs] def calculate_jacobian( joint_values: Float[torch.Tensor, "... total_num_joints"], target_link_indices: torch.Tensor, articulation: "Articulation", root_poses: Optional[Float[torch.Tensor, "... num_arti 4 4"]] = None, return_target_link_poses: bool = False, ) -> Union[ Float[torch.Tensor, "... 6 total_num_joints"], Tuple[Float[torch.Tensor, "... 6 total_num_joints"], Float[torch.Tensor, "... num_arti 4 4"]], ]: jacobian, target_link_poses = CalculateJacobian.apply(joint_values, target_link_indices, articulation, root_poses) if return_target_link_poses: return jacobian, target_link_poses else: return jacobian
@wp.kernel
[docs] def delta_pose_kernel( T_current: wp.array(dtype=wp.mat44), pos_target: wp.array(dtype=wp.vec3), quat_target: wp.array(dtype=wp.vec4), delta: wp.array2d(dtype=wp.float32), ) -> None: """Compute delta pose error between current and target poses in axis-angle form.""" tid = wp.tid() # extract current translation error T = T_current[tid] p_current = wp.vec3(T[0, 3], T[1, 3], T[2, 3]) pt = pos_target[tid] pos_err = wp.vec3(pt[0] - p_current[0], pt[1] - p_current[1], pt[2] - p_current[2]) # extract rotation matrix elements r00 = T[0, 0] r01 = T[0, 1] r02 = T[0, 2] r10 = T[1, 0] r11 = T[1, 1] r12 = T[1, 2] r20 = T[2, 0] r21 = T[2, 1] r22 = T[2, 2] # convert rotation matrix to quaternion (w, x, y, z) trace = r00 + r11 + r22 if trace > 0.0: s = wp.sqrt(trace + 1.0) * 2.0 qw = 0.25 * s qx = (r21 - r12) / s qy = (r02 - r20) / s qz = (r10 - r01) / s else: if r00 > r11 and r00 > r22: s = wp.sqrt(1.0 + r00 - r11 - r22) * 2.0 qw = (r21 - r12) / s qx = 0.25 * s qy = (r01 + r10) / s qz = (r02 + r20) / s elif r11 > r22: s = wp.sqrt(1.0 + r11 - r00 - r22) * 2.0 qw = (r02 - r20) / s qx = (r01 + r10) / s qy = 0.25 * s qz = (r12 + r21) / s else: s = wp.sqrt(1.0 + r22 - r00 - r11) * 2.0 qw = (r10 - r01) / s qx = (r02 + r20) / s qy = (r12 + r21) / s qz = 0.25 * s # get target quaternion components (w, x, y, z) target_q = quat_target[tid] qtw = target_q[0] qtx = target_q[1] qty = target_q[2] qtz = target_q[3] # compute quaternion error: quat_err = quat_target * quaternion_invert(current) q_err_w = qtw * qw + qtx * qx + qty * qy + qtz * qz q_err_x = -qtw * qx + qtx * qw - qty * qz + qtz * qy q_err_y = -qtw * qy + qtx * qz + qty * qw - qtz * qx q_err_z = -qtw * qz - qtx * qy + qty * qx + qtz * qw # ensure the error quaternion is in the same hemisphere (w>=0) if q_err_w < 0.0: q_err_w = -q_err_w q_err_x = -q_err_x q_err_y = -q_err_y q_err_z = -q_err_z # clamp to avoid numerical issues (acos domain) if q_err_w > 1.0: q_err_w = 1.0 if q_err_w < -1.0: q_err_w = -1.0 # convert the quaternion error to axis-angle representation angle = 2.0 * wp.acos(q_err_w) sin_half_angle = wp.sqrt(1.0 - q_err_w * q_err_w) eps = 1e-6 if sin_half_angle < eps: rot_error_x = 0.0 rot_error_y = 0.0 rot_error_z = 0.0 else: rot_error_x = q_err_x / sin_half_angle * angle rot_error_y = q_err_y / sin_half_angle * angle rot_error_z = q_err_z / sin_half_angle * angle # write concatenated translation and rotation error to the output delta[tid, 0] = pos_err[0] delta[tid, 1] = pos_err[1] delta[tid, 2] = pos_err[2] delta[tid, 3] = rot_error_x delta[tid, 4] = rot_error_y delta[tid, 5] = rot_error_z
[docs] def delta_pose_warp( T_current: Float[torch.Tensor, "... 4 4"], pos_target: Float[torch.Tensor, "... 3"], quat_target: Float[torch.Tensor, "... 4"], ) -> Float[torch.Tensor, "... 6"]: batch_shape = T_current.shape[:-2] T_current_wp = wp.from_torch(T_current.contiguous().view(-1, 4, 4), dtype=wp.mat44, requires_grad=False) pos_target_wp = wp.from_torch(pos_target.contiguous().view(-1, 3), dtype=wp.vec3, requires_grad=False) quat_target_wp = wp.from_torch(quat_target.contiguous().view(-1, 4), dtype=wp.vec4, requires_grad=False) delta_wp = wp.zeros(shape=(T_current_wp.shape[0], 6), dtype=wp.float32, device=T_current_wp.device) wp.launch( kernel=delta_pose_kernel, dim=(T_current_wp.shape[0],), inputs=[T_current_wp, pos_target_wp, quat_target_wp], outputs=[delta_wp], device=T_current_wp.device, ) return wp.to_torch(delta_wp).view(*batch_shape, 6)
@wp.func
[docs] def solve_lower_triangular(L: Any, b: Any) -> Any: """Solves for y in L y = b where L is lower triangular with unit diagonal. Returns: y: the solution vector. """ y = type(b)() # initialized vector (assumed zero) for i in range(type(b).length): sum_val = b.dtype(0) for j in range(i): # accumulate lower-triangular contributions sum_val = sum_val + L[i, j] * y[j] y[i] = b[i] - sum_val return y
@wp.func
[docs] def inverse_lu(A: Any) -> Any: """Computes the inverse of a square matrix using LU decomposition without pivoting. Returns: A_inv: the inverse of A. """ # initialize L as identity and U as a copy of A L = wp.identity(n=type(A[0]).length, dtype=A.dtype) U = type(A)() for i in range(type(A[0]).length): for j in range(type(A[0]).length): U[i, j] = A[i, j] # perform LU decomposition (Doolittle algorithm) for i in range(type(A[0]).length): for j in range(i + 1, type(A[0]).length): pivot = U[i, i] factor = U[j, i] / pivot L[j, i] = factor for k in range(i, type(A[0]).length): U[j, k] = U[j, k] - factor * U[i, k] # compute inverse column-by-column by solving A x = e for each basis vector e A_inv = type(A)() for i in range(type(A[0]).length): # create standard basis vector e with 1 at index i and 0 elsewhere e = type(A[0])() for j in range(type(A[0]).length): e[j] = A.dtype(0) e[i] = A.dtype(1) # forward substitution: solve L y = e y = solve_lower_triangular(L, e) x = solve_triangular(U, y) A_inv[i] = x return wp.transpose(A_inv)
@wp.kernel
[docs] def compute_dq_kernel_v2( J: wp.array3d(dtype=wp.float32), # shape: [b, 6, total_num_joints] err: wp.array3d(dtype=wp.float32), # shape: [b, num_arti, 6] A: wp.array3d(dtype=wp.float32), # shape: [b, num_arti, 6] joint_first_indices: wp.array(dtype=wp.int32), # shape: [num_arti] damping: wp.float32, # damping factor for regularization dq: wp.array2d(dtype=wp.float32), # shape: [b, total_num_joints] (output) ) -> None: """Compute joint velocity dq using a damped least-squares formulation.""" b_idx, arti_idx = wp.tid() # determine number of articulations and total joints from input arrays num_arti = joint_first_indices.shape[0] total_num_joints = int(J.shape[2]) # determine joint index range for the current articulation start_idx = joint_first_indices[arti_idx] if arti_idx == num_arti - 1: end_idx = total_num_joints else: end_idx = joint_first_indices[arti_idx + 1] # compute the 6x6 JJt matrix for the current articulation M = wp.identity(n=6, dtype=wp.float32) # local 6x6 matrix for i in range(6): for j in range(6): sum_val = float(0.0) for k in range(start_idx, end_idx): # accumulate product contributions over joint indices sum_val = sum_val + J[b_idx, i, k] * J[b_idx, j, k] if i == j: # add damping regularization on the diagonal sum_val = sum_val + damping M[i, j] = sum_val # compute inverse of the JJt matrix via QR factorization inv_M = inverse_qr(M) # inv_M = inverse_lu(M) # compute vector A = inv_M * err locally for the articulation block for i in range(6): acc = float(0.0) for j in range(6): acc = acc + inv_M[i, j] * err[b_idx, arti_idx, j] A[b_idx, arti_idx, i] = acc # compute dq for each joint in the current articulation block for k in range(start_idx, end_idx): acc = float(0.0) for i in range(6): acc = acc + J[b_idx, i, k] * A[b_idx, arti_idx, i] dq[b_idx, k] = acc
[docs] def compute_dq_warp( J: Float[torch.Tensor, "... 6 total_num_joints"], err: Float[torch.Tensor, "... num_arti 6"], joint_first_indices_pt: Int[torch.Tensor, "num_arti"], damping: float, ) -> Float[torch.Tensor, "... total_num_joints"]: batch_shape = J.shape[:-2] total_num_joints = J.shape[-1] num_arti = len(joint_first_indices_pt) J_wp = wp.from_torch(J.contiguous().view(-1, 6, total_num_joints), dtype=wp.float32, requires_grad=False) err_wp = wp.from_torch(err.contiguous().view(-1, num_arti, 6), dtype=wp.float32, requires_grad=False) # [b, A, 6] A_wp = wp.zeros(shape=(err_wp.shape[0], num_arti, 6), dtype=wp.float32, device=err_wp.device) # [b, A, 6] dq_wp = wp.zeros(shape=(J_wp.shape[0], total_num_joints), dtype=wp.float32, device=J_wp.device) wp.launch( kernel=compute_dq_kernel_v2, dim=(J_wp.shape[0], num_arti), inputs=[J_wp, err_wp, A_wp, wp.from_torch(joint_first_indices_pt, dtype=wp.int32), damping, dq_wp], device=J_wp.device, ) dq = wp.to_torch(dq_wp).view(*batch_shape, total_num_joints) return dq
@torch.no_grad()
[docs] def inverse_kinematics( target_link_poses: Float[torch.Tensor, "... num_arti 4 4"], target_link_indices: Int[torch.Tensor, "num_arti"], articulation: "Articulation", max_iterations: int = 100, learning_rate: float = 0.1, tolerance: float = 1e-6, damping: float = 0.01, num_retries: int = 10, init_joint_values: Optional[Float[torch.Tensor, "... total_num_dofs"]] = None, jitter_strength: float = 1.0, ) -> Tuple[Float[torch.Tensor, "... total_num_dofs"], Bool[torch.Tensor, "... num_arti"]]: batch_shape = target_link_poses.shape[:-3] device = target_link_poses.device num_arti = articulation.num_arti total_num_dofs = articulation.total_num_dofs if init_joint_values is None: init_joint_values = ( articulation.get_packed_zero_joint_values(return_tensors="pt").expand(*batch_shape, -1).clone() # type: ignore ) if not articulation.has_none_joint_limits: lower_limit, upper_limit = articulation.joint_limits[..., 0], articulation.joint_limits[..., 1] # type: ignore lower_q_init = (lower_limit - init_joint_values) * jitter_strength + init_joint_values upper_q_init = (upper_limit - init_joint_values) * jitter_strength + init_joint_values lower_q_init = lower_q_init.unsqueeze(-2) # add a retry dimension upper_q_init = upper_q_init.unsqueeze(-2) q_init = lower_q_init + (upper_q_init - lower_q_init) * torch.rand( *batch_shape, num_retries, total_num_dofs, device=device, dtype=target_link_poses.dtype ) else: # NOTE use a small range of 0.1 when joint limits are not specified lower_q_init = init_joint_values - 0.1 upper_q_init = init_joint_values + 0.1 lower_q_init = lower_q_init.unsqueeze(-2) # add a retry dimension upper_q_init = upper_q_init.unsqueeze(-2) q_init = lower_q_init + (upper_q_init - lower_q_init) * torch.rand( *batch_shape, num_retries, total_num_dofs, device=device, dtype=target_link_poses.dtype ) q = q_init # shape: (*batch_shape, num_retries, total_num_dofs) target_link_poses_particles = target_link_poses.unsqueeze(-4).expand(*batch_shape, num_retries, num_arti, 4, 4) pos_target = target_link_poses_particles[..., :3, 3] # shape: (*batch_shape, num_retries, num_arti, 3) # shape: (*batch_shape, num_retries, num_arti, 4) quat_target = matrix_to_quaternion(target_link_poses_particles[..., :3, :3]) joint_first_indices_pt = articulation.get_joint_first_indices(return_tensors="pt") joint_limits = articulation.get_packed_joint_limits(return_tensors="pt") jfi = joint_first_indices_pt.tolist() joint_slices = [slice(jfi[i], jfi[i + 1]) for i in range(num_arti - 1)] + [slice(jfi[-1], total_num_dofs)] for _ in range(max_iterations): J, current_poses = calculate_jacobian(q, target_link_indices, articulation, return_target_link_poses=True) err = delta_pose_warp(current_poses, pos_target, quat_target) # shape: (*batch_shape, num_retries, A, 6) err_norm = err.norm(dim=-1) # shape: (*batch_shape, num_retries, A) if (err_norm < tolerance).any(dim=-2).all(): break dq = compute_dq_warp(J, err, joint_first_indices_pt, damping) # type: ignore q = q + learning_rate * dq if joint_limits is not None: q = torch.clamp(q, min=joint_limits[..., 0], max=joint_limits[..., 1]) # type: ignore # select the best retry per articulation based on squared error norm _, current_poses = calculate_jacobian(q, target_link_indices, articulation, return_target_link_poses=True) final_err = delta_pose_warp(current_poses, pos_target, quat_target) # shape: (*batch_shape, num_retries, A, 6) final_err_norm = final_err.norm(dim=-1) # shape: (*batch_shape, num_retries, A) best_idx = final_err_norm.argmin(dim=-2) # shape: (*batch_shape, A) final_success = (final_err_norm < tolerance).any(dim=-2) # shape: (*batch_shape, A) best_q_list = [] for i, js in enumerate(joint_slices): sel = best_idx[..., i].unsqueeze(-1) # shape: (*batch_shape, 1) q_seg = q[..., js] # shape: (*batch_shape, num_retries, dof_range) indices = sel.unsqueeze(-1).expand(*sel.shape, q_seg.shape[-1]) best_q_list.append(torch.gather(q_seg, dim=-2, index=indices).squeeze(-2)) best_q = torch.cat(best_q_list, dim=-1) return best_q, final_success