Source code for fastdev.nn.point_transformer_v3

# type: ignore
"""
Point Transformer - V3 Mode1

Adapted from: https://github.com/Pointcept/Pointcept

This module requires the installation of the following packages:

- addict: pip install addict
- spconv: https://github.com/traveller59/spconv?tab=readme-ov-file#spconv-spatially-sparse-convolution-library
- torch-scatter: https://github.com/rusty1s/pytorch_scatter?tab=readme-ov-file#installation
- flash-attention: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features

Original Author: Xiaoyang Wu ([email protected])
Please cite their work if you use the following code in your research paper.
"""

import math
import sys
from collections import OrderedDict
from functools import partial
from typing import Optional, Union

import spconv.pytorch as spconv
import torch
import torch.nn as nn
import torch_scatter
from addict import Dict

try:
    import flash_attn
except ImportError:
[docs] flash_attn = None
try: import ocnn # not really necessary except ImportError:
[docs] ocnn = None
# Adapted from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
[docs] def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor
[docs] class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): super(DropPath, self).__init__()
[docs] self.drop_prob = drop_prob
[docs] self.scale_by_keep = scale_by_keep
[docs] def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
[docs] def extra_repr(self): return f"drop_prob={round(self.drop_prob, 3):0.3f}"
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/serialization/z_order.py
[docs] class KeyLUT: def __init__(self): r256 = torch.arange(256, dtype=torch.int64) r512 = torch.arange(512, dtype=torch.int64) zero = torch.zeros(256, dtype=torch.int64) device = torch.device("cpu") self._encode = { device: ( self.xyz2key(r256, zero, zero, 8), self.xyz2key(zero, r256, zero, 8), self.xyz2key(zero, zero, r256, 8), ) } self._decode = {device: self.key2xyz(r512, 9)}
[docs] def encode_lut(self, device=torch.device("cpu")): if device not in self._encode: cpu = torch.device("cpu") self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) return self._encode[device]
[docs] def decode_lut(self, device=torch.device("cpu")): if device not in self._decode: cpu = torch.device("cpu") self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) return self._decode[device]
[docs] def xyz2key(self, x, y, z, depth): key = torch.zeros_like(x) for i in range(depth): mask = 1 << i key = key | ((x & mask) << (2 * i + 2)) | ((y & mask) << (2 * i + 1)) | ((z & mask) << (2 * i + 0)) return key
[docs] def key2xyz(self, key, depth): x = torch.zeros_like(key) y = torch.zeros_like(key) z = torch.zeros_like(key) for i in range(depth): x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) return x, y, z
_key_lut = KeyLUT()
[docs] def xyz2key( x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, b: Optional[Union[torch.Tensor, int]] = None, depth: int = 16, ): r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys based on pre-computed look up tables. The speed of this function is much faster than the method based on for-loop. Args: x (torch.Tensor): The x coordinate. y (torch.Tensor): The y coordinate. z (torch.Tensor): The z coordinate. b (torch.Tensor or int): The batch index of the coordinates, and should be smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). """ EX, EY, EZ = _key_lut.encode_lut(x.device) x, y, z = x.long(), y.long(), z.long() mask = 255 if depth > 8 else (1 << depth) - 1 key = EX[x & mask] | EY[y & mask] | EZ[z & mask] if depth > 8: mask = (1 << (depth - 8)) - 1 key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] key = key16 << 24 | key if b is not None: b = b.long() key = b << 48 | key return key
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/serialization/hilbert.py
[docs] def right_shift(binary, k=1, axis=-1): """Right shift an array of binary values. Parameters: ----------- binary: An ndarray of binary values. k: The number of bits to shift. Default 1. axis: The axis along which to shift. Default -1. Returns: -------- Returns an ndarray with zero prepended and the ends truncated, along whatever axis was specified.""" # If we're shifting the whole thing, just return zeros. if binary.shape[axis] <= k: return torch.zeros_like(binary) # Determine the padding pattern. # padding = [(0,0)] * len(binary.shape) # padding[axis] = (k,0) # Determine the slicing pattern to eliminate just the last one. slicing = [slice(None)] * len(binary.shape) slicing[axis] = slice(None, -k) shifted = torch.nn.functional.pad(binary[tuple(slicing)], (k, 0), mode="constant", value=0) return shifted
[docs] def gray2binary(gray, axis=-1): """Convert an array of Gray codes back into binary values. Parameters: ----------- gray: An ndarray of gray codes. axis: The axis along which to perform Gray decoding. Default=-1. Returns: -------- Returns an ndarray of binary values. """ # Loop the log2(bits) number of times necessary, with shift and xor. shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) while shift > 0: gray = torch.logical_xor(gray, right_shift(gray, shift)) shift = torch.div(shift, 2, rounding_mode="floor") return gray
[docs] def hilbert_encode_(locs, num_dims, num_bits): """Decode an array of locations in a hypercube into a Hilbert integer. This is a vectorized-ish version of the Hilbert curve implementation by John Skilling as described in: Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. Params: ------- locs - An ndarray of locations in a hypercube of num_dims dimensions, in which each dimension runs from 0 to 2**num_bits-1. The shape can be arbitrary, as long as the last dimension of the same has size num_dims. num_dims - The dimensionality of the hypercube. Integer. num_bits - The number of bits for each dimension. Integer. Returns: -------- The output is an ndarray of uint64 integers with the same shape as the input, excluding the last dimension, which needs to be num_dims. """ # Keep around the original shape for later. orig_shape = locs.shape bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) bitpack_mask_rev = bitpack_mask.flip(-1) if orig_shape[-1] != num_dims: raise ValueError( """ The shape of locs was surprising in that the last dimension was of size %d, but num_dims=%d. These need to be equal. """ % (orig_shape[-1], num_dims) ) if num_dims * num_bits > 63: raise ValueError( """ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded into a int64. Are you sure you need that many points on your Hilbert curve? """ % (num_dims, num_bits, num_dims * num_bits) ) # Treat the location integers as 64-bit unsigned and then split them up into # a sequence of uint8s. Preserve the association by dimension. locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) # Now turn these into bits and truncate to num_bits. gray = locs_uint8.unsqueeze(-1).bitwise_and(bitpack_mask_rev).ne(0).byte().flatten(-2, -1)[..., -num_bits:] # Run the decoding process the other way. # Iterate forwards through the bits. for bit in range(0, num_bits): # Iterate forwards through the dimensions. for dim in range(0, num_dims): # Identify which ones have this bit active. mask = gray[:, dim, bit] # Where this bit is on, invert the 0 dimension for lower bits. gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], mask[:, None]) # Where the bit is off, exchange the lower bits with the 0 dimension. to_flip = torch.logical_and( torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), ) gray[:, dim, bit + 1 :] = torch.logical_xor(gray[:, dim, bit + 1 :], to_flip) gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) # Now flatten out. gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) # Convert Gray back to binary. hh_bin = gray2binary(gray) # Pad back out to 64 bits. extra_dims = 64 - num_bits * num_dims padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) # Convert binary values into uint8s. hh_uint8 = (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask).sum(2).squeeze().type(torch.uint8) # Convert uint8s into uint64s. hh_uint64 = hh_uint8.view(torch.int64).squeeze() return hh_uint64
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/serialization/default.py @torch.inference_mode()
[docs] def encode(grid_coord, batch=None, depth=16, order="z"): assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} if order == "z": code = z_order_encode(grid_coord, depth=depth) elif order == "z-trans": code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) elif order == "hilbert": code = hilbert_encode(grid_coord, depth=depth) elif order == "hilbert-trans": code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) else: raise NotImplementedError if batch is not None: batch = batch.long() code = batch << depth * 3 | code return code
[docs] def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() # we block the support to batch, maintain batched code in Point class code = xyz2key(x, y, z, b=None, depth=depth) return code
[docs] def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/misc.py @torch.inference_mode()
[docs] def offset2bincount(offset): return torch.diff(offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long))
@torch.inference_mode()
[docs] def offset2batch(offset): bincount = offset2bincount(offset) return torch.arange(len(bincount), device=offset.device, dtype=torch.long).repeat_interleave(bincount)
@torch.inference_mode()
[docs] def batch2offset(batch): return torch.cumsum(batch.bincount(), dim=0).long()
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/utils/structure.py
[docs] class Point(Dict): """ Point Structure of Pointcept A Point (point cloud) in Pointcept is a dictionary that contains various properties of a batched point cloud. The property with the following names have a specific definition as follows: - "coord": original coordinate of point cloud; - "grid_coord": grid coordinate for specific grid size (related to GridSampling); Point also support the following optional attributes: - "offset": if not exist, initialized as batch size is 1; - "batch": if not exist, initialized as batch size is 1; - "feat": feature of point cloud, default input of model; - "grid_size": Grid size of point cloud (related to GridSampling); (related to Serialization) - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; - "serialized_code": a list of serialization codes; - "serialized_order": a list of serialization order determined by code; - "serialized_inverse": a list of inverse mapping determined by code; (related to Sparsify: SpConv) - "sparse_shape": Sparse shape for Sparse Conv Tensor; - "sparse_conv_feat": SparseConvTensor init with information provide by Point; """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # If one of "offset" or "batch" do not exist, generate by the existing one if "batch" not in self.keys() and "offset" in self.keys(): self["batch"] = offset2batch(self.offset) elif "offset" not in self.keys() and "batch" in self.keys(): self["offset"] = batch2offset(self.batch)
[docs] def serialization(self, order="z", depth=None, shuffle_orders=False): """ Point Cloud Serialization relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] """ assert "batch" in self.keys() if "grid_coord" not in self.keys(): # if you don't want to operate GridSampling in data augmentation, # please add the following augmentation into your pipline: # dict(type="Copy", keys_dict={"grid_size": 0.01}), # (adjust `grid_size` to what your want) assert {"grid_size", "coord"}.issubset(self.keys()) self["grid_coord"] = torch.div( self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" ).int() if depth is None: # Adaptive measure the depth of serialization cube (length = 2 ^ depth) depth = int(self.grid_coord.max()).bit_length() self["serialized_depth"] = depth # Maximum bit length for serialization code is 63 (int64) assert depth * 3 + len(self.offset).bit_length() <= 63 # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. # We can unlock the limitation by optimizing the z-order encoding function if necessary. assert depth <= 16 # The serialization codes are arranged as following structures: # [Order1 ([n]), # Order2 ([n]), # ... # OrderN ([n])] (k, n) code = [encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order] code = torch.stack(code) order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1), ) if shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] self["serialized_code"] = code self["serialized_order"] = order self["serialized_inverse"] = inverse
[docs] def sparsify(self, pad=96): """ Point Cloud Serialization Point cloud is sparse, here we use "sparsify" to specifically refer to preparing "spconv.SparseConvTensor" for SpConv. relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] pad: padding sparse for sparse shape. """ assert {"feat", "batch"}.issubset(self.keys()) if "grid_coord" not in self.keys(): # if you don't want to operate GridSampling in data augmentation, # please add the following augmentation into your pipline: # dict(type="Copy", keys_dict={"grid_size": 0.01}), # (adjust `grid_size` to what your want) assert {"grid_size", "coord"}.issubset(self.keys()) self["grid_coord"] = torch.div( self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" ).int() if "sparse_shape" in self.keys(): sparse_shape = self.sparse_shape else: sparse_shape = torch.add(torch.max(self.grid_coord, dim=0).values, pad).tolist() sparse_conv_feat = spconv.SparseConvTensor( features=self.feat, indices=torch.cat([self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1).contiguous(), spatial_shape=sparse_shape, batch_size=self.batch[-1].tolist() + 1, ) self["sparse_shape"] = sparse_shape self["sparse_conv_feat"] = sparse_conv_feat
[docs] def octreetization(self, depth=None, full_depth=None): """ Point Cloud Octreelization Generate octree with OCNN relay on ["grid_coord", "batch", "feat"] """ assert ocnn is not None, "Please follow https://github.com/octree-nn/ocnn-pytorch install ocnn." assert {"grid_coord", "feat", "batch"}.issubset(self.keys()) # add 1 to make grid space support shift order if depth is None: if "depth" in self.keys(): depth = self.depth else: depth = int(self.grid_coord.max() + 1).bit_length() if full_depth is None: full_depth = 2 self["depth"] = depth assert depth <= 16 # maximum in ocnn # [0, 2**depth] -> [0, 2] -> [-1, 1] coord = self.grid_coord / 2 ** (self.depth - 1) - 1.0 point = ocnn.octree.Points( points=coord, features=self.feat, batch_id=self.batch.unsqueeze(-1), batch_size=self.batch[-1] + 1, ) octree = ocnn.octree.Octree( depth=depth, full_depth=full_depth, batch_size=self.batch[-1] + 1, device=coord.device, ) octree.build_octree(point) octree.construct_all_neigh() self["octree"] = octree
# Adpted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/modules.py
[docs] class PointModule(nn.Module): r"""PointModule placeholder, all module subclass from this will take Point in PointSequential. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] class PointSequential(PointModule): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. """ def __init__(self, *args, **kwargs): super().__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) for name, module in kwargs.items(): if sys.version_info < (3, 6): raise ValueError("kwargs only supported in py36+") if name in self._modules: raise ValueError("name exists.") self.add_module(name, module)
[docs] def __getitem__(self, idx): if not (-len(self) <= idx < len(self)): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) it = iter(self._modules.values()) for i in range(idx): next(it) return next(it)
[docs] def __len__(self): return len(self._modules)
[docs] def add(self, module, name=None): if name is None: name = str(len(self._modules)) if name in self._modules: raise KeyError("name exists") self.add_module(name, module)
[docs] def forward(self, input): for k, module in self._modules.items(): # Point module if isinstance(module, PointModule): input = module(input) # Spconv module elif spconv.modules.is_spconv_module(module): if isinstance(input, Point): input.sparse_conv_feat = module(input.sparse_conv_feat) input.feat = input.sparse_conv_feat.features else: input = module(input) # PyTorch module else: if isinstance(input, Point): input.feat = module(input.feat) if "sparse_conv_feat" in input.keys(): input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(input.feat) elif isinstance(input, spconv.SparseConvTensor): if input.indices.shape[0] != 0: input = input.replace_feature(module(input.features)) else: input = module(input) return input
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_prompt_training/prompt_driven_normalization.py
[docs] class PDNorm(PointModule): def __init__( self, num_features, norm_layer, context_channels=256, conditions=("ScanNet", "S3DIS", "Structured3D"), decouple=True, adaptive=False, ): super().__init__()
[docs] self.conditions = conditions
[docs] self.decouple = decouple
[docs] self.adaptive = adaptive
if self.decouple: self.norm = nn.ModuleList([norm_layer(num_features) for _ in conditions]) else: self.norm = norm_layer if self.adaptive: self.modulation = nn.Sequential(nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True))
[docs] def forward(self, point): assert {"feat", "condition"}.issubset(point.keys()) if isinstance(point.condition, str): condition = point.condition else: condition = point.condition[0] if self.decouple: assert condition in self.conditions norm = self.norm[self.conditions.index(condition)] else: norm = self.norm point.feat = norm(point.feat) if self.adaptive: assert "context" in point.keys() shift, scale = self.modulation(point.context).chunk(2, dim=1) point.feat = point.feat * (1.0 + scale) + shift return point
# Adapted from: https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py
[docs] class RPE(torch.nn.Module): def __init__(self, patch_size, num_heads): super().__init__()
[docs] self.patch_size = patch_size
[docs] self.num_heads = num_heads
[docs] self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
[docs] self.rpe_num = 2 * self.pos_bnd + 1
[docs] self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
[docs] def forward(self, coord): idx = ( coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + self.pos_bnd # relative position to positive index + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride ) out = self.rpe_table.index_select(0, idx.reshape(-1)) out = out.view(idx.shape + (-1,)).sum(3) out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) return out
[docs] class SerializedAttention(PointModule): def __init__( self, channels, num_heads, patch_size, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, order_index=0, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() assert channels % num_heads == 0
[docs] self.channels = channels
[docs] self.num_heads = num_heads
[docs] self.scale = qk_scale or (channels // num_heads) ** -0.5
[docs] self.order_index = order_index
[docs] self.upcast_attention = upcast_attention
[docs] self.upcast_softmax = upcast_softmax
[docs] self.enable_rpe = enable_rpe
[docs] self.enable_flash = enable_flash
if enable_flash: assert enable_rpe is False, "Set enable_rpe to False when enable Flash Attention" assert upcast_attention is False, "Set upcast_attention to False when enable Flash Attention" assert upcast_softmax is False, "Set upcast_softmax to False when enable Flash Attention" assert flash_attn is not None, "Make sure flash_attn is installed." self.patch_size = patch_size self.attn_drop = attn_drop else: # when disable flash attention, we still don't want to use mask # consequently, patch size will auto set to the # min number of patch_size_max and number of points self.patch_size_max = patch_size self.patch_size = 0 self.attn_drop = torch.nn.Dropout(attn_drop)
[docs] self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
[docs] self.proj = torch.nn.Linear(channels, channels)
[docs] self.proj_drop = torch.nn.Dropout(proj_drop)
[docs] self.softmax = torch.nn.Softmax(dim=-1)
[docs] self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
@torch.no_grad()
[docs] def get_rel_pos(self, point, order): K = self.patch_size rel_pos_key = f"rel_pos_{self.order_index}" if rel_pos_key not in point.keys(): grid_coord = point.grid_coord[order] grid_coord = grid_coord.reshape(-1, K, 3) point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) return point[rel_pos_key]
@torch.no_grad()
[docs] def get_padding_and_inverse(self, point): pad_key = "pad" unpad_key = "unpad" cu_seqlens_key = "cu_seqlens_key" if pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys(): offset = point.offset bincount = offset2bincount(offset) bincount_pad = ( torch.div( bincount + self.patch_size - 1, self.patch_size, rounding_mode="trunc", ) * self.patch_size ) # only pad point when num of points larger than patch_size mask_pad = bincount > self.patch_size bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad _offset = nn.functional.pad(offset, (1, 0)) _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) pad = torch.arange(_offset_pad[-1], device=offset.device) unpad = torch.arange(_offset[-1], device=offset.device) cu_seqlens = [] for i in range(len(offset)): unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] if bincount[i] != bincount_pad[i]: pad[_offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1]] = ( pad[ _offset_pad[i + 1] - 2 * self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[ i + 1 ] - self.patch_size ] ) pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] cu_seqlens.append( torch.arange( _offset_pad[i], _offset_pad[i + 1], step=self.patch_size, dtype=torch.int32, device=offset.device, ) ) point[pad_key] = pad point[unpad_key] = unpad point[cu_seqlens_key] = nn.functional.pad(torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]) return point[pad_key], point[unpad_key], point[cu_seqlens_key]
[docs] def forward(self, point): if not self.enable_flash: self.patch_size = min(offset2bincount(point.offset).min().tolist(), self.patch_size_max) H = self.num_heads K = self.patch_size C = self.channels pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) order = point.serialized_order[self.order_index][pad] inverse = unpad[point.serialized_inverse[self.order_index]] # padding and reshape feat and batch for serialized point patch qkv = self.qkv(point.feat)[order] if not self.enable_flash: # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') q, k, v = qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) # attn if self.upcast_attention: q = q.float() k = k.float() attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) if self.enable_rpe: attn = attn + self.rpe(self.get_rel_pos(point, order)) if self.upcast_softmax: attn = attn.float() attn = self.softmax(attn) attn = self.attn_drop(attn).to(qkv.dtype) feat = (attn @ v).transpose(1, 2).reshape(-1, C) else: feat = flash_attn.flash_attn_varlen_qkvpacked_func( qkv.half().reshape(-1, 3, H, C // H), cu_seqlens, max_seqlen=self.patch_size, dropout_p=self.attn_drop if self.training else 0, softmax_scale=self.scale, ).reshape(-1, C) feat = feat.to(qkv.dtype) feat = feat[inverse] # ffn feat = self.proj(feat) feat = self.proj_drop(feat) point.feat = feat return point
[docs] class MLP(nn.Module): def __init__( self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels
[docs] self.fc1 = nn.Linear(in_channels, hidden_channels)
[docs] self.act = act_layer()
[docs] self.fc2 = nn.Linear(hidden_channels, out_channels)
[docs] self.drop = nn.Dropout(drop)
[docs] def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs] class Block(PointModule): def __init__( self, channels, num_heads, patch_size=48, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, act_layer=nn.GELU, pre_norm=True, order_index=0, cpe_indice_key=None, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__()
[docs] self.channels = channels
[docs] self.pre_norm = pre_norm
[docs] self.cpe = PointSequential( spconv.SubMConv3d( channels, channels, kernel_size=3, bias=True, indice_key=cpe_indice_key, ), nn.Linear(channels, channels), norm_layer(channels), )
[docs] self.norm1 = PointSequential(norm_layer(channels))
[docs] self.attn = SerializedAttention( channels=channels, patch_size=patch_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, order_index=order_index, enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, )
[docs] self.norm2 = PointSequential(norm_layer(channels))
[docs] self.mlp = PointSequential( MLP( in_channels=channels, hidden_channels=int(channels * mlp_ratio), out_channels=channels, act_layer=act_layer, drop=proj_drop, ) )
[docs] self.drop_path = PointSequential(DropPath(drop_path) if drop_path > 0.0 else nn.Identity())
[docs] def forward(self, point: Point): shortcut = point.feat point = self.cpe(point) point.feat = shortcut + point.feat shortcut = point.feat if self.pre_norm: point = self.norm1(point) point = self.drop_path(self.attn(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm1(point) shortcut = point.feat if self.pre_norm: point = self.norm2(point) point = self.drop_path(self.mlp(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm2(point) point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point
[docs] class SerializedPooling(PointModule): def __init__( self, in_channels, out_channels, stride=2, norm_layer=None, act_layer=None, reduce="max", shuffle_orders=True, traceable=True, # record parent and cluster ): super().__init__()
[docs] self.in_channels = in_channels
[docs] self.out_channels = out_channels
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 # TODO: add support to grid pool (any stride)
[docs] self.stride = stride
assert reduce in ["sum", "mean", "min", "max"]
[docs] self.reduce = reduce
[docs] self.shuffle_orders = shuffle_orders
[docs] self.traceable = traceable
[docs] self.proj = nn.Linear(in_channels, out_channels)
if norm_layer is not None: self.norm = PointSequential(norm_layer(out_channels)) if act_layer is not None: self.act = PointSequential(act_layer())
[docs] def forward(self, point: Point): pooling_depth = (math.ceil(self.stride) - 1).bit_length() if pooling_depth > point.serialized_depth: pooling_depth = 0 assert { "serialized_code", "serialized_order", "serialized_inverse", "serialized_depth", }.issubset(point.keys()), "Run point.serialization() point cloud before SerializedPooling" code = point.serialized_code >> pooling_depth * 3 code_, cluster, counts = torch.unique( code[0], sorted=True, return_inverse=True, return_counts=True, ) # indices of point sorted by cluster, for torch_scatter.segment_csr _, indices = torch.sort(cluster) # index pointer for sorted point, for torch_scatter.segment_csr idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) # head_indices of each cluster, for reduce attr e.g. code, batch head_indices = indices[idx_ptr[:-1]] # generate down code, order, inverse code = code[:, head_indices] order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat(code.shape[0], 1), ) if self.shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] # collect information point_dict = Dict( feat=torch_scatter.segment_csr(self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce), coord=torch_scatter.segment_csr(point.coord[indices], idx_ptr, reduce="mean"), grid_coord=point.grid_coord[head_indices] >> pooling_depth, serialized_code=code, serialized_order=order, serialized_inverse=inverse, serialized_depth=point.serialized_depth - pooling_depth, batch=point.batch[head_indices], ) if "condition" in point.keys(): point_dict["condition"] = point.condition if "context" in point.keys(): point_dict["context"] = point.context if self.traceable: point_dict["pooling_inverse"] = cluster point_dict["pooling_parent"] = point point = Point(point_dict) if self.norm is not None: point = self.norm(point) if self.act is not None: point = self.act(point) point.sparsify() return point
[docs] class SerializedUnpooling(PointModule): def __init__( self, in_channels, skip_channels, out_channels, norm_layer=None, act_layer=None, traceable=False, # record parent and cluster ): super().__init__()
[docs] self.proj = PointSequential(nn.Linear(in_channels, out_channels))
[docs] self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
if norm_layer is not None: self.proj.add(norm_layer(out_channels)) self.proj_skip.add(norm_layer(out_channels)) if act_layer is not None: self.proj.add(act_layer()) self.proj_skip.add(act_layer())
[docs] self.traceable = traceable
[docs] def forward(self, point): assert "pooling_parent" in point.keys() assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pop("pooling_inverse") point = self.proj(point) parent = self.proj_skip(parent) parent.feat = parent.feat + point.feat[inverse] if self.traceable: parent["unpooling_parent"] = point return parent
[docs] class Embedding(PointModule): def __init__( self, in_channels, embed_channels, norm_layer=None, act_layer=None, ): super().__init__()
[docs] self.in_channels = in_channels
[docs] self.embed_channels = embed_channels
# TODO: check remove spconv
[docs] self.stem = PointSequential( conv=spconv.SubMConv3d( in_channels, embed_channels, kernel_size=5, padding=1, bias=False, indice_key="stem", ) )
if norm_layer is not None: self.stem.add(norm_layer(embed_channels), name="norm") if act_layer is not None: self.stem.add(act_layer(), name="act")
[docs] def forward(self, point: Point): point = self.stem(point) return point
[docs] class PointTransformerV3(PointModule): def __init__( self, in_channels=6, order=("z", "z-trans"), stride=(2, 2, 2, 2), enc_depths=(2, 2, 2, 6, 2), enc_channels=(32, 64, 128, 256, 512), enc_num_head=(2, 4, 8, 16, 32), enc_patch_size=(48, 48, 48, 48, 48), dec_depths=(2, 2, 2, 2), dec_channels=(64, 64, 128, 256), dec_num_head=(4, 4, 8, 16), dec_patch_size=(48, 48, 48, 48), mlp_ratio=4, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.3, pre_norm=True, shuffle_orders=True, enable_rpe=False, enable_flash=True, upcast_attention=False, upcast_softmax=False, cls_mode=False, pdnorm_bn=False, pdnorm_ln=False, pdnorm_decouple=True, pdnorm_adaptive=False, pdnorm_affine=True, pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), ): super().__init__()
[docs] self.num_stages = len(enc_depths)
[docs] self.order = [order] if isinstance(order, str) else order
[docs] self.cls_mode = cls_mode
[docs] self.shuffle_orders = shuffle_orders
assert self.num_stages == len(stride) + 1 assert self.num_stages == len(enc_depths) assert self.num_stages == len(enc_channels) assert self.num_stages == len(enc_num_head) assert self.num_stages == len(enc_patch_size) assert self.cls_mode or self.num_stages == len(dec_depths) + 1 assert self.cls_mode or self.num_stages == len(dec_channels) + 1 assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 # norm layers if pdnorm_bn: bn_layer = partial( PDNorm, norm_layer=partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine), conditions=pdnorm_conditions, decouple=pdnorm_decouple, adaptive=pdnorm_adaptive, ) else: bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) if pdnorm_ln: ln_layer = partial( PDNorm, norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine), conditions=pdnorm_conditions, decouple=pdnorm_decouple, adaptive=pdnorm_adaptive, ) else: ln_layer = nn.LayerNorm # activation layers act_layer = nn.GELU
[docs] self.embedding = Embedding( in_channels=in_channels, embed_channels=enc_channels[0], norm_layer=bn_layer, act_layer=act_layer, )
# encoder enc_drop_path = [x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))]
[docs] self.enc = PointSequential()
for s in range(self.num_stages): enc_drop_path_ = enc_drop_path[sum(enc_depths[:s]) : sum(enc_depths[: s + 1])] enc = PointSequential() if s > 0: enc.add( SerializedPooling( in_channels=enc_channels[s - 1], out_channels=enc_channels[s], stride=stride[s - 1], norm_layer=bn_layer, act_layer=act_layer, ), name="down", ) for i in range(enc_depths[s]): enc.add( Block( channels=enc_channels[s], num_heads=enc_num_head[s], patch_size=enc_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=enc_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) if len(enc) != 0: self.enc.add(module=enc, name=f"enc{s}") # decoder if not self.cls_mode: dec_drop_path = [x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))] self.dec = PointSequential() dec_channels = list(dec_channels) + [enc_channels[-1]] for s in reversed(range(self.num_stages - 1)): dec_drop_path_ = dec_drop_path[sum(dec_depths[:s]) : sum(dec_depths[: s + 1])] dec_drop_path_.reverse() dec = PointSequential() dec.add( SerializedUnpooling( in_channels=dec_channels[s + 1], skip_channels=enc_channels[s], out_channels=dec_channels[s], norm_layer=bn_layer, act_layer=act_layer, ), name="up", ) for i in range(dec_depths[s]): dec.add( Block( channels=dec_channels[s], num_heads=dec_num_head[s], patch_size=dec_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=dec_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) self.dec.add(module=dec, name=f"dec{s}")
[docs] def forward(self, data_dict): point = Point(data_dict) point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) point.sparsify() point = self.embedding(point) point = self.enc(point) if not self.cls_mode: point = self.dec(point) # else: # point.feat = torch_scatter.segment_csr( # src=point.feat, # indptr=nn.functional.pad(point.offset, (1, 0)), # reduce="mean", # ) return point