fastdev.geom.utils

Module Contents

fastdev.geom.utils.masked_gather(points: torch.Tensor, idx: torch.Tensor) torch.Tensor[source]

Helper function for torch.gather to collect the points at the given indices in idx where some of the indices might be -1 to indicate padding. These indices are first replaced with 0. Then the points are gathered after which the padded values are set to 0.0.

Parameters:
  • points (torch.Tensor) – (N, P, D) float32 tensor of points

  • idx (torch.Tensor) – (N, K) or (N, P, K) long tensor of indices into points, where some indices are -1 to indicate padding

Returns:

(N, K, D) float32 tensor of points

at the given indices

Return type:

selected_points