fastdev.nn.pointnet

PointNet implementation in PyTorch.

Adapted from: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/

Module Contents

class fastdev.nn.pointnet.PointNetEncoder(input_dim: int = 3, feature_dim: int = 1024, global_feature: bool = True, point_transform: bool = False, feature_transform: bool = False, pretrained_filename: Literal['models/pointnn/pointnet_cls_241031.safetensors'] | None = None)[source]

Bases: torch.nn.Module

Encoder for PointNet.

This implementation differs from the original repository by adding an option to apply a transformation to the input points (default: False). STNkd is reused for both point and feature transformations, with STN3d being a special case of STNkd.

Forward method:
Args:

x (torch.Tensor): Input tensor of shape (batch_size, channels, num_points).

Returns:

torch.Tensor: Extracted features. torch.Tensor: Transformation matrix applied to the points. torch.Tensor: Transformed features.

Reference: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet_utils.py

Parameters:
  • input_dim (int)

  • feature_dim (int)

  • global_feature (bool)

  • point_transform (bool)

  • feature_transform (bool)

  • pretrained_filename (Optional[Literal['models/pointnn/pointnet_cls_241031.safetensors']])

global_feat = True[source]
point_transform = False[source]
feature_transform = False[source]
conv1[source]
conv2[source]
conv3[source]
bn1[source]
bn2[source]
bn3[source]
relu[source]
forward(x: torch.Tensor)[source]

Forward pass.

Parameters:

x (torch.Tensor) – Input tensor of shape (batch_size, channels, num_points).

Returns:

Extracted features. torch.Tensor: Transformation matrix applied to the points. torch.Tensor: Transformation matrix applied to the features.

Return type:

torch.Tensor

class fastdev.nn.pointnet.PointNetCls(k: int = 40, normal_channel: bool = True, point_transform: bool = False, feature_transform: bool = False, pretrained_filename: Literal['models/pointnn/pointnet_cls_241031.safetensors'] | None = None)[source]

Bases: torch.nn.Module

Parameters:
  • k (int)

  • normal_channel (bool)

  • point_transform (bool)

  • feature_transform (bool)

  • pretrained_filename (Optional[Literal['models/pointnn/pointnet_cls_241031.safetensors']])

feat[source]
fc1[source]
fc2[source]
fc3[source]
dropout[source]
bn1[source]
bn2[source]
relu[source]
log_softmax[source]
forward(x)[source]
fastdev.nn.pointnet.feature_transform_reguliarzer(trans: torch.Tensor) torch.Tensor[source]
Parameters:

trans (torch.Tensor)

Return type:

torch.Tensor

class fastdev.nn.pointnet.PointNetSemSeg(num_classes: int, point_transform: bool = False, feature_transform: bool = False)[source]

Bases: torch.nn.Module

Parameters:
  • num_classes (int)

  • point_transform (bool)

  • feature_transform (bool)

k[source]
feat[source]
conv1[source]
conv2[source]
conv3[source]
conv4[source]
bn1[source]
bn2[source]
bn3[source]
relu[source]
log_softmax[source]
forward(x)[source]