Source code for fastdev.sim_webui.webui

# ruff: noqa: F821
import random
import string
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from threading import Event
from typing import Dict, List, Literal, Optional, Tuple, Union, cast, overload

import numpy as np
import torch
import trimesh
import viser
from jaxtyping import Float, Integer, UInt

from fastdev.robo.articulation import Articulation
from fastdev.utils.tensor import atleast_nd, to_number, to_numpy
from fastdev.xform.warp_rotation import matrix_to_quaternion_numpy

# fmt: off
_COMMON_COLORS: Dict[str, List[int]] = {
    "red": [255, 0, 0], "green": [0, 255, 0], "blue": [0, 0, 255],
    "yellow": [255, 255, 0], "cyan": [0, 255, 255], "magenta": [255, 0, 255],
    "white": [255, 255, 255], "black": [0, 0, 0], "gray": [128, 128, 128],
    "orange": [255, 128, 0], "purple": [128, 0, 128], "pink": [255, 0, 128],
    "brown": [128, 64, 0], "teal": [0, 128, 128], "navy": [0, 0, 128],
    "silver": [192, 192, 192], "gold": [255, 215, 0], "indigo": [74, 0, 130],
    "violet": [238, 130, 238], "skyblue": [135, 206, 250]
}
# fmt: on
_DEFAULT_ROBOT_COLOR = np.array([192, 192, 192], dtype=np.uint8)

[docs] AssetType = Literal["mesh", "robot", "point_cloud", "axes"]
[docs] JointValues = Float[np.ndarray, "num_frames num_dofs"]
[docs] JointValuesLike = Union[Float[np.ndarray, "*num_frames num_dofs"], Float[torch.Tensor, "*num_frames num_dofs"]]
[docs] Poses = Union[Float[np.ndarray, "*num_frames 4 4"], Float[np.ndarray, "*num_frames 7"]]
[docs] PosesLike = Union[ Float[np.ndarray, "*num_frames 4 4"], Float[np.ndarray, "*num_frames 7"], Float[torch.Tensor, "*num_frames 4 4"], Float[torch.Tensor, "*num_frames 7"], ]
[docs] Vertices = Float[np.ndarray, "num_vertices 3"]
[docs] VerticesLike = Union[Float[np.ndarray, "*num_vertices 3"], Float[torch.Tensor, "*num_vertices 3"]]
[docs] FacesLike = Union[Integer[np.ndarray, "num_faces 3"], Integer[torch.Tensor, "num_faces 3"]]
[docs] Color = UInt[np.ndarray, "... 3"]
[docs] ColorLike = Union[ str, List[int], List[float], Integer[np.ndarray, "... 3"], Float[np.ndarray, "... 3"], Integer[torch.Tensor, "... 3"], Float[torch.Tensor, "... 3"], ]
[docs] ScaleLike = Union[float, Float[np.ndarray, "..."], Float[torch.Tensor, "..."]]
[docs] AxesPositions = Float[np.ndarray, "num_axes 3"]
[docs] AxesPositionsLike = Union[Float[np.ndarray, "*num_axes 3"], Float[torch.Tensor, "*num_axes 3"]]
[docs] AxesWXYZs = Float[np.ndarray, "num_axes 4"]
[docs] AxesWXYZsLike = Union[ Float[np.ndarray, "*num_axes 4"], Float[np.ndarray, "*num_axes 3 3"], Float[torch.Tensor, "*num_axes 4"], Float[torch.Tensor, "*num_axes 3 3"], ]
@dataclass
[docs] class ViserAsset:
[docs] viser_asset_id: str
[docs] viser_asset_type: Literal["trimesh", "point_cloud", "axes"]
# for all assets
[docs] color: Optional[Color] = None
# for trimesh
[docs] scale: float = 1.0
[docs] trimesh_mesh: Optional[trimesh.Trimesh] = None
# for point cloud
[docs] points: Optional[Vertices] = None
[docs] point_size: float = 0.02
# for axes
[docs] axes_positions: Optional[AxesPositions] = None
[docs] axes_wxyzs: Optional[AxesWXYZs] = None
[docs] axes_length: float = 0.1
[docs] axes_radius: float = 0.005
@dataclass
[docs] class ViserAssetState:
[docs] asset_id: str
[docs] viser_asset_id: str
[docs] position: Float[np.ndarray, "3"]
[docs] wxyz: Float[np.ndarray, "4"]
@dataclass
[docs] class Asset(ABC): """Base asset class."""
[docs] asset_id: str
@dataclass
[docs] class MeshAsset(Asset):
[docs] trimesh_mesh: trimesh.Trimesh
_viser_assets: Dict[str, ViserAsset] = field(default_factory=dict) _color_if_not_provided: Optional[Color] = None
[docs] def get_viser_asset(self, viser_asset_id: str) -> ViserAsset: return self._viser_assets[viser_asset_id]
[docs] def get_or_create_asset_id(self, color: Optional[Color], scale: float, postfix: int = 0) -> str: if color is None: if self._color_if_not_provided is None: self._color_if_not_provided = get_random_color() color = self._color_if_not_provided asset_id = f"mesh/{self.asset_id}/#{color[0]:02x}{color[1]:02x}{color[2]:02x}_{scale:.6f}_{postfix}" if asset_id not in self._viser_assets: mesh = self.trimesh_mesh.copy() mesh.visual.face_colors = color viser_asset = ViserAsset( viser_asset_id=asset_id, viser_asset_type="trimesh", trimesh_mesh=mesh, scale=scale, ) self._viser_assets[asset_id] = viser_asset return asset_id
@dataclass
[docs] class PointCloudAsset(Asset):
[docs] points: Vertices
_viser_assets: Dict[str, ViserAsset] = field(default_factory=dict) _color_if_not_provided: Optional[Color] = None
[docs] def get_viser_asset(self, viser_asset_id: str) -> ViserAsset: return self._viser_assets[viser_asset_id]
[docs] def get_or_create_asset_id(self, color: Optional[Color], point_size: float) -> str: if color is None: if self._color_if_not_provided is None: self._color_if_not_provided = get_random_color() color = self._color_if_not_provided if color.ndim > 1: color_hash = hash(color.tobytes()) asset_id = f"pc/{self.asset_id}/{color_hash}_{point_size:.6f}" else: asset_id = f"pc/{self.asset_id}/#{color[0]:02x}{color[1]:02x}{color[2]:02x}_{point_size:.6f}" if asset_id not in self._viser_assets: viser_asset = ViserAsset( viser_asset_id=asset_id, viser_asset_type="point_cloud", points=self.points, point_size=point_size, color=color, ) self._viser_assets[asset_id] = viser_asset return asset_id
@dataclass
[docs] class AxesAsset(Asset):
[docs] positions: AxesPositions
[docs] wxyzs: AxesWXYZs
_viser_assets: Dict[str, ViserAsset] = field(default_factory=dict)
[docs] def get_viser_asset(self, viser_asset_id: str) -> ViserAsset: return self._viser_assets[viser_asset_id]
[docs] def get_or_create_asset_id(self, axes_length: float, axes_radius: float) -> str: asset_id = f"axes/{self.asset_id}/{axes_length:.6f}_{axes_radius:.6f}" if asset_id not in self._viser_assets: viser_asset = ViserAsset( viser_asset_id=asset_id, viser_asset_type="axes", axes_positions=self.positions, axes_wxyzs=self.wxyzs, axes_length=axes_length, axes_radius=axes_radius, ) self._viser_assets[asset_id] = viser_asset return asset_id
@dataclass
[docs] class RobotAsset(Asset):
[docs] robot_model: Articulation
_link_trimesh_meshes: Optional[Dict[str, trimesh.Trimesh]] = None _viser_assets: Dict[str, ViserAsset] = field(default_factory=dict)
[docs] def get_viser_asset(self, viser_asset_id: str) -> ViserAsset: return self._viser_assets[viser_asset_id]
[docs] def get_or_create_asset_id(self, link_name: str, color: Color) -> str: asset_id = f"robot/{self.asset_id}/{link_name}/#{color[0]:02x}{color[1]:02x}{color[2]:02x}" if asset_id not in self._viser_assets: if self._link_trimesh_meshes is None: self._link_trimesh_meshes = self.robot_model.first_spec.get_link_trimesh_meshes() mesh = self._link_trimesh_meshes[link_name].copy() mesh.visual.face_colors = color viser_asset = ViserAsset( viser_asset_id=asset_id, viser_asset_type="trimesh", trimesh_mesh=mesh, ) self._viser_assets[asset_id] = viser_asset return asset_id
@dataclass
[docs] class AssetLibrary: """Asset library for multiple assets.""" # assets _robot_assets: Dict[str, RobotAsset] = field(default_factory=dict) _mesh_assets: Dict[str, MeshAsset] = field(default_factory=dict) _pc_assets: Dict[str, PointCloudAsset] = field(default_factory=dict) _axes_assets: Dict[str, AxesAsset] = field(default_factory=dict) _asset_id_to_asset_type: Dict[str, AssetType] = field(default_factory=dict) _viser_assets: Dict[str, ViserAsset] = field(default_factory=dict) # cache for assets _urdf_or_mjcf_path_mesh_dir_to_robot_asset_id: Dict[Tuple[str, Optional[str]], str] = field(default_factory=dict) _trimesh_hash_to_mesh_asset_id: Dict[int, str] = field(default_factory=dict) _pc_array_hash_to_pc_asset_id: Dict[int, str] = field(default_factory=dict) _axes_hash_to_axes_asset_id: Dict[int, str] = field(default_factory=dict)
[docs] def add_robot_asset( self, urdf_or_mjcf_path: Optional[str] = None, mesh_dir: Optional[str] = None, articulation: Optional["Articulation"] = None, ) -> str: if urdf_or_mjcf_path is not None: if (urdf_or_mjcf_path, mesh_dir) not in self._urdf_or_mjcf_path_mesh_dir_to_robot_asset_id: robot_model = Articulation.from_urdf_or_mjcf_paths(urdf_or_mjcf_path, mesh_dir) robot_asset = RobotAsset(asset_id=self.get_random_asset_id(), robot_model=robot_model) self._robot_assets[robot_asset.asset_id] = robot_asset self._asset_id_to_asset_type[robot_asset.asset_id] = "robot" self._urdf_or_mjcf_path_mesh_dir_to_robot_asset_id[(urdf_or_mjcf_path, mesh_dir)] = robot_asset.asset_id return self._urdf_or_mjcf_path_mesh_dir_to_robot_asset_id[(urdf_or_mjcf_path, mesh_dir)] elif articulation is not None: # NOTE do not check if the articulation is already in the library for now robot_asset = RobotAsset(asset_id=self.get_random_asset_id(), robot_model=articulation) self._robot_assets[robot_asset.asset_id] = robot_asset self._asset_id_to_asset_type[robot_asset.asset_id] = "robot" return robot_asset.asset_id else: raise ValueError("Either urdf_or_mjcf_path or articulation must be provided")
[docs] def add_mesh_asset(self, trimesh_mesh: trimesh.Trimesh, disable_cache: bool = False) -> str: if disable_cache: mesh_asset = MeshAsset(asset_id=self.get_random_asset_id(), trimesh_mesh=trimesh_mesh) self._mesh_assets[mesh_asset.asset_id] = mesh_asset self._asset_id_to_asset_type[mesh_asset.asset_id] = "mesh" return mesh_asset.asset_id else: # NOTE: We use a basic hash function for trimesh meshes since # trimesh.Trimesh.identifier_hash can return identical values for # meshes that differ by rigid transformation trimesh_hash = hash( np.concatenate([trimesh_mesh.vertices.flatten(), trimesh_mesh.faces.flatten()]).tobytes() ) if trimesh_hash not in self._trimesh_hash_to_mesh_asset_id: mesh_asset = MeshAsset(asset_id=self.get_random_asset_id(), trimesh_mesh=trimesh_mesh) self._mesh_assets[mesh_asset.asset_id] = mesh_asset self._asset_id_to_asset_type[mesh_asset.asset_id] = "mesh" self._trimesh_hash_to_mesh_asset_id[trimesh_hash] = mesh_asset.asset_id return self._trimesh_hash_to_mesh_asset_id[trimesh_hash]
[docs] def add_point_cloud_asset(self, points: Vertices) -> str: points_hash = hash(points.tobytes()) if points_hash not in self._pc_array_hash_to_pc_asset_id: pc_asset = PointCloudAsset(asset_id=self.get_random_asset_id(), points=points) self._pc_assets[pc_asset.asset_id] = pc_asset self._asset_id_to_asset_type[pc_asset.asset_id] = "point_cloud" self._pc_array_hash_to_pc_asset_id[points_hash] = pc_asset.asset_id return self._pc_array_hash_to_pc_asset_id[points_hash]
[docs] def add_axes_asset( self, positions: AxesPositions, wxyzs: AxesWXYZs, ) -> str: axes_hash = hash(positions.tobytes() + wxyzs.tobytes()) if axes_hash not in self._axes_hash_to_axes_asset_id: axes_asset = AxesAsset(asset_id=self.get_random_asset_id(), positions=positions, wxyzs=wxyzs) self._axes_assets[axes_asset.asset_id] = axes_asset self._asset_id_to_asset_type[axes_asset.asset_id] = "axes" self._axes_hash_to_axes_asset_id[axes_hash] = axes_asset.asset_id return self._axes_hash_to_axes_asset_id[axes_hash]
[docs] def asset_exists(self, asset_id: str) -> bool: asset_type = self._asset_id_to_asset_type.get(asset_id) if asset_type == "robot": return asset_id in self._robot_assets elif asset_type == "mesh": return asset_id in self._mesh_assets elif asset_type == "point_cloud": return asset_id in self._pc_assets elif asset_type == "axes": return asset_id in self._axes_assets return False
[docs] def get_asset(self, asset_id: str) -> Asset: asset_type = self._asset_id_to_asset_type.get(asset_id) if asset_type == "robot": return self._robot_assets[asset_id] elif asset_type == "mesh": return self._mesh_assets[asset_id] elif asset_type == "point_cloud": return self._pc_assets[asset_id] elif asset_type == "axes": return self._axes_assets[asset_id] raise ValueError(f"Invalid asset type: {asset_type}")
[docs] def get_viser_asset(self, asset_id: str, viser_asset_id: str) -> ViserAsset: if viser_asset_id not in self._viser_assets: asset_type = self._asset_id_to_asset_type.get(asset_id) if asset_type == "robot": self._viser_assets[viser_asset_id] = self._robot_assets[asset_id].get_viser_asset(viser_asset_id) elif asset_type == "mesh": self._viser_assets[viser_asset_id] = self._mesh_assets[asset_id].get_viser_asset(viser_asset_id) elif asset_type == "point_cloud": self._viser_assets[viser_asset_id] = self._pc_assets[asset_id].get_viser_asset(viser_asset_id) elif asset_type == "axes": self._viser_assets[viser_asset_id] = self._axes_assets[asset_id].get_viser_asset(viser_asset_id) return self._viser_assets[viser_asset_id]
@staticmethod
[docs] def get_random_asset_id() -> str: # ref: https://stackoverflow.com/a/56398787 alphabet = string.ascii_lowercase + string.digits return "".join(random.choices(alphabet, k=8))
[docs] ASSET_LIBRARY = AssetLibrary()
[docs] def to_wxyz(rot: Optional[AxesWXYZsLike]) -> Float[np.ndarray, "... 4"]: rot = to_numpy(rot) if rot is None: return np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) if rot.ndim >= 2 and rot.shape[-2:] == (3, 3): return matrix_to_quaternion_numpy(rot, scalar_first=True) elif rot.shape[-1] == 4: return rot else: raise ValueError(f"Invalid rotation shape: {rot.shape}")
[docs] def to_position_wxyz(pose: Optional[PosesLike]) -> Tuple[Float[np.ndarray, "... 3"], Float[np.ndarray, "... 4"]]: pose = to_numpy(pose) if pose is None: return np.zeros(3, dtype=np.float32), np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) elif pose.shape[-2:] == (4, 4): position = pose[..., :3, 3] wxyz = matrix_to_quaternion_numpy(pose[..., :3, :3]) return position, wxyz elif pose.shape[-1] == 7: return pose[..., :3], pose[..., 3:] else: raise ValueError(f"Invalid pose shape: {pose.shape}")
[docs] def get_random_color( r: Tuple[int, int] = (51, 180), g: Tuple[int, int] = (102, 204), b: Tuple[int, int] = (0, 102) ) -> Color: r_l = np.random.uniform(r[0], r[1], 1) g_l = np.random.uniform(g[0], g[1], 1) b_l = np.random.uniform(b[0], b[1], 1) return np.concatenate([r_l, g_l, b_l]).astype(np.uint8)
@overload
[docs] def to_color_array(color: None) -> None: ...
@overload def to_color_array(color: ColorLike) -> Color: ... def to_color_array(color: Optional[ColorLike]) -> Optional[Color]: if color is None: return None if isinstance(color, str): return np.array(_COMMON_COLORS[color]).astype(np.uint8) color = to_numpy(color, preserve_list=False) if color.dtype == np.float32 or color.dtype == np.float64: return (color * 255).astype(np.uint8) return color.astype(np.uint8) @dataclass
[docs] class AssetState(ABC): """Asset state for multiple frames."""
[docs] asset_id: str
# NOTE actually this is not the number of frames, but the end frame index (exclusive) # the reason for this design is to align with `num_frames` in other asset states @property @abstractmethod
[docs] def num_frames(self) -> int: ...
@abstractmethod
[docs] def get_frame_viser_asset_states(self, frame_index: int) -> List[ViserAssetState]: ...
@dataclass
[docs] class MeshState(AssetState):
[docs] poses: Optional[Poses] = None
[docs] color: Optional[Color] = None
[docs] scale: float = 1.0
_frame_range: Tuple[int, int] = (0, 1) # (start, end), start is inclusive, end is exclusive @property
[docs] def frame_range(self) -> Tuple[int, int]: return self._frame_range
@frame_range.setter def frame_range(self, frame_range: Optional[Tuple[int, int]]): if frame_range is None: if self.poses is not None: self._frame_range = (0, self.poses.shape[0]) else: self._frame_range = frame_range if self.poses is not None: assert self.poses.shape[0] == frame_range[1] - frame_range[0] or self.poses.shape[0] == 1 @property
[docs] def num_frames(self) -> int: return self.frame_range[1]
[docs] def get_frame_viser_asset_states(self, frame_index: int) -> List[ViserAssetState]: if frame_index not in range(*self.frame_range): return [] if self.poses is None: pose = None elif self.poses.shape[0] == 1: pose = self.poses[0] else: pose = self.poses[frame_index] pos, wxyz = to_position_wxyz(pose) mesh_asset: MeshAsset = cast(MeshAsset, ASSET_LIBRARY.get_asset(self.asset_id)) viser_asset_state = ViserAssetState( asset_id=self.asset_id, viser_asset_id=mesh_asset.get_or_create_asset_id(self.color, self.scale), position=pos, wxyz=wxyz, ) return [viser_asset_state]
@dataclass
[docs] class PointCloudState(AssetState):
[docs] poses: Optional[Poses] = None
[docs] color: Optional[Color] = None
[docs] point_size: float = 0.02
_frame_range: Tuple[int, int] = (0, 1) # (start, end), start is inclusive, end is exclusive @property
[docs] def frame_range(self) -> Tuple[int, int]: return self._frame_range
@frame_range.setter def frame_range(self, frame_range: Optional[Tuple[int, int]]): if frame_range is None: if self.poses is not None: self._frame_range = (0, self.poses.shape[0]) else: self._frame_range = frame_range if self.poses is not None: assert self.poses.shape[0] == frame_range[1] - frame_range[0] or self.poses.shape[0] == 1 @property
[docs] def num_frames(self) -> int: return self.frame_range[1]
[docs] def get_frame_viser_asset_states(self, frame_index: int) -> List[ViserAssetState]: if frame_index not in range(*self.frame_range): return [] pc_asset: PointCloudAsset = cast(PointCloudAsset, ASSET_LIBRARY.get_asset(self.asset_id)) pose = self.poses[frame_index] if self.poses is not None else None pos, wxyz = to_position_wxyz(pose) viser_asset_state = ViserAssetState( asset_id=self.asset_id, viser_asset_id=pc_asset.get_or_create_asset_id(self.color, self.point_size), position=pos, wxyz=wxyz, ) return [viser_asset_state]
@dataclass
[docs] class AxesState(AssetState):
[docs] poses: Optional[Poses] = None
[docs] axes_length: float = 0.1
[docs] axes_radius: float = 0.005
@property
[docs] def num_frames(self) -> int: return self.poses.shape[0] if self.poses is not None else 1
[docs] def get_frame_viser_asset_states(self, frame_index: int) -> List[ViserAssetState]: axes_asset: AxesAsset = cast(AxesAsset, ASSET_LIBRARY.get_asset(self.asset_id)) pose = self.poses[frame_index] if self.poses is not None else None pos, wxyz = to_position_wxyz(pose) viser_asset_state = ViserAssetState( asset_id=self.asset_id, viser_asset_id=axes_asset.get_or_create_asset_id( axes_length=self.axes_length, axes_radius=self.axes_radius, ), position=pos, wxyz=wxyz, ) return [viser_asset_state]
@dataclass
[docs] class RobotState(AssetState):
[docs] joint_values: Optional[JointValues] = None
[docs] root_poses: Optional[Poses] = None
[docs] color: Color = field(default_factory=lambda: _DEFAULT_ROBOT_COLOR)
@property
[docs] def num_frames(self) -> int: return self.joint_values.shape[0] if self.joint_values is not None else 1
[docs] def get_frame_viser_asset_states(self, frame_index: int) -> List[ViserAssetState]: robot_asset: RobotAsset = cast(RobotAsset, ASSET_LIBRARY.get_asset(self.asset_id)) frame_joint_values = ( self.joint_values[frame_index] if self.joint_values is not None else robot_asset.robot_model.get_packed_zero_joint_values(return_tensors="np") ) frame_root_poses = self.root_poses[frame_index] if self.root_poses is not None else None link_poses = robot_asset.robot_model.forward_kinematics_numpy( joint_values=frame_joint_values, # type: ignore root_poses=frame_root_poses, # type: ignore ) viser_asset_states = [] for link_name, link_pose in zip(robot_asset.robot_model.first_spec.link_names, link_poses): link_pos, link_wxyz = to_position_wxyz(link_pose) viser_asset_states.append( ViserAssetState( asset_id=self.asset_id, viser_asset_id=robot_asset.get_or_create_asset_id(link_name, self.color), position=link_pos, wxyz=link_wxyz, ) ) return viser_asset_states
@dataclass
[docs] class SceneState: """Scene state for multiple assets and frames.""" _num_frames: int = 0 _last_updated: float = 0.0 _robot_states: Dict[str, RobotState] = field(default_factory=dict) _mesh_states: Dict[str, MeshState] = field(default_factory=dict) _pc_states: Dict[str, PointCloudState] = field(default_factory=dict) _axes_states: Dict[str, AxesState] = field(default_factory=dict)
[docs] def set_robot_state( self, asset_id: str, joint_values: Optional[JointValues] = None, root_poses: Optional[Poses] = None, color: Color = _DEFAULT_ROBOT_COLOR, ): robot_state = self._robot_states.get(asset_id, RobotState(asset_id=asset_id)) robot_state.joint_values = joint_values robot_state.root_poses = root_poses robot_state.color = color if robot_state.num_frames > self._num_frames: self._num_frames = robot_state.num_frames self._robot_states[asset_id] = robot_state self._last_updated = time.time()
[docs] def set_mesh_state( self, asset_id: str, poses: Optional[Poses] = None, scale: float = 1.0, color: Optional[Color] = None, frame_range: Optional[Tuple[int, int]] = None, ): mesh_state = self._mesh_states.get(asset_id, MeshState(asset_id=asset_id)) mesh_state.poses = poses mesh_state.scale = scale mesh_state.color = color mesh_state.frame_range = frame_range # type: ignore if mesh_state.num_frames > self._num_frames: self._num_frames = mesh_state.num_frames self._mesh_states[asset_id] = mesh_state self._last_updated = time.time()
[docs] def set_point_cloud_state( self, asset_id: str, poses: Optional[Poses] = None, point_size: float = 0.02, color: Optional[Color] = None, frame_range: Optional[Tuple[int, int]] = None, ): pc_state = self._pc_states.get(asset_id, PointCloudState(asset_id=asset_id)) pc_state.poses = poses pc_state.point_size = point_size pc_state.color = color pc_state.frame_range = frame_range # type: ignore if pc_state.num_frames > self._num_frames: self._num_frames = pc_state.num_frames self._pc_states[asset_id] = pc_state self._last_updated = time.time()
[docs] def set_axes_state( self, asset_id: str, poses: Optional[Poses] = None, axes_length: float = 0.1, axes_radius: float = 0.005, ): axes_state = self._axes_states.get(asset_id, AxesState(asset_id=asset_id)) axes_state.poses = poses axes_state.axes_length = axes_length axes_state.axes_radius = axes_radius if axes_state.num_frames > self._num_frames: self._num_frames = axes_state.num_frames self._axes_states[asset_id] = axes_state self._last_updated = time.time()
@property
[docs] def num_frames(self) -> int: return self._num_frames
@property
[docs] def last_updated(self) -> float: return self._last_updated
[docs] def get_frame_viser_asset_states(self, frame_index: int) -> List[ViserAssetState]: viser_asset_states = [] for asset_state in chain( self._robot_states.values(), self._axes_states.values(), ): if asset_state.num_frames <= frame_index: continue viser_asset_states.extend(asset_state.get_frame_viser_asset_states(frame_index=frame_index)) for asset_state in chain(self._mesh_states.values(), self._pc_states.values()): if frame_index not in range(*asset_state.frame_range): # type: ignore continue viser_asset_states.extend(asset_state.get_frame_viser_asset_states(frame_index=frame_index)) return viser_asset_states
[docs] def __repr__(self) -> str: return ( f"SceneState(num_frames={self.num_frames}, num_robot_states={len(self._robot_states)}, num_mesh_states={len(self._mesh_states)}, " f"num_pc_states={len(self._pc_states)}, num_axes_states={len(self._axes_states)})" )
[docs] def __str__(self) -> str: return self.__repr__()
@dataclass
[docs] class StateManager: """State manager for multiple scenes.""" _scene_states: List[SceneState] = field(default_factory=list)
[docs] def set_robot_state( self, asset_id: str, scene_index: int, joint_values: Optional[JointValues] = None, root_poses: Optional[Poses] = None, color: Color = _DEFAULT_ROBOT_COLOR, ): self._get_scene_state(scene_index).set_robot_state(asset_id, joint_values, root_poses, color)
[docs] def set_mesh_state( self, asset_id: str, scene_index: int, poses: Optional[Poses] = None, scale: float = 1.0, color: Optional[Color] = None, frame_range: Optional[Tuple[int, int]] = None, ): self._get_scene_state(scene_index).set_mesh_state(asset_id, poses, scale, color, frame_range)
[docs] def set_point_cloud_state( self, asset_id: str, scene_index: int, poses: Optional[Poses] = None, point_size: float = 1.0, color: Optional[Color] = None, frame_range: Optional[Tuple[int, int]] = None, ): self._get_scene_state(scene_index).set_point_cloud_state(asset_id, poses, point_size, color, frame_range)
[docs] def set_axes_state( self, asset_id: str, scene_index: int, poses: Optional[Poses] = None, axes_length: float = 0.1, axes_radius: float = 0.005, ): self._get_scene_state(scene_index).set_axes_state(asset_id, poses, axes_length, axes_radius)
def _get_scene_state(self, scene_index: int) -> SceneState: if not self.validate_scene_index(scene_index): raise ValueError(f"Invalid scene index: {scene_index}") if scene_index == self.num_scenes: # new scenes to be added self._scene_states.append(SceneState()) return self._scene_states[scene_index]
[docs] def get_scene_num_frames(self, scene_index: int) -> int: if len(self._scene_states) == 0: return 0 return self._get_scene_state(scene_index).num_frames
[docs] def get_scene_last_updated(self, scene_index: int) -> float: return self._get_scene_state(scene_index).last_updated
[docs] def get_frame_viser_asset_states(self, scene_index: int, frame_index: int) -> List[ViserAssetState]: return self._get_scene_state(scene_index).get_frame_viser_asset_states(frame_index=frame_index)
@property
[docs] def num_scenes(self) -> int: return len(self._scene_states)
[docs] def validate_scene_index(self, scene_index: int) -> bool: """Validate the scene index. Valid scene index should be in the range [0, num_scenes], including both ends. When the scene index equals to num_scenes, it means the scene is the new scene to be added. Args: scene_index (int): Scene index. """ return 0 <= scene_index <= self.num_scenes
[docs] def reset(self): self._scene_states = []
[docs] def __getitem__(self, scene_index: int) -> SceneState: return self._get_scene_state(scene_index)
[docs] def __repr__(self) -> str: return f"StateManager(num_scenes={self.num_scenes})"
[docs] def __str__(self) -> str: return self.__repr__()
# thread-safe event for playing status
[docs] IS_PLAYING_EVENT = Event()
[docs] class ViserHelper: """Helper class for Viser server.""" def __init__( self, state_manager: StateManager, host: str = "localhost", port: int = 8080, enable_geometry_option: bool = False, ): self._state_manager = state_manager self._viser_server = viser.ViserServer(host=host, port=port) self._viser_server.gui.configure_theme(control_width="large") self._gui_scene_folder = self._viser_server.gui.add_folder("Scene") with self._gui_scene_folder: self._gui_scene_index = self._viser_server.gui.add_slider("Index", min=0, max=0, step=1, initial_value=0) self._gui_scene_index.on_update(self.update_server) self._gui_frame_folder = self._viser_server.gui.add_folder("Frame") with self._gui_frame_folder: self._gui_is_playing = self._viser_server.gui.add_checkbox("Playing", False) self._gui_frame_index = self._viser_server.gui.add_slider("Index", min=0, max=0, step=1, initial_value=0) self._gui_is_playing.on_update(self.update_server) self._gui_frame_index.on_update(self.update_server) # TODO: support max_value property reading for sliders in viser self._scene_index = 0 self._frame_index = 0 self._max_scene_index_on_gui = 0 self._max_frame_index_on_gui = 0 self._is_playing = False self._scene_last_updated = 0.0 IS_PLAYING_EVENT.clear() self._viser_asset_handles: Dict[str, viser.SceneNodeHandle] = {} self._visble_asset_handles: Dict[str, viser.SceneNodeHandle] = {}
[docs] def update_server(self, *args, **kwargs): """Update the Viser server.""" # TODO support max-value modification message on viser # TODO support client-side frame update for playing # i.e., save frame history on the client side and update the frame index based on the playing status # client only fetches the frame changes in the background and updates the changes to the history asynchronously # update max scene index on GUI if necessary max_scene_index_changed = self._max_scene_index_on_gui != self._state_manager.num_scenes - 1 if max_scene_index_changed: with self._viser_server.atomic(): with self._gui_scene_folder: _cur_scene_index = self._gui_scene_index.value self._gui_scene_index.remove() self._gui_scene_index = self._viser_server.gui.add_slider( "Index", min=0, max=max(self._state_manager.num_scenes - 1, 0), step=1, initial_value=max(0, min(_cur_scene_index, self._state_manager.num_scenes - 1)), ) self._gui_scene_index.on_update(self.update_server) self._max_scene_index_on_gui = self._state_manager.num_scenes - 1 # retrieve num_frames of the current scene scene_num_frames = self._state_manager.get_scene_num_frames(self._gui_scene_index.value) # remove & add `gui_frame_index` based on `is_playing` is_playing_changed = self._gui_is_playing.value != self._is_playing if is_playing_changed: with self._viser_server.atomic(): if self._gui_is_playing.value: # do not allow changing frame index when playing, remove the slider self._gui_frame_index.remove() IS_PLAYING_EVENT.set() else: # add the slider back with self._gui_frame_folder: self._gui_frame_index = self._viser_server.gui.add_slider( "Index", min=0, max=max(scene_num_frames - 1, 0), step=1, initial_value=max(0, min(self._frame_index, scene_num_frames - 1)), ) self._gui_frame_index.on_update(self.update_server) self._max_frame_index_on_gui = scene_num_frames - 1 IS_PLAYING_EVENT.clear() self._is_playing = self._gui_is_playing.value # update max frame index on GUI if necessary and not playing max_frame_index_changed = self._max_frame_index_on_gui != scene_num_frames - 1 and not self._is_playing if max_frame_index_changed: with self._viser_server.atomic(): self._gui_frame_index.remove() with self._gui_frame_folder: self._gui_frame_index = self._viser_server.gui.add_slider( "Index", min=0, max=max(scene_num_frames - 1, 0), step=1, initial_value=max(0, min(self._frame_index, scene_num_frames - 1)), ) self._gui_frame_index.on_update(self.update_server) self._max_frame_index_on_gui = scene_num_frames - 1 # -------------------- core logic -------------------- # update asset states if necessary expected_num_frames = ( self._gui_frame_index.value if not self._is_playing else max(0, min(self._frame_index + 1, scene_num_frames - 1)) ) scene_index_changed = self._scene_index != self._gui_scene_index.value frame_index_changed = self._frame_index != expected_num_frames scene_last_updated_changed = self._scene_last_updated != self._state_manager.get_scene_last_updated( self._gui_scene_index.value ) if scene_index_changed or frame_index_changed or scene_last_updated_changed: new_visible_asset_handles: Dict[str, viser.SceneNodeHandle] = {} # type: ignore frame_asset_states = self._state_manager.get_frame_viser_asset_states( scene_index=self._gui_scene_index.value, frame_index=expected_num_frames ) for asset_state in frame_asset_states: if asset_state.viser_asset_id not in self._viser_asset_handles: self._add_viser_asset_from_state(asset_state) else: self._update_viser_asset_from_state(asset_state) new_visible_asset_handles[asset_state.viser_asset_id] = self._viser_asset_handles[ asset_state.viser_asset_id ] # hide invisible assets for viser_asset_id, viser_asset_handle in self._visble_asset_handles.items(): if viser_asset_id not in new_visible_asset_handles: viser_asset_handle.visible = False self._visble_asset_handles = new_visible_asset_handles self._scene_index = self._gui_scene_index.value self._frame_index = expected_num_frames self._scene_last_updated = self._state_manager.get_scene_last_updated(self._gui_scene_index.value)
def _add_viser_asset_from_state(self, asset_state: ViserAssetState): viser_asset = ASSET_LIBRARY.get_viser_asset(asset_state.asset_id, asset_state.viser_asset_id) if viser_asset.viser_asset_type == "trimesh": asset_handle = self._viser_server.scene.add_mesh_trimesh( name=viser_asset.viser_asset_id, mesh=viser_asset.trimesh_mesh, # type: ignore position=asset_state.position, wxyz=asset_state.wxyz, scale=viser_asset.scale, ) elif viser_asset.viser_asset_type == "point_cloud": asset_handle = self._viser_server.scene.add_point_cloud( name=viser_asset.viser_asset_id, points=viser_asset.points, # type: ignore point_size=viser_asset.point_size, point_shape="circle", colors=viser_asset.color, # type: ignore wxyz=asset_state.wxyz, position=asset_state.position, ) elif viser_asset.viser_asset_type == "axes": asset_handle = self._viser_server.scene.add_batched_axes( name=viser_asset.viser_asset_id, batched_positions=viser_asset.axes_positions, # type: ignore batched_wxyzs=viser_asset.axes_wxyzs, # type: ignore axes_length=viser_asset.axes_length, axes_radius=viser_asset.axes_radius, ) self._viser_asset_handles[viser_asset.viser_asset_id] = asset_handle def _update_viser_asset_from_state(self, asset_state: ViserAssetState): viser_asset_handle = self._viser_asset_handles[asset_state.viser_asset_id] with self._viser_server.atomic(): # TODO support scale, color, etc in viser viser_asset_handle.position = asset_state.position viser_asset_handle.wxyz = asset_state.wxyz viser_asset_handle.visible = True
[docs] def reset(self): self._scene_index = 0 self._frame_index = 0 self._scene_last_updated = 0.0 # self._is_playing = False # IS_PLAYING_EVENT.clear() for viser_asset_handle in self._viser_asset_handles.values(): viser_asset_handle.visible = False self._visble_asset_handles = {}
[docs] class SimWebUI: """WebUI for simulator and 3D scene visualization.""" def __init__(self, host: str = "localhost", port: int = 8080, enable_geometry_option: bool = False): self._state_manager = StateManager() self._scene_index = 0 self._viser_helper = ViserHelper( state_manager=self._state_manager, host=host, port=port, enable_geometry_option=enable_geometry_option, )
[docs] def __repr__(self) -> str: return f"SimWebUI(scene_index={self._scene_index}, num_scenes={self._state_manager.num_scenes})"
[docs] def __str__(self) -> str: return self.__repr__()
@property
[docs] def scene_index(self) -> int: return self._scene_index
[docs] def set_scene_index(self, value: int): self._scene_index = value
[docs] def add_robot_asset( self, urdf_or_mjcf_path: Optional[Union[str, Path]] = None, mesh_dir: Optional[Union[str, Path]] = None, articulation: Optional["Articulation"] = None, ) -> str: """Add a robot asset to the asset library. Args: urdf_or_mjcf_path (Union[str, Path]): Path to the URDF or MJCF file of the robot. mesh_dir (Optional[Union[str, Path]], optional): Directory path of the robot meshes. Will use the directory of the URDF/MJCF file if not provided. Defaults to None. Returns: str: Asset ID of the robot asset. """ if urdf_or_mjcf_path is not None and isinstance(urdf_or_mjcf_path, Path): urdf_or_mjcf_path = str(urdf_or_mjcf_path) if mesh_dir is not None and isinstance(mesh_dir, Path): mesh_dir = str(mesh_dir) return ASSET_LIBRARY.add_robot_asset(urdf_or_mjcf_path, mesh_dir, articulation)
[docs] def add_mesh_asset( self, vertices: Optional[VerticesLike] = None, faces: Optional[FacesLike] = None, trimesh_mesh: Optional[trimesh.Trimesh] = None, mesh_path: Optional[Union[str, Path]] = None, disable_cache: bool = False, ) -> str: """Add a mesh asset to the asset library. Args: vertices (Optional[VerticesLike], optional): Vertices of the mesh. Defaults to None. faces (Optional[FacesLike], optional): Faces of the mesh. Defaults to None. trimesh_mesh (Optional[trimesh.Trimesh], optional): Trimesh mesh object. Defaults to None. mesh_path (Optional[Union[str, Path]], optional): Path to the mesh file. Defaults to None. Returns: str: Asset ID of the mesh asset. .. note:: Either trimesh_mesh or vertices and faces or mesh_path must be provided, but not both. """ trimesh_provided = trimesh_mesh is not None vertices_faces_provided = vertices is not None and faces is not None mesh_path_provided = mesh_path is not None if not (trimesh_provided ^ vertices_faces_provided ^ mesh_path_provided): raise ValueError("Either trimesh_mesh or vertices and faces or mesh_path must be provided.") if vertices_faces_provided: trimesh_mesh = trimesh.Trimesh(vertices=to_numpy(vertices), faces=to_numpy(faces)) elif mesh_path_provided: trimesh_mesh = trimesh.load(mesh_path, process=False, force="mesh") # type: ignore return ASSET_LIBRARY.add_mesh_asset(trimesh_mesh, disable_cache=disable_cache) # type: ignore
[docs] def add_sphere_asset(self, radius: ScaleLike, subdivisions: int = 3, disable_cache: bool = False) -> str: """Add a single sphere asset to the asset library. Args: radius (float): Radius of the sphere. subdivisions (int, optional): Number of subdivisions. Defaults to 3. Returns: str: Asset ID of the sphere asset. """ radius = to_number(radius) sphere_mesh = trimesh.creation.icosphere(radius=radius, subdivisions=subdivisions) return ASSET_LIBRARY.add_mesh_asset(sphere_mesh, disable_cache=disable_cache) # type: ignore
[docs] def add_point_cloud_asset(self, points: VerticesLike) -> str: """Add a point cloud asset to the asset library. Args: points (VerticesLike): Points of the point cloud. Returns: str: Asset ID of the point cloud asset. """ return ASSET_LIBRARY.add_point_cloud_asset(atleast_nd(to_numpy(points), expected_ndim=2))
[docs] def add_axes_asset(self, positions: AxesPositionsLike, rotations: AxesWXYZsLike) -> str: """Add an axes asset to the asset library. Args: positions (AxesPositionLike): Positions of the axes. rotations (AxesWXYZsLike): Rotations of the axes, could be in rotation matrices or wxyz quaternions. Returns: str: Asset ID of the axes asset. """ return ASSET_LIBRARY.add_axes_asset( positions=atleast_nd(to_numpy(positions), expected_ndim=2), wxyzs=atleast_nd(to_wxyz(rotations), expected_ndim=2), )
[docs] def set_robot_state( self, asset_id: str, scene_index: Optional[int] = None, joint_values: Optional[JointValuesLike] = None, root_poses: Optional[PosesLike] = None, color: ColorLike = "silver", ): """Set the state of a robot asset. Args: asset_id (str): Asset ID of the robot asset. joint_values (Optional[JointValuesT], optional): Multi-frame (or single-frame) joint values. Defaults to None. root_poses (Optional[PosesT], optional): Multi-frame (or single-frame) root poses. Defaults to None. scene_index (Optional[int], optional): Scene index. Defaults to None. .. note:: The number of frames of the asset state is determined by the number of joint values provided. """ if not ASSET_LIBRARY.asset_exists(asset_id): raise ValueError(f"Asset with ID '{asset_id}' does not exist.") scene_index = scene_index if scene_index is not None else self.scene_index joint_values = atleast_nd(to_numpy(joint_values), expected_ndim=2) root_poses = atleast_nd(to_numpy(root_poses), expected_ndim=3) color = to_color_array(color) if joint_values is not None and root_poses is not None: if joint_values.shape[0] != root_poses.shape[0]: raise ValueError( f"Number of frames mismatch, joint_values: {joint_values.shape[0]}, root_poses: {root_poses.shape[0]}" # type: ignore ) self._state_manager.set_robot_state( asset_id=asset_id, scene_index=scene_index, joint_values=joint_values, root_poses=root_poses, color=color, ) self._viser_helper.update_server()
[docs] def set_mesh_state( self, asset_id: str, scene_index: Optional[int] = None, poses: Optional[PosesLike] = None, scale: ScaleLike = 1.0, color: Optional[ColorLike] = None, frame_range: Optional[Union[int, Tuple[int, int]]] = None, ): """Set the state of a mesh asset. Args: asset_id (str): Asset ID of the mesh asset. scene_index (Optional[int], optional): Scene index. Defaults to None. poses (Optional[PosesLike], optional): Multi-frame (or single-frame) poses. Defaults to None. scale (ScaleLike, optional): Scale factor. Defaults to 1.0. color (ColorLike, optional): Color of the mesh, use random color if not provided. Defaults to None. .. note:: The number of frames of the asset state is determined by the number of poses provided. """ if not ASSET_LIBRARY.asset_exists(asset_id): raise ValueError(f"Asset with ID '{asset_id}' does not exist.") scene_index = scene_index if scene_index is not None else self.scene_index if poses is not None: expected_ndim = 2 if poses.shape[-1] == 7 else 3 else: expected_ndim = 3 poses = atleast_nd(to_numpy(poses), expected_ndim=expected_ndim) scale = to_number(scale) color = to_color_array(color) if frame_range is not None: frame_range = frame_range if isinstance(frame_range, tuple) else (0, frame_range) self._state_manager.set_mesh_state( asset_id=asset_id, scene_index=scene_index, poses=poses, scale=scale, color=color, frame_range=frame_range, ) self._viser_helper.update_server()
[docs] def set_sphere_state( self, asset_id: str, scene_index: Optional[int] = None, poses: Optional[PosesLike] = None, scale: ScaleLike = 1.0, color: Optional[ColorLike] = None, frame_range: Optional[Union[int, Tuple[int, int]]] = None, ): """Set the state of a sphere asset. Args: asset_id (str): Asset ID of the sphere asset. scene_index (Optional[int], optional): Scene index. Defaults to None. poses (Optional[PosesLike], optional): Multi-frame (or single-frame) poses. Defaults to None. scale (ScaleLike, optional): Scale factor. Defaults to 1.0. color (Optional[ColorLike], optional): Color of the sphere. Defaults to None. .. note:: The number of frames of the asset state is determined by the number of poses provided. """ self.set_mesh_state(asset_id, scene_index, poses, scale, color, frame_range)
[docs] def set_point_cloud_state( self, asset_id: str, scene_index: Optional[int] = None, poses: Optional[PosesLike] = None, point_size: ScaleLike = 0.02, color: Optional[ColorLike] = None, frame_range: Optional[Union[int, Tuple[int, int]]] = None, ): """Set the state of a point cloud asset. Args: asset_id (str): Asset ID of the point cloud asset. scene_index (Optional[int], optional): Scene index. Defaults to None. poses (Optional[PosesLike], optional): Multi-frame (or single-frame) poses. Defaults to None. scale (ScaleLike, optional): Scale factor. Defaults to 1.0. color (Optional[ColorLike], optional): Color of the point cloud. Defaults to None. .. note:: The number of frames of the asset state is determined by the number of poses provided. """ if not ASSET_LIBRARY.asset_exists(asset_id): raise ValueError(f"Asset with ID '{asset_id}' does not exist.") scene_index = scene_index if scene_index is not None else self.scene_index poses = atleast_nd(to_numpy(poses), expected_ndim=3) point_size = to_number(point_size) color = to_color_array(color) if frame_range is not None: frame_range = frame_range if isinstance(frame_range, tuple) else (0, frame_range) self._state_manager.set_point_cloud_state( asset_id=asset_id, scene_index=scene_index, poses=poses, point_size=point_size, color=color, frame_range=frame_range, ) self._viser_helper.update_server()
[docs] def set_axes_state( self, asset_id: str, scene_index: Optional[int] = None, axes_length: ScaleLike = 0.1, axes_radius: ScaleLike = 0.005, poses: Optional[PosesLike] = None, ): """Set the state of a point cloud asset. Args: asset_id (str): Asset ID of the point cloud asset. scene_index (Optional[int], optional): Scene index. Defaults to None. axes_length (ScaleLike, optional): Length of the axes. Defaults to 0.1. axes_radius (ScaleLike, optional): Radius of the axes. Defaults to 0.005. poses (Optional[PosesLike], optional): Multi-frame (or single-frame) poses. Defaults to None. .. note:: The number of frames of the asset state is determined by the number of poses provided. """ if not ASSET_LIBRARY.asset_exists(asset_id): raise ValueError(f"Asset with ID '{asset_id}' does not exist.") scene_index = scene_index if scene_index is not None else self.scene_index poses = atleast_nd(to_numpy(poses), expected_ndim=3) axes_length = to_number(axes_length) axes_radius = to_number(axes_radius) self._state_manager.set_axes_state( asset_id=asset_id, scene_index=scene_index, poses=poses, axes_length=axes_length, axes_radius=axes_radius, ) self._viser_helper.update_server()
[docs] def reset(self): """Reset the state manager and Viser server.""" self._state_manager.reset() self._viser_helper.reset() self._scene_index = 0 self._viser_helper.update_server()