# type: ignore
"""
Point Transformer - V3 Mode1
Adapted from: https://github.com/Pointcept/Pointcept
This module requires the installation of the following packages:
- addict: pip install addict
- spconv: https://github.com/traveller59/spconv?tab=readme-ov-file#spconv-spatially-sparse-convolution-library
- torch-scatter: https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation
- flash-attention: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
Original Author: Xiaoyang Wu ([email protected])
Please cite their work if you use the following code in your research paper.
"""
import math
import sys
from collections import OrderedDict
from functools import partial
from typing import Optional, Union
import spconv.pytorch as spconv
import torch
import torch.nn as nn
import torch_scatter
from addict import Dict
try:
import flash_attn
except ImportError:
try:
import ocnn # not really necessary
except ImportError:
# Adapted from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
[docs]
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
[docs]
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
[docs]
self.drop_prob = drop_prob
[docs]
self.scale_by_keep = scale_by_keep
[docs]
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/serialization/z_order.py
[docs]
class KeyLUT:
def __init__(self):
r256 = torch.arange(256, dtype=torch.int64)
r512 = torch.arange(512, dtype=torch.int64)
zero = torch.zeros(256, dtype=torch.int64)
device = torch.device("cpu")
self._encode = {
device: (
self.xyz2key(r256, zero, zero, 8),
self.xyz2key(zero, r256, zero, 8),
self.xyz2key(zero, zero, r256, 8),
)
}
self._decode = {device: self.key2xyz(r512, 9)}
[docs]
def encode_lut(self, device=torch.device("cpu")):
if device not in self._encode:
cpu = torch.device("cpu")
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
return self._encode[device]
[docs]
def decode_lut(self, device=torch.device("cpu")):
if device not in self._decode:
cpu = torch.device("cpu")
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
return self._decode[device]
[docs]
def xyz2key(self, x, y, z, depth):
key = torch.zeros_like(x)
for i in range(depth):
mask = 1 << i
key = key | ((x & mask) << (2 * i + 2)) | ((y & mask) << (2 * i + 1)) | ((z & mask) << (2 * i + 0))
return key
[docs]
def key2xyz(self, key, depth):
x = torch.zeros_like(key)
y = torch.zeros_like(key)
z = torch.zeros_like(key)
for i in range(depth):
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
return x, y, z
_key_lut = KeyLUT()
[docs]
def xyz2key(
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
b: Optional[Union[torch.Tensor, int]] = None,
depth: int = 16,
):
r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
based on pre-computed look up tables. The speed of this function is much
faster than the method based on for-loop.
Args:
x (torch.Tensor): The x coordinate.
y (torch.Tensor): The y coordinate.
z (torch.Tensor): The z coordinate.
b (torch.Tensor or int): The batch index of the coordinates, and should be
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
"""
EX, EY, EZ = _key_lut.encode_lut(x.device)
x, y, z = x.long(), y.long(), z.long()
mask = 255 if depth > 8 else (1 << depth) - 1
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
if depth > 8:
mask = (1 << (depth - 8)) - 1
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
key = key16 << 24 | key
if b is not None:
b = b.long()
key = b << 48 | key
return key
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/serialization/hilbert.py
[docs]
def right_shift(binary, k=1, axis=-1):
"""Right shift an array of binary values.
Parameters:
-----------
binary: An ndarray of binary values.
k: The number of bits to shift. Default 1.
axis: The axis along which to shift. Default -1.
Returns:
--------
Returns an ndarray with zero prepended and the ends truncated, along
whatever axis was specified."""
# If we're shifting the whole thing, just return zeros.
if binary.shape[axis] <= k:
return torch.zeros_like(binary)
# Determine the padding pattern.
# padding = [(0,0)] * len(binary.shape)
# padding[axis] = (k,0)
# Determine the slicing pattern to eliminate just the last one.
slicing = [slice(None)] * len(binary.shape)
slicing[axis] = slice(None, -k)
shifted = torch.nn.functional.pad(binary[tuple(slicing)], (k, 0), mode="constant", value=0)
return shifted
[docs]
def gray2binary(gray, axis=-1):
"""Convert an array of Gray codes back into binary values.
Parameters:
-----------
gray: An ndarray of gray codes.
axis: The axis along which to perform Gray decoding. Default=-1.
Returns:
--------
Returns an ndarray of binary values.
"""
# Loop the log2(bits) number of times necessary, with shift and xor.
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
while shift > 0:
gray = torch.logical_xor(gray, right_shift(gray, shift))
shift = torch.div(shift, 2, rounding_mode="floor")
return gray
[docs]
def hilbert_encode_(locs, num_dims, num_bits):
"""Decode an array of locations in a hypercube into a Hilbert integer.
This is a vectorized-ish version of the Hilbert curve implementation by John
Skilling as described in:
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
Params:
-------
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
which each dimension runs from 0 to 2**num_bits-1. The shape can
be arbitrary, as long as the last dimension of the same has size
num_dims.
num_dims - The dimensionality of the hypercube. Integer.
num_bits - The number of bits for each dimension. Integer.
Returns:
--------
The output is an ndarray of uint64 integers with the same shape as the
input, excluding the last dimension, which needs to be num_dims.
"""
# Keep around the original shape for later.
orig_shape = locs.shape
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
bitpack_mask_rev = bitpack_mask.flip(-1)
if orig_shape[-1] != num_dims:
raise ValueError(
"""
The shape of locs was surprising in that the last dimension was of size
%d, but num_dims=%d. These need to be equal.
"""
% (orig_shape[-1], num_dims)
)
if num_dims * num_bits > 63:
raise ValueError(
"""
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
into a int64. Are you sure you need that many points on your Hilbert
curve?
"""
% (num_dims, num_bits, num_dims * num_bits)
)
# Treat the location integers as 64-bit unsigned and then split them up into
# a sequence of uint8s. Preserve the association by dimension.
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
# Now turn these into bits and truncate to num_bits.
gray = locs_uint8.unsqueeze(-1).bitwise_and(bitpack_mask_rev).ne(0).byte().flatten(-2, -1)[..., -num_bits:]
# Run the decoding process the other way.
# Iterate forwards through the bits.
for bit in range(0, num_bits):
# Iterate forwards through the dimensions.
for dim in range(0, num_dims):
# Identify which ones have this bit active.
mask = gray[:, dim, bit]
# Where this bit is on, invert the 0 dimension for lower bits.
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], mask[:, None])
# Where the bit is off, exchange the lower bits with the 0 dimension.
to_flip = torch.logical_and(
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
)
gray[:, dim, bit + 1 :] = torch.logical_xor(gray[:, dim, bit + 1 :], to_flip)
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
# Now flatten out.
gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
# Convert Gray back to binary.
hh_bin = gray2binary(gray)
# Pad back out to 64 bits.
extra_dims = 64 - num_bits * num_dims
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
# Convert binary values into uint8s.
hh_uint8 = (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask).sum(2).squeeze().type(torch.uint8)
# Convert uint8s into uint64s.
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
return hh_uint64
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/serialization/default.py
@torch.inference_mode()
[docs]
def encode(grid_coord, batch=None, depth=16, order="z"):
assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
if order == "z":
code = z_order_encode(grid_coord, depth=depth)
elif order == "z-trans":
code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
elif order == "hilbert":
code = hilbert_encode(grid_coord, depth=depth)
elif order == "hilbert-trans":
code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
else:
raise NotImplementedError
if batch is not None:
batch = batch.long()
code = batch << depth * 3 | code
return code
[docs]
def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
# we block the support to batch, maintain batched code in Point class
code = xyz2key(x, y, z, b=None, depth=depth)
return code
[docs]
def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/misc.py
@torch.inference_mode()
[docs]
def offset2bincount(offset):
return torch.diff(offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long))
@torch.inference_mode()
[docs]
def offset2batch(offset):
bincount = offset2bincount(offset)
return torch.arange(len(bincount), device=offset.device, dtype=torch.long).repeat_interleave(bincount)
@torch.inference_mode()
[docs]
def batch2offset(batch):
return torch.cumsum(batch.bincount(), dim=0).long()
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/structure.py
[docs]
class Point(Dict):
"""
Point Structure of Pointcept
A Point (point cloud) in Pointcept is a dictionary that contains various properties of
a batched point cloud. The property with the following names have a specific definition
as follows:
- "coord": original coordinate of point cloud;
- "grid_coord": grid coordinate for specific grid size (related to GridSampling);
Point also support the following optional attributes:
- "offset": if not exist, initialized as batch size is 1;
- "batch": if not exist, initialized as batch size is 1;
- "feat": feature of point cloud, default input of model;
- "grid_size": Grid size of point cloud (related to GridSampling);
(related to Serialization)
- "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
- "serialized_code": a list of serialization codes;
- "serialized_order": a list of serialization order determined by code;
- "serialized_inverse": a list of inverse mapping determined by code;
(related to Sparsify: SpConv)
- "sparse_shape": Sparse shape for Sparse Conv Tensor;
- "sparse_conv_feat": SparseConvTensor init with information provide by Point;
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# If one of "offset" or "batch" do not exist, generate by the existing one
if "batch" not in self.keys() and "offset" in self.keys():
self["batch"] = offset2batch(self.offset)
elif "offset" not in self.keys() and "batch" in self.keys():
self["offset"] = batch2offset(self.batch)
[docs]
def serialization(self, order="z", depth=None, shuffle_orders=False):
"""
Point Cloud Serialization
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
"""
assert "batch" in self.keys()
if "grid_coord" not in self.keys():
# if you don't want to operate GridSampling in data augmentation,
# please add the following augmentation into your pipline:
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
# (adjust `grid_size` to what your want)
assert {"grid_size", "coord"}.issubset(self.keys())
self["grid_coord"] = torch.div(
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
).int()
if depth is None:
# Adaptive measure the depth of serialization cube (length = 2 ^ depth)
depth = int(self.grid_coord.max()).bit_length()
self["serialized_depth"] = depth
# Maximum bit length for serialization code is 63 (int64)
assert depth * 3 + len(self.offset).bit_length() <= 63
# Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
# Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
# cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
# We can unlock the limitation by optimizing the z-order encoding function if necessary.
assert depth <= 16
# The serialization codes are arranged as following structures:
# [Order1 ([n]),
# Order2 ([n]),
# ...
# OrderN ([n])] (k, n)
code = [encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order]
code = torch.stack(code)
order = torch.argsort(code)
inverse = torch.zeros_like(order).scatter_(
dim=1,
index=order,
src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1),
)
if shuffle_orders:
perm = torch.randperm(code.shape[0])
code = code[perm]
order = order[perm]
inverse = inverse[perm]
self["serialized_code"] = code
self["serialized_order"] = order
self["serialized_inverse"] = inverse
[docs]
def sparsify(self, pad=96):
"""
Point Cloud Serialization
Point cloud is sparse, here we use "sparsify" to specifically refer to
preparing "spconv.SparseConvTensor" for SpConv.
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
pad: padding sparse for sparse shape.
"""
assert {"feat", "batch"}.issubset(self.keys())
if "grid_coord" not in self.keys():
# if you don't want to operate GridSampling in data augmentation,
# please add the following augmentation into your pipline:
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
# (adjust `grid_size` to what your want)
assert {"grid_size", "coord"}.issubset(self.keys())
self["grid_coord"] = torch.div(
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
).int()
if "sparse_shape" in self.keys():
sparse_shape = self.sparse_shape
else:
sparse_shape = torch.add(torch.max(self.grid_coord, dim=0).values, pad).tolist()
sparse_conv_feat = spconv.SparseConvTensor(
features=self.feat,
indices=torch.cat([self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1).contiguous(),
spatial_shape=sparse_shape,
batch_size=self.batch[-1].tolist() + 1,
)
self["sparse_shape"] = sparse_shape
self["sparse_conv_feat"] = sparse_conv_feat
[docs]
def octreetization(self, depth=None, full_depth=None):
"""
Point Cloud Octreelization
Generate octree with OCNN
relay on ["grid_coord", "batch", "feat"]
"""
assert ocnn is not None, "Please follow https://github.com/octree-nn/ocnn-pytorch install ocnn."
assert {"grid_coord", "feat", "batch"}.issubset(self.keys())
# add 1 to make grid space support shift order
if depth is None:
if "depth" in self.keys():
depth = self.depth
else:
depth = int(self.grid_coord.max() + 1).bit_length()
if full_depth is None:
full_depth = 2
self["depth"] = depth
assert depth <= 16 # maximum in ocnn
# [0, 2**depth] -> [0, 2] -> [-1, 1]
coord = self.grid_coord / 2 ** (self.depth - 1) - 1.0
point = ocnn.octree.Points(
points=coord,
features=self.feat,
batch_id=self.batch.unsqueeze(-1),
batch_size=self.batch[-1] + 1,
)
octree = ocnn.octree.Octree(
depth=depth,
full_depth=full_depth,
batch_size=self.batch[-1] + 1,
device=coord.device,
)
octree.build_octree(point)
octree.construct_all_neigh()
self["octree"] = octree
# Adpted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/modules.py
[docs]
class PointModule(nn.Module):
r"""PointModule
placeholder, all module subclass from this will take Point in PointSequential.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
class PointSequential(PointModule):
r"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
"""
def __init__(self, *args, **kwargs):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
for name, module in kwargs.items():
if sys.version_info < (3, 6):
raise ValueError("kwargs only supported in py36+")
if name in self._modules:
raise ValueError("name exists.")
self.add_module(name, module)
[docs]
def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)):
raise IndexError("index {} is out of range".format(idx))
if idx < 0:
idx += len(self)
it = iter(self._modules.values())
for i in range(idx):
next(it)
return next(it)
[docs]
def __len__(self):
return len(self._modules)
[docs]
def add(self, module, name=None):
if name is None:
name = str(len(self._modules))
if name in self._modules:
raise KeyError("name exists")
self.add_module(name, module)
[docs]
def forward(self, input):
for k, module in self._modules.items():
# Point module
if isinstance(module, PointModule):
input = module(input)
# Spconv module
elif spconv.modules.is_spconv_module(module):
if isinstance(input, Point):
input.sparse_conv_feat = module(input.sparse_conv_feat)
input.feat = input.sparse_conv_feat.features
else:
input = module(input)
# PyTorch module
else:
if isinstance(input, Point):
input.feat = module(input.feat)
if "sparse_conv_feat" in input.keys():
input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(input.feat)
elif isinstance(input, spconv.SparseConvTensor):
if input.indices.shape[0] != 0:
input = input.replace_feature(module(input.features))
else:
input = module(input)
return input
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_prompt_training/prompt_driven_normalization.py
[docs]
class PDNorm(PointModule):
def __init__(
self,
num_features,
norm_layer,
context_channels=256,
conditions=("ScanNet", "S3DIS", "Structured3D"),
decouple=True,
adaptive=False,
):
super().__init__()
[docs]
self.conditions = conditions
[docs]
self.decouple = decouple
[docs]
self.adaptive = adaptive
if self.decouple:
self.norm = nn.ModuleList([norm_layer(num_features) for _ in conditions])
else:
self.norm = norm_layer
if self.adaptive:
self.modulation = nn.Sequential(nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True))
[docs]
def forward(self, point):
assert {"feat", "condition"}.issubset(point.keys())
if isinstance(point.condition, str):
condition = point.condition
else:
condition = point.condition[0]
if self.decouple:
assert condition in self.conditions
norm = self.norm[self.conditions.index(condition)]
else:
norm = self.norm
point.feat = norm(point.feat)
if self.adaptive:
assert "context" in point.keys()
shift, scale = self.modulation(point.context).chunk(2, dim=1)
point.feat = point.feat * (1.0 + scale) + shift
return point
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py
[docs]
class RPE(torch.nn.Module):
def __init__(self, patch_size, num_heads):
super().__init__()
[docs]
self.patch_size = patch_size
[docs]
self.num_heads = num_heads
[docs]
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
[docs]
self.rpe_num = 2 * self.pos_bnd + 1
[docs]
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
[docs]
def forward(self, coord):
idx = (
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
+ self.pos_bnd # relative position to positive index
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
)
out = self.rpe_table.index_select(0, idx.reshape(-1))
out = out.view(idx.shape + (-1,)).sum(3)
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
return out
[docs]
class SerializedAttention(PointModule):
def __init__(
self,
channels,
num_heads,
patch_size,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
order_index=0,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
):
super().__init__()
assert channels % num_heads == 0
[docs]
self.channels = channels
[docs]
self.num_heads = num_heads
[docs]
self.scale = qk_scale or (channels // num_heads) ** -0.5
[docs]
self.order_index = order_index
[docs]
self.upcast_attention = upcast_attention
[docs]
self.upcast_softmax = upcast_softmax
[docs]
self.enable_rpe = enable_rpe
[docs]
self.enable_flash = enable_flash
if enable_flash:
assert enable_rpe is False, "Set enable_rpe to False when enable Flash Attention"
assert upcast_attention is False, "Set upcast_attention to False when enable Flash Attention"
assert upcast_softmax is False, "Set upcast_softmax to False when enable Flash Attention"
assert flash_attn is not None, "Make sure flash_attn is installed."
self.patch_size = patch_size
self.attn_drop = attn_drop
else:
# when disable flash attention, we still don't want to use mask
# consequently, patch size will auto set to the
# min number of patch_size_max and number of points
self.patch_size_max = patch_size
self.patch_size = 0
self.attn_drop = torch.nn.Dropout(attn_drop)
[docs]
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
[docs]
self.proj = torch.nn.Linear(channels, channels)
[docs]
self.proj_drop = torch.nn.Dropout(proj_drop)
[docs]
self.softmax = torch.nn.Softmax(dim=-1)
[docs]
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
@torch.no_grad()
[docs]
def get_rel_pos(self, point, order):
K = self.patch_size
rel_pos_key = f"rel_pos_{self.order_index}"
if rel_pos_key not in point.keys():
grid_coord = point.grid_coord[order]
grid_coord = grid_coord.reshape(-1, K, 3)
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
return point[rel_pos_key]
@torch.no_grad()
[docs]
def get_padding_and_inverse(self, point):
pad_key = "pad"
unpad_key = "unpad"
cu_seqlens_key = "cu_seqlens_key"
if pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys():
offset = point.offset
bincount = offset2bincount(offset)
bincount_pad = (
torch.div(
bincount + self.patch_size - 1,
self.patch_size,
rounding_mode="trunc",
)
* self.patch_size
)
# only pad point when num of points larger than patch_size
mask_pad = bincount > self.patch_size
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
_offset = nn.functional.pad(offset, (1, 0))
_offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
pad = torch.arange(_offset_pad[-1], device=offset.device)
unpad = torch.arange(_offset[-1], device=offset.device)
cu_seqlens = []
for i in range(len(offset)):
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
if bincount[i] != bincount_pad[i]:
pad[_offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1]] = (
pad[
_offset_pad[i + 1] - 2 * self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[
i + 1
]
- self.patch_size
]
)
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
cu_seqlens.append(
torch.arange(
_offset_pad[i],
_offset_pad[i + 1],
step=self.patch_size,
dtype=torch.int32,
device=offset.device,
)
)
point[pad_key] = pad
point[unpad_key] = unpad
point[cu_seqlens_key] = nn.functional.pad(torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1])
return point[pad_key], point[unpad_key], point[cu_seqlens_key]
[docs]
def forward(self, point):
if not self.enable_flash:
self.patch_size = min(offset2bincount(point.offset).min().tolist(), self.patch_size_max)
H = self.num_heads
K = self.patch_size
C = self.channels
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
order = point.serialized_order[self.order_index][pad]
inverse = unpad[point.serialized_inverse[self.order_index]]
# padding and reshape feat and batch for serialized point patch
qkv = self.qkv(point.feat)[order]
if not self.enable_flash:
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
q, k, v = qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
# attn
if self.upcast_attention:
q = q.float()
k = k.float()
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
if self.enable_rpe:
attn = attn + self.rpe(self.get_rel_pos(point, order))
if self.upcast_softmax:
attn = attn.float()
attn = self.softmax(attn)
attn = self.attn_drop(attn).to(qkv.dtype)
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
else:
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
qkv.half().reshape(-1, 3, H, C // H),
cu_seqlens,
max_seqlen=self.patch_size,
dropout_p=self.attn_drop if self.training else 0,
softmax_scale=self.scale,
).reshape(-1, C)
feat = feat.to(qkv.dtype)
feat = feat[inverse]
# ffn
feat = self.proj(feat)
feat = self.proj_drop(feat)
point.feat = feat
return point
[docs]
class MLP(nn.Module):
def __init__(
self,
in_channels,
hidden_channels=None,
out_channels=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
[docs]
self.fc1 = nn.Linear(in_channels, hidden_channels)
[docs]
self.fc2 = nn.Linear(hidden_channels, out_channels)
[docs]
self.drop = nn.Dropout(drop)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
[docs]
class Block(PointModule):
def __init__(
self,
channels,
num_heads,
patch_size=48,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
pre_norm=True,
order_index=0,
cpe_indice_key=None,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
):
super().__init__()
[docs]
self.channels = channels
[docs]
self.pre_norm = pre_norm
[docs]
self.cpe = PointSequential(
spconv.SubMConv3d(
channels,
channels,
kernel_size=3,
bias=True,
indice_key=cpe_indice_key,
),
nn.Linear(channels, channels),
norm_layer(channels),
)
[docs]
self.norm1 = PointSequential(norm_layer(channels))
[docs]
self.attn = SerializedAttention(
channels=channels,
patch_size=patch_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
order_index=order_index,
enable_rpe=enable_rpe,
enable_flash=enable_flash,
upcast_attention=upcast_attention,
upcast_softmax=upcast_softmax,
)
[docs]
self.norm2 = PointSequential(norm_layer(channels))
[docs]
self.mlp = PointSequential(
MLP(
in_channels=channels,
hidden_channels=int(channels * mlp_ratio),
out_channels=channels,
act_layer=act_layer,
drop=proj_drop,
)
)
[docs]
self.drop_path = PointSequential(DropPath(drop_path) if drop_path > 0.0 else nn.Identity())
[docs]
def forward(self, point: Point):
shortcut = point.feat
point = self.cpe(point)
point.feat = shortcut + point.feat
shortcut = point.feat
if self.pre_norm:
point = self.norm1(point)
point = self.drop_path(self.attn(point))
point.feat = shortcut + point.feat
if not self.pre_norm:
point = self.norm1(point)
shortcut = point.feat
if self.pre_norm:
point = self.norm2(point)
point = self.drop_path(self.mlp(point))
point.feat = shortcut + point.feat
if not self.pre_norm:
point = self.norm2(point)
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
return point
[docs]
class SerializedPooling(PointModule):
def __init__(
self,
in_channels,
out_channels,
stride=2,
norm_layer=None,
act_layer=None,
reduce="max",
shuffle_orders=True,
traceable=True, # record parent and cluster
):
super().__init__()
[docs]
self.in_channels = in_channels
[docs]
self.out_channels = out_channels
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
# TODO: add support to grid pool (any stride)
assert reduce in ["sum", "mean", "min", "max"]
[docs]
self.shuffle_orders = shuffle_orders
[docs]
self.traceable = traceable
[docs]
self.proj = nn.Linear(in_channels, out_channels)
if norm_layer is not None:
self.norm = PointSequential(norm_layer(out_channels))
if act_layer is not None:
self.act = PointSequential(act_layer())
[docs]
def forward(self, point: Point):
pooling_depth = (math.ceil(self.stride) - 1).bit_length()
if pooling_depth > point.serialized_depth:
pooling_depth = 0
assert {
"serialized_code",
"serialized_order",
"serialized_inverse",
"serialized_depth",
}.issubset(point.keys()), "Run point.serialization() point cloud before SerializedPooling"
code = point.serialized_code >> pooling_depth * 3
code_, cluster, counts = torch.unique(
code[0],
sorted=True,
return_inverse=True,
return_counts=True,
)
# indices of point sorted by cluster, for torch_scatter.segment_csr
_, indices = torch.sort(cluster)
# index pointer for sorted point, for torch_scatter.segment_csr
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
# head_indices of each cluster, for reduce attr e.g. code, batch
head_indices = indices[idx_ptr[:-1]]
# generate down code, order, inverse
code = code[:, head_indices]
order = torch.argsort(code)
inverse = torch.zeros_like(order).scatter_(
dim=1,
index=order,
src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1),
)
if self.shuffle_orders:
perm = torch.randperm(code.shape[0])
code = code[perm]
order = order[perm]
inverse = inverse[perm]
# collect information
point_dict = Dict(
feat=torch_scatter.segment_csr(self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce),
coord=torch_scatter.segment_csr(point.coord[indices], idx_ptr, reduce="mean"),
grid_coord=point.grid_coord[head_indices] >> pooling_depth,
serialized_code=code,
serialized_order=order,
serialized_inverse=inverse,
serialized_depth=point.serialized_depth - pooling_depth,
batch=point.batch[head_indices],
)
if "condition" in point.keys():
point_dict["condition"] = point.condition
if "context" in point.keys():
point_dict["context"] = point.context
if self.traceable:
point_dict["pooling_inverse"] = cluster
point_dict["pooling_parent"] = point
point = Point(point_dict)
if self.norm is not None:
point = self.norm(point)
if self.act is not None:
point = self.act(point)
point.sparsify()
return point
[docs]
class SerializedUnpooling(PointModule):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
norm_layer=None,
act_layer=None,
traceable=False, # record parent and cluster
):
super().__init__()
[docs]
self.proj = PointSequential(nn.Linear(in_channels, out_channels))
[docs]
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
if norm_layer is not None:
self.proj.add(norm_layer(out_channels))
self.proj_skip.add(norm_layer(out_channels))
if act_layer is not None:
self.proj.add(act_layer())
self.proj_skip.add(act_layer())
[docs]
self.traceable = traceable
[docs]
def forward(self, point):
assert "pooling_parent" in point.keys()
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pop("pooling_inverse")
point = self.proj(point)
parent = self.proj_skip(parent)
parent.feat = parent.feat + point.feat[inverse]
if self.traceable:
parent["unpooling_parent"] = point
return parent
[docs]
class Embedding(PointModule):
def __init__(
self,
in_channels,
embed_channels,
norm_layer=None,
act_layer=None,
):
super().__init__()
[docs]
self.in_channels = in_channels
[docs]
self.embed_channels = embed_channels
# TODO: check remove spconv
[docs]
self.stem = PointSequential(
conv=spconv.SubMConv3d(
in_channels,
embed_channels,
kernel_size=5,
padding=1,
bias=False,
indice_key="stem",
)
)
if norm_layer is not None:
self.stem.add(norm_layer(embed_channels), name="norm")
if act_layer is not None:
self.stem.add(act_layer(), name="act")
[docs]
def forward(self, point: Point):
point = self.stem(point)
return point