[docs]defmasked_gather(points:torch.Tensor,idx:torch.Tensor)->torch.Tensor:""" Helper function for torch.gather to collect the points at the given indices in idx where some of the indices might be -1 to indicate padding. These indices are first replaced with 0. Then the points are gathered after which the padded values are set to 0.0. Args: points: (N, P, D) float32 tensor of points idx: (N, K) or (N, P, K) long tensor of indices into points, where some indices are -1 to indicate padding Returns: selected_points: (N, K, D) float32 tensor of points at the given indices """iflen(idx)!=len(points):raiseValueError("points and idx must have the same batch dimension")N,P,D=points.shapeifidx.ndim==3:# Case: KNN, Ball Query where idx is of shape (N, P', K)# where P' is not necessarily the same as P as the# points may be gathered from a different pointcloud.K=idx.shape[2]# Match dimensions for points and indicesidx_expanded=idx[...,None].expand(-1,-1,-1,D)points=points[:,:,None,:].expand(-1,-1,K,-1)elifidx.ndim==2:# Farthest point sampling where idx is of shape (N, K)idx_expanded=idx[...,None].expand(-1,-1,D)else:raiseValueError("idx format is not supported %s"%repr(idx.shape))idx_expanded_mask=idx_expanded.eq(-1)idx_expanded=idx_expanded.clone()# Replace -1 values with 0 for gatheridx_expanded[idx_expanded_mask]=0# Gather pointsselected_points=points.gather(dim=1,index=idx_expanded)# Replace padded valuesselected_points[idx_expanded_mask]=0.0returnselected_points