fromtypingimportOptionalimportnumpyasnpimporttorchfromjaxtypingimportFloat,Int64fromfastdev.extensionimportFDEV_EXTtry:importfpsampleFPSAMPLE_AVAILABLE=TrueexceptImportError:FPSAMPLE_AVAILABLE=False# TODO: modify the function signature based on pytorch3d's implementation
[docs]defsample_farthest_points(points:Float[torch.Tensor,"*B N 3"],num_samples:int,random_start:bool=False)->Int64[torch.Tensor,"*B num_samples"]:"""Sample farthest points. Args: points (Tensor): input points in shape (B, N, 3) or (N, 3) num_samples (int): number of samples Returns: Tensor: indices of farthest points in shape (B, num_samples) or (num_samples,) """ifpoints.ndim!=2andpoints.ndim!=3orpoints.shape[-1]!=3:raiseValueError("points should be in shape (B, N, 3) or (N, 3).")is_batch_input=points.dim()==3ifnotis_batch_input:points=points.unsqueeze(0)ifrandom_start:start_idx=torch.randint(points.shape[1],(points.shape[0],),device=points.device)else:start_idx=torch.zeros((points.shape[0],),dtype=torch.long,device=points.device)indices=FDEV_EXT.load_module("fastdev_sample_farthest_points").sample_farthest_points(points,torch.full((points.shape[0],),fill_value=points.shape[1],dtype=torch.long,device=points.device),torch.full((points.shape[0],),fill_value=num_samples,dtype=torch.long,device=points.device),start_idx,)ifnotis_batch_input:returnindices.squeeze(0)else:returnindices
defsample_farthest_points_numpy(points:Float[np.ndarray,"N 3"],num_samples:int,start_idx:Optional[int]=None)->Int64[np.ndarray,"num_samples"]:# noqa: F821"""Sample farthest points using fpsample. Args: points (np.ndarray): input points in shape (N, 3) num_samples (int): number of samples Returns: np.ndarray: indices of farthest points in shape (num_samples,) """ifnotFPSAMPLE_AVAILABLE:raiseImportError("fpsample is not available, please install it via `pip install fpsample`.")ifpoints.ndim!=2orpoints.shape[-1]!=3:raiseValueError("points should be in shape (N, 3), no batch support.")returnfpsample.bucket_fps_kdline_sampling(points,num_samples,h=3,start_idx=start_idx)# type: ignore__all__=["sample_farthest_points"]