Source code for fastdev.robo.kinematics

# ruff: noqa: F821
from typing import TYPE_CHECKING, Optional, Tuple, Union

import torch
from einops import rearrange
from jaxtyping import Float, Int, Bool

from fastdev.xform.rotation import (
    axis_angle_to_matrix,
    matrix_to_quaternion,
    quaternion_invert,
    quaternion_multiply,
    quaternion_to_axis_angle_vector,
)
from fastdev.xform.transforms import rot_tl_to_tf_mat

if TYPE_CHECKING:
    from fastdev.robo.articulation import Articulation


[docs] def forward_kinematics( joint_values: Float[torch.Tensor, "... total_num_joints"], articulation: "Articulation", root_poses: Optional[Float[torch.Tensor, "... num_arti 4 4"]] = None, ) -> Float[torch.Tensor, "... total_num_links 4 4"]: batch_shape = joint_values.shape[:-1] total_num_links = articulation.total_num_links device = joint_values.device requires_grad = joint_values.requires_grad or (root_poses is not None and root_poses.requires_grad) link_poses = torch.eye(4, device=device, requires_grad=requires_grad).repeat(*batch_shape, total_num_links, 1, 1) if root_poses is None: root_poses = torch.eye(4, device=device).expand(*batch_shape, articulation.num_arti, 4, 4) joint_axes = articulation.get_packed_full_joint_axes(return_tensors="pt") pris_jnt_tf = rot_tl_to_tf_mat(tl=joint_axes * joint_values.unsqueeze(-1)) # type: ignore rev_jnt_tf = rot_tl_to_tf_mat(rot_mat=axis_angle_to_matrix(joint_axes, joint_values)) # type: ignore link_topo_indices = articulation.get_packed_link_indices_topological_order(return_tensors="pt") parent_link_indices = articulation.get_packed_parent_link_indices(return_tensors="pt") link_joint_types = articulation.get_packed_link_joint_types(return_tensors="pt") link_joint_indices = articulation.get_packed_link_joint_indices(return_tensors="pt") link_joint_origins = articulation.get_packed_link_joint_origins(return_tensors="pt") joint_first_indices = articulation.get_joint_first_indices(return_tensors="pt") link_first_indices = articulation.get_link_first_indices(return_tensors="pt") identity_matrix = torch.eye(4, device=device).expand(*batch_shape, 4, 4) for arti_idx in range(articulation.num_arti): link_start = link_first_indices[arti_idx].item() link_end = ( link_first_indices[arti_idx + 1].item() if arti_idx < len(link_first_indices) - 1 else len(link_topo_indices) ) joint_start = joint_first_indices[arti_idx].item() for local_link_idx in link_topo_indices[link_start:link_end]: # type: ignore glb_link_idx = local_link_idx + link_start joint_type = link_joint_types[glb_link_idx] if joint_type == -1: # Root link link_poses[..., glb_link_idx, :, :] = root_poses[..., arti_idx, :, :] continue glb_parent_idx = parent_link_indices[glb_link_idx] + link_start parent_pose = link_poses[..., glb_parent_idx, :, :] if joint_type == 1: # Prismatic glb_joint_idx = link_joint_indices[glb_link_idx] + joint_start local_tf = pris_jnt_tf[..., glb_joint_idx, :, :] elif joint_type == 2: # Revolute glb_joint_idx = link_joint_indices[glb_link_idx] + joint_start local_tf = rev_jnt_tf[..., glb_joint_idx, :, :] else: # Fixed local_tf = identity_matrix origin = link_joint_origins[glb_link_idx] link_poses[..., glb_link_idx, :, :] = (parent_pose @ origin) @ local_tf return link_poses
[docs] def calculate_jacobian( joint_values: Float[torch.Tensor, "... total_num_joints"], target_link_indices: Int[torch.Tensor, "num_arti"], articulation: "Articulation", root_poses: Optional[Float[torch.Tensor, "... num_arti 4 4"]] = None, return_target_link_poses: bool = False, ) -> Union[ Float[torch.Tensor, "... 6 total_num_joints"], Tuple[Float[torch.Tensor, "... 6 total_num_joints"], Float[torch.Tensor, "... num_arti 4 4"]], ]: """Calculate the geometric Jacobian for the end-effector for each joint in the articulated robot. The Jacobian is computed in the world frame. For revolute joints, the linear velocity component is computed as the cross product of the joint axis (in world frame) and the vector from the joint position to the end-effector position, while the angular velocity component is the joint axis. For prismatic joints, the linear velocity component is the joint axis and the angular component is zero. """ batch_shape = joint_values.shape[:-1] device = joint_values.device # compute full forward kinematics for all links link_poses = forward_kinematics(joint_values, articulation, root_poses=root_poses) total_num_joints = joint_values.shape[-1] jacobian = torch.zeros(*batch_shape, 6, total_num_joints, device=device, dtype=joint_values.dtype) # extract articulation parameters joint_axes = articulation.get_packed_full_joint_axes(return_tensors="pt") # (total_num_joints, 3) link_joint_types = articulation.get_packed_link_joint_types(return_tensors="pt") # (total_num_links,) link_joint_indices = articulation.get_packed_link_joint_indices(return_tensors="pt") # (total_num_links,) link_joint_origins = articulation.get_packed_link_joint_origins(return_tensors="pt") # (total_num_links, 4, 4) parent_link_indices = articulation.get_packed_parent_link_indices(return_tensors="pt") # (total_num_links,) joint_first_indices = articulation.get_joint_first_indices(return_tensors="pt") # (num_arti,) link_first_indices = articulation.get_link_first_indices(return_tensors="pt") # (num_arti,) ancestor_link_masks: torch.Tensor = articulation.get_packed_ancestor_links_mask( # type: ignore target_link_indices, return_tensors="pt" ) # (total_num_links,) total_num_links = articulation.total_num_links num_arti = articulation.num_arti if return_target_link_poses: target_link_poses = torch.zeros(*batch_shape, num_arti, 4, 4, device=device, dtype=joint_values.dtype) for arti_idx in range(num_arti): # determine the link and joint index ranges for the current articulation link_start = int(link_first_indices[arti_idx].item()) link_end = int(link_first_indices[arti_idx + 1].item()) if arti_idx < num_arti - 1 else total_num_links joint_start = int(joint_first_indices[arti_idx].item()) # select the designated end-effector for this articulation and extract its position eef_idx = target_link_indices[arti_idx] + link_start eef_pose = link_poses[..., eef_idx, :, :] if return_target_link_poses: target_link_poses[..., arti_idx, :, :] = eef_pose p_eef = eef_pose[..., :3, 3] # identify ancestor links using the respective mask link_mask = ancestor_link_masks[link_start:link_end] # (L,) valid_local_links = torch.nonzero(link_mask, as_tuple=True)[0] if valid_local_links.numel() == 0: continue valid_global_links = valid_local_links + link_start # filter for joints that are either prismatic (1) or revolute (2) joint_types = link_joint_types[valid_global_links].to(torch.int64) # (N,) valid_joint_mask = (joint_types == 1) | (joint_types == 2) if valid_joint_mask.sum() == 0: continue valid_global_links = valid_global_links[valid_joint_mask] joint_types = joint_types[valid_joint_mask] joint_ids = (link_joint_indices[valid_global_links] + joint_start).long() # gather parent's transformation for valid links parent_indices = parent_link_indices[valid_global_links] + link_start parent_pose = torch.index_select(link_poses, dim=-3, index=parent_indices) origin_tf = link_joint_origins[valid_global_links].view(*(1,) * len(batch_shape), -1, 4, 4) joint_tf = parent_pose @ origin_tf p_joint = joint_tf[..., :3, 3] R_joint = joint_tf[..., :3, :3] # transform local joint axis into the world frame local_axis = joint_axes[joint_ids].view(*(1,) * len(batch_shape), -1, 3) # type: ignore axis_world = torch.matmul(R_joint, local_axis.unsqueeze(-1)).squeeze(-1) diff = p_eef.unsqueeze(-2) - p_joint is_revolute = joint_types == 2 rev_mask = is_revolute.view(*(1,) * len(batch_shape), -1, 1).to(joint_values.dtype) # for revolute joints, compute linear velocity using cross product; for prismatic, use direct axis lin_component = torch.cross(axis_world, diff, dim=-1) * rev_mask + axis_world * (1 - rev_mask) ang_component = axis_world * rev_mask update = torch.cat((lin_component, ang_component), dim=-1) # shape: (*batch, N, 6) jacobian[..., :, joint_ids] = rearrange(update, "... n m -> ... m n") if return_target_link_poses: return jacobian, target_link_poses else: return jacobian
@torch.no_grad()
[docs] def delta_pose(T_current: torch.Tensor, pos_target: torch.Tensor, quat_target: torch.Tensor) -> torch.Tensor: """Compute the error between current and target poses.""" pos_error = pos_target - T_current[..., :3, 3] current_quat = matrix_to_quaternion(T_current[..., :3, :3]) quat_err = quaternion_multiply(quat_target, quaternion_invert(current_quat)) rot_error = quaternion_to_axis_angle_vector(quat_err) return torch.cat([pos_error, rot_error], dim=-1) # shape (..., 6)
@torch.no_grad()
[docs] def inverse_kinematics( target_link_poses: Float[torch.Tensor, "... num_arti 4 4"], target_link_indices: Int[torch.Tensor, "num_arti"], articulation: "Articulation", max_iterations: int = 100, learning_rate: float = 0.1, tolerance: float = 1e-6, damping: float = 0.01, num_retries: int = 50, init_joint_values: Optional[Float[torch.Tensor, "... total_num_dofs"]] = None, jitter_strength: float = 1.0, ) -> Tuple[Float[torch.Tensor, "... total_num_dofs"], Bool[torch.Tensor, "... num_arti"]]: batch_shape = target_link_poses.shape[:-3] device = target_link_poses.device num_arti = articulation.num_arti total_num_dofs = articulation.total_num_dofs if init_joint_values is None: init_joint_values = ( articulation.get_packed_zero_joint_values(return_tensors="pt").expand(*batch_shape, -1).clone() # type: ignore ) if not articulation.has_none_joint_limits: lower_limit, upper_limit = articulation.joint_limits[..., 0], articulation.joint_limits[..., 1] # type: ignore lower_q_init = (lower_limit - init_joint_values) * jitter_strength + init_joint_values upper_q_init = (upper_limit - init_joint_values) * jitter_strength + init_joint_values lower_q_init = lower_q_init.unsqueeze(-2) # add a retry dimension upper_q_init = upper_q_init.unsqueeze(-2) q_init = lower_q_init + (upper_q_init - lower_q_init) * torch.rand( *batch_shape, num_retries, total_num_dofs, device=device, dtype=target_link_poses.dtype ) else: # NOTE use a small range of 0.1 when joint limits are not specified lower_q_init = init_joint_values - 0.1 upper_q_init = init_joint_values + 0.1 lower_q_init = lower_q_init.unsqueeze(-2) # add a retry dimension upper_q_init = upper_q_init.unsqueeze(-2) q_init = lower_q_init + (upper_q_init - lower_q_init) * torch.rand( *batch_shape, num_retries, total_num_dofs, device=device, dtype=target_link_poses.dtype ) q = q_init # shape: (*batch_shape, num_retries, total_num_dofs) pos_target = target_link_poses[..., :3, 3].unsqueeze(-3) # rely on broadcasting quat_target = matrix_to_quaternion(target_link_poses[..., :3, :3]).unsqueeze(-3) joint_first_indices_tensor = articulation.get_joint_first_indices(return_tensors="pt") joint_limits = articulation.get_packed_joint_limits(return_tensors="pt") jfi = joint_first_indices_tensor.tolist() joint_slices = [slice(jfi[i], jfi[i + 1]) for i in range(num_arti - 1)] + [slice(jfi[-1], total_num_dofs)] for _ in range(max_iterations): J, current_poses = calculate_jacobian(q, target_link_indices, articulation, return_target_link_poses=True) err = delta_pose(current_poses, pos_target, quat_target) # shape: (*batch_shape, num_retries, A, 6) err_norm = err.norm(dim=-1) # shape: (*batch_shape, num_retries, A) success = (err_norm < tolerance).any(dim=-2) # shape: (*batch_shape, A) if success.all(): break dq_list = [] # collect updates for each joint block for i, js in enumerate(joint_slices): err_i = err[..., i, :] # shape: (*batch_shape, num_retries, 6) J_i = J[..., js] # shape: (*batch_shape, num_retries, 6, dofs_i) reg = damping * torch.eye(6, device=device, dtype=q.dtype).expand(*J_i.shape[:-2], 6, 6) # regularize err_i_unsq = err_i.unsqueeze(-1) JJt_i = J_i @ J_i.transpose(-1, -2) + reg A_i = torch.linalg.solve(JJt_i, err_i_unsq) dq_i = (J_i.transpose(-1, -2) @ A_i).squeeze(-1) dq_list.append(dq_i) dq_total = torch.cat(dq_list, dim=-1) q = q + learning_rate * dq_total if joint_limits is not None: q = torch.clamp(q, min=joint_limits[..., 0], max=joint_limits[..., 1]) # type: ignore # select the best retry per articulation based on squared error norm _, current_poses = calculate_jacobian(q, target_link_indices, articulation, return_target_link_poses=True) final_err = delta_pose(current_poses, pos_target, quat_target) # shape: (*batch_shape, num_retries, A, 6) final_err_norm = final_err.norm(dim=-1) # shape: (*batch_shape, num_retries, A) final_success = (final_err_norm < tolerance).any(dim=-2) # shape: (*batch_shape, A) best_idx = final_err_norm.argmin(dim=-2) # shape: (*batch_shape, A) best_q_list = [] for i, js in enumerate(joint_slices): sel = best_idx[..., i].unsqueeze(-1) # shape: (*batch_shape, 1) q_seg = q[..., js] # shape: (*batch_shape, num_retries, dof_range) indices = sel.unsqueeze(-1).expand(*sel.shape, q_seg.shape[-1]) best_q_list.append(torch.gather(q_seg, dim=-2, index=indices).squeeze(-2)) best_q = torch.cat(best_q_list, dim=-1) return best_q, final_success