fastdev.nn.pointnet¶
PointNet implementation in PyTorch.
Adapted from: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/
Module Contents¶
- class fastdev.nn.pointnet.PointNetEncoder(input_dim: int = 3, feature_dim: int = 1024, global_feature: bool = True, point_transform: bool = False, feature_transform: bool = False, pretrained_filename: Literal['models/pointnn/pointnet_cls_241031.safetensors'] | None = None)[source]¶
Bases:
torch.nn.Module
Encoder for PointNet.
This implementation differs from the original repository by adding an option to apply a transformation to the input points (default: False). STNkd is reused for both point and feature transformations, with STN3d being a special case of STNkd.
- Forward method:
- Args:
x (torch.Tensor): Input tensor of shape (batch_size, channels, num_points).
- Returns:
torch.Tensor: Extracted features. torch.Tensor: Transformation matrix applied to the points. torch.Tensor: Transformed features.
Reference: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet_utils.py
- Parameters:
input_dim (int)
feature_dim (int)
global_feature (bool)
point_transform (bool)
feature_transform (bool)
pretrained_filename (Optional[Literal['models/pointnn/pointnet_cls_241031.safetensors']])
- forward(x: torch.Tensor)[source]¶
Forward pass.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, channels, num_points).
- Returns:
Extracted features. torch.Tensor: Transformation matrix applied to the points. torch.Tensor: Transformation matrix applied to the features.
- Return type:
torch.Tensor
- class fastdev.nn.pointnet.PointNetCls(k: int = 40, normal_channel: bool = True, point_transform: bool = False, feature_transform: bool = False, pretrained_filename: Literal['models/pointnn/pointnet_cls_241031.safetensors'] | None = None)[source]¶
Bases:
torch.nn.Module
- Parameters:
k (int)
normal_channel (bool)
point_transform (bool)
feature_transform (bool)
pretrained_filename (Optional[Literal['models/pointnn/pointnet_cls_241031.safetensors']])
- fastdev.nn.pointnet.feature_transform_reguliarzer(trans: torch.Tensor) torch.Tensor [source]¶
- Parameters:
trans (torch.Tensor)
- Return type:
torch.Tensor