fastdev.nn.pointnet =================== .. py:module:: fastdev.nn.pointnet .. autoapi-nested-parse:: PointNet implementation in PyTorch. Adapted from: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/ Module Contents --------------- .. py:class:: PointNetEncoder(input_dim: int = 3, feature_dim: int = 1024, global_feature: bool = True, point_transform: bool = False, feature_transform: bool = False, pretrained_filename: Optional[Literal['models/pointnn/pointnet_cls_241031.safetensors']] = None) Bases: :py:obj:`torch.nn.Module` Encoder for PointNet. This implementation differs from the original repository by adding an option to apply a transformation to the input points (default: False). STNkd is reused for both point and feature transformations, with STN3d being a special case of STNkd. Forward method: Args: x (torch.Tensor): Input tensor of shape (batch_size, channels, num_points). Returns: torch.Tensor: Extracted features. torch.Tensor: Transformation matrix applied to the points. torch.Tensor: Transformed features. Reference: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet_utils.py .. py:attribute:: global_feat :value: True .. py:attribute:: point_transform :value: False .. py:attribute:: feature_transform :value: False .. py:attribute:: conv1 .. py:attribute:: conv2 .. py:attribute:: conv3 .. py:attribute:: bn1 .. py:attribute:: bn2 .. py:attribute:: bn3 .. py:attribute:: relu .. py:method:: forward(x: torch.Tensor) Forward pass. :param x: Input tensor of shape (batch_size, channels, num_points). :type x: torch.Tensor :returns: Extracted features. torch.Tensor: Transformation matrix applied to the points. torch.Tensor: Transformation matrix applied to the features. :rtype: torch.Tensor .. py:class:: PointNetCls(k: int = 40, normal_channel: bool = True, point_transform: bool = False, feature_transform: bool = False, pretrained_filename: Optional[Literal['models/pointnn/pointnet_cls_241031.safetensors']] = None) Bases: :py:obj:`torch.nn.Module` .. py:attribute:: feat .. py:attribute:: fc1 .. py:attribute:: fc2 .. py:attribute:: fc3 .. py:attribute:: dropout .. py:attribute:: bn1 .. py:attribute:: bn2 .. py:attribute:: relu .. py:attribute:: log_softmax .. py:method:: forward(x) .. py:function:: feature_transform_reguliarzer(trans: torch.Tensor) -> torch.Tensor .. py:class:: PointNetSemSeg(num_classes: int, point_transform: bool = False, feature_transform: bool = False) Bases: :py:obj:`torch.nn.Module` .. py:attribute:: k .. py:attribute:: feat .. py:attribute:: conv1 .. py:attribute:: conv2 .. py:attribute:: conv3 .. py:attribute:: conv4 .. py:attribute:: bn1 .. py:attribute:: bn2 .. py:attribute:: bn3 .. py:attribute:: relu .. py:attribute:: log_softmax .. py:method:: forward(x)