fastdev.xform.warp_transforms

Module Contents

fastdev.xform.warp_transforms.transform_points(pts: jaxtyping.Float[torch.Tensor, ... n 3], tf_mat: jaxtyping.Float[torch.Tensor, ... 4 4]) jaxtyping.Float[torch.Tensor, ... n 3][source]

Apply a transformation matrix on a set of 3D points.

Parameters:
  • pts (torch.Tensor) – 3D points, could be [… n 3]

  • tf_mat (torch.Tensor) – Transformation matrix, could be [… 4 4]

Returns:

Transformed pts in shape of [… n 3]

Return type:

jaxtyping.Float[torch.Tensor, … n 3]

Examples

>>> pts = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> tf_mat = torch.tensor([[0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 2.0], [1.0, 0.0, 0.0, 3.0], [0.0, 0.0, 0.0, 1.0]])
>>> transform_points(pts, tf_mat)
tensor([[3., 5., 4.],
        [6., 8., 7.]])

Note

The dimension number of pts and tf_mat should be the same. The batch dimensions (…) are broadcasted (and thus must be broadcastable). We don’t adopt the shapes [… 3] and [… 4 4] because there is no real broadcasted vector-matrix multiplication in pytorch. [… 3] and [… 4 4] will be converted to [… 1 3] and [… 4 4] and apply a broadcasted matrix-matrix multiplication.

fastdev.xform.warp_transforms.rotate_points(pts: jaxtyping.Float[torch.Tensor, ... n 3], tf_mat: jaxtyping.Float[torch.Tensor, ... 3 3]) jaxtyping.Float[torch.Tensor, ... n 3][source]

Apply a rotation matrix on a set of 3D points.

Parameters:
  • pts (torch.Tensor) – 3D points in shape [… n 3].

  • rot_mat (torch.Tensor) – Rotation matrix in shape [… 3 3].

  • tf_mat (jaxtyping.Float[torch.Tensor, ... 3 3])

Returns:

Rotated points in shape [… n 3].

Return type:

torch.Tensor