Source code for fastdev.geom.warp_scene

# mypy: ignore-errors
# ruff: noqa: F821
import os
from functools import lru_cache
from typing import List, Optional, Union

import numpy as np
import torch
import trimesh
import warp as wp
from beartype import beartype
from jaxtyping import Float, Int

from fastdev.geom.warp_sdf_fns import query_sdf_in_scenes

[docs] Device = Optional[Union[str, torch.device]] # from torch.types import Device # make mypy happy
@beartype
[docs] class Scene: """Class to manage multiple scenes""" def __init__(self, num_scenes: int, device: Device = "cpu"):
[docs] self.device = device
[docs] self.num_scenes = num_scenes
# common tensor arguments int_args = {"device": device, "dtype": torch.int32} float_args = {"device": device, "dtype": torch.float32} # mesh fields self._mesh_filenames: List[str] = [] self._mesh_scene_indices: Int[torch.Tensor, "num_meshes"] = torch.empty((0,), **int_args) self._mesh_scales: Float[torch.Tensor, "num_meshes 3"] = torch.empty((0, 3), **float_args) self._mesh_poses: Float[torch.Tensor, "num_meshes 4 4"] = torch.empty((0, 4, 4), **float_args) self._wp_meshes: List[wp.Mesh] = [] # box fields self._box_sizes: Float[torch.Tensor, "num_boxes 3"] = torch.empty((0, 3), **float_args) self._box_scene_indices: Int[torch.Tensor, "num_boxes"] = torch.empty((0,), **int_args) self._box_scales: Float[torch.Tensor, "num_boxes 3"] = torch.empty((0, 3), **float_args) self._box_poses: Float[torch.Tensor, "num_boxes 4 4"] = torch.empty((0, 4, 4), **float_args) # scene indices for built scenes self._is_built: bool = False self._scene_mesh_indices: Int[torch.Tensor, "num_scenes"] = torch.empty((0,), **int_args) self._scene_mesh_first_indices: Int[torch.Tensor, "num_scenes"] = torch.empty((0,), **int_args) self._scene_box_indices: Int[torch.Tensor, "num_scenes"] = torch.empty((0,), **int_args) self._scene_box_first_indices: Int[torch.Tensor, "num_scenes"] = torch.empty((0,), **int_args)
[docs] def add_meshes_from_files( self, filenames: List[str], # filenames can be repeated since meshes are cached by filename scene_indices: Int[torch.Tensor, "num_meshes"], scales: Optional[Float[torch.Tensor, "num_meshes 3"]] = None, poses: Optional[Float[torch.Tensor, "num_meshes 4 4"]] = None, ): filenames = [os.path.normpath(os.path.abspath(filename)) for filename in filenames] # normalize paths self._mesh_filenames.extend(filenames) self._mesh_scene_indices = torch.cat([self._mesh_scene_indices, scene_indices.to(self.device)]) if scales is None: scales = torch.ones((len(filenames), 3), device=self.device, dtype=torch.float32) self._mesh_scales = torch.cat([self._mesh_scales, scales.to(self.device)]) if poses is None: poses = torch.eye(4, device=self.device, dtype=torch.float32).repeat(len(filenames), 1, 1) self._mesh_poses = torch.cat([self._mesh_poses, poses.to(self.device)])
[docs] def add_boxes( self, sizes: Float[torch.Tensor, "num_boxes 3"], scene_indices: Int[torch.Tensor, "num_boxes"], scales: Optional[Float[torch.Tensor, "num_boxes"]] = None, poses: Optional[Float[torch.Tensor, "num_boxes 4 4"]] = None, ): self._box_sizes = torch.cat([self._box_sizes, sizes.to(self.device)]) self._box_scene_indices = torch.cat([self._box_scene_indices, scene_indices.to(self.device)]) if scales is None: scales = torch.ones((len(sizes), 3), device=self.device, dtype=torch.float32) self._box_scales = torch.cat([self._box_scales, scales.to(self.device)]) if poses is None: poses = torch.eye(4, device=self.device, dtype=torch.float32).repeat(len(sizes), 1, 1)
[docs] def add_spheres(self): raise NotImplementedError("Spheres are not implemented yet")
[docs] def build(self) -> None: filename_to_wp_mesh = {f: self.load_warp_mesh_cached(f, self.device) for f in set(self._mesh_filenames)} self._wp_meshes = [filename_to_wp_mesh[f] for f in self._mesh_filenames] self._mesh_ids_wp = wp.array([m.id for m in self._wp_meshes], dtype=wp.uint64, device=self.device) mesh_indices_per_scene = [ torch.nonzero(self._mesh_scene_indices == i, as_tuple=True)[0] for i in range(self.num_scenes) ] self._scene_mesh_indices = torch.cat(mesh_indices_per_scene, dim=0) mesh_counts_per_scene = torch.tensor( [indices.numel() for indices in mesh_indices_per_scene], device=self.device, dtype=torch.int32 ) self._scene_mesh_first_indices = torch.cat( [torch.zeros(1, device=self.device, dtype=torch.int32), torch.cumsum(mesh_counts_per_scene, dim=0)[:-1]], dim=0, ) box_indices_per_scene = [ torch.nonzero(self._box_scene_indices == i, as_tuple=True)[0] for i in range(self.num_scenes) ] self._scene_box_indices = torch.cat(box_indices_per_scene, dim=0) box_counts_per_scene = torch.tensor( [indices.numel() for indices in box_indices_per_scene], device=self.device, dtype=torch.int32 ) self._scene_box_first_indices = torch.cat( [torch.zeros(1, device=self.device, dtype=torch.int32), torch.cumsum(box_counts_per_scene, dim=0)[:-1]], dim=0, ) self._is_built = True
[docs] def query_signed_distances( self, query_points: Float[torch.Tensor, "num_points 3"], query_points_first_idx: Int[torch.Tensor, "num_scenes"], max_dist: float = 1e6, ) -> torch.Tensor: if not self._is_built: raise RuntimeError("Scene is not built. Call build() first.") wp.init() full_query_points_first_idx = torch.cat( [query_points_first_idx, torch.tensor([query_points.shape[0]], device=query_points.device)] ) max_num_pts_per_scene = torch.diff(full_query_points_first_idx).max().item() sdf = torch.empty_like(query_points[..., 0], requires_grad=False) normals = torch.empty_like(query_points, requires_grad=False) clst_pts = torch.empty_like(query_points, requires_grad=False) wp.launch( kernel=query_sdf_in_scenes, dim=(max_num_pts_per_scene, self.num_scenes), inputs=[ wp.from_torch(query_points.contiguous().view(-1, 3), dtype=wp.vec3, requires_grad=False), wp.from_torch(query_points_first_idx.contiguous().to(torch.int32), dtype=wp.int32, requires_grad=False), self._mesh_ids_wp, self._meshes_first_idx_wp, self._mesh_poses, self._mesh_scales, self._box_sizes, self._box_first_idx_wp, self._box_poses, self._box_scales, max_dist, wp.from_torch(sdf.view(-1), dtype=wp.float32), wp.from_torch(normals.view(-1, 3), dtype=wp.vec3), wp.from_torch(clst_pts.view(-1, 3), dtype=wp.vec3), ], device=wp.device_from_torch(query_points.device), )
@property
[docs] def num_meshes(self) -> int: return len(self._mesh_filenames)
@property
[docs] def num_boxes(self) -> int: return len(self._box_sizes)
[docs] def __repr__(self) -> str: return f"Scene(num_scenes={self.num_scenes}, num_meshes={self.num_meshes}, num_boxes={self.num_boxes})"
[docs] def __str__(self) -> str: return self.__repr__()
@staticmethod @lru_cache(maxsize=1024) @beartype
[docs] def load_warp_mesh_cached(mesh_path: str, device: Device = "cpu") -> wp.Mesh: mesh = trimesh.load(mesh_path, process=False, force="mesh") v, f = mesh.vertices.view(np.ndarray), mesh.faces.view(np.ndarray) # type: ignore return wp.Mesh( points=wp.array(v, dtype=wp.vec3, device=device), indices=wp.array(np.ravel(f), dtype=int, device=device), )