Source code for fastdev.sim_webui.maniskill_webui

import time
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Counter as CounterT
from typing import Dict, List, cast

import numpy as np
import sapien.physx as physx
import trimesh
import trimesh.creation
from gymnasium import Env
from mani_skill.envs.scene import ManiSkillScene
from mani_skill.utils.geometry.trimesh_utils import merge_meshes
from sapien.physx import PhysxRigidBaseComponent

from fastdev.sim_webui.webui import (
    ASSET_LIBRARY,
    IS_PLAYING_EVENT,
    Color,
    MeshAsset,
    SimWebUI,
    StateManager,
    ViserAssetState,
    ViserHelper,
    get_random_color,
    to_color_array,
    to_position_wxyz,
)
from fastdev.utils.profile import timeit
from fastdev.utils.tensor import to_numpy
from fastdev.utils.tui import log_once


# NOTE somehow `get_component_meshes` from `mani_skill.utils.geometry.trimesh_utils` is not correct
[docs] def get_component_meshes(component: physx.PhysxRigidBaseComponent): """Get component (collision) meshes in the component's frame.""" meshes = [] for geom in component.get_collision_shapes(): if isinstance(geom, physx.PhysxCollisionShapeBox): mesh = trimesh.creation.box(extents=2 * geom.half_size) elif isinstance(geom, physx.PhysxCollisionShapeCapsule): extra_trimesh_tf_mat = np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]], dtype=np.float32) mesh = trimesh.creation.capsule( radius=geom.radius, height=2 * geom.half_length, transform=extra_trimesh_tf_mat ) elif isinstance(geom, physx.PhysxCollisionShapeCylinder): extra_trimesh_tf_mat = np.array([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]], dtype=np.float32) mesh = trimesh.creation.cylinder( radius=geom.radius, height=2 * geom.half_length, transform=extra_trimesh_tf_mat ) elif isinstance(geom, physx.PhysxCollisionShapeSphere): mesh = trimesh.creation.icosphere(radius=geom.radius) elif isinstance(geom, physx.PhysxCollisionShapePlane): continue elif isinstance(geom, (physx.PhysxCollisionShapeConvexMesh)): vertices = geom.vertices # [n, 3] faces = geom.get_triangles() vertices = vertices * geom.scale mesh = trimesh.Trimesh(vertices=vertices, faces=faces) elif isinstance(geom, physx.PhysxCollisionShapeTriangleMesh): vertices = geom.vertices faces = geom.get_triangles() vertices = vertices * geom.scale mesh = trimesh.Trimesh(vertices=vertices, faces=faces) else: raise TypeError(type(geom)) mesh.apply_transform(geom.get_local_pose().to_transformation_matrix()) meshes.append(mesh) return meshes
@dataclass
[docs] class ManiSkillStateManger(StateManager): _num_scenes: int = 0 _scene_last_updated: float = field(default_factory=time.time)
[docs] component_poses_history: List[Dict[str, np.ndarray]] = field(default_factory=list)
[docs] component_colors: Dict[str, Color] = field(default_factory=dict)
[docs] component_name_to_asset_id: Dict[str, str] = field(default_factory=dict)
[docs] component_name_to_scene_indices: Dict[str, List[int]] = field(default_factory=lambda: defaultdict(list))
[docs] scene_index_to_component_names: Dict[int, List[str]] = field(default_factory=lambda: defaultdict(list))
[docs] def get_scene_num_frames(self, scene_index: int) -> int: return len(self.component_poses_history)
@lru_cache(maxsize=256)
[docs] def get_frame_viser_asset_states(self, scene_index: int, frame_index: int) -> List[ViserAssetState]: if frame_index >= len(self.component_poses_history): return [] viser_asset_states: List[ViserAssetState] = [] asset_ids: CounterT[str] = Counter() for component_name in self.scene_index_to_component_names[scene_index]: asset_id = self.component_name_to_asset_id[component_name] data_offset = self.component_name_to_scene_indices[component_name].index(scene_index) pose = self.component_poses_history[frame_index][component_name][data_offset] asset_id = self.component_name_to_asset_id[component_name] mesh_asset: MeshAsset = cast(MeshAsset, ASSET_LIBRARY.get_asset(asset_id)) pos, wxyz = to_position_wxyz(pose) color = self.component_colors.get(component_name, to_color_array("silver")) viser_asset_state = ViserAssetState( asset_id=asset_id, viser_asset_id=mesh_asset.get_or_create_asset_id(color, 1.0, postfix=asset_ids[asset_id]), position=pos, wxyz=wxyz, ) viser_asset_states.append(viser_asset_state) asset_ids[asset_id] += 1 return viser_asset_states
@property
[docs] def num_scenes(self) -> int: return self._num_scenes
@num_scenes.setter def num_scenes(self, num_scenes: int): self._num_scenes = num_scenes
[docs] def get_scene_last_updated(self, scene_index: int) -> float: return self._scene_last_updated
[docs] def reset(self): self.component_poses_history = [] self.component_name_to_asset_id = {} self.component_colors = {} self.component_name_to_scene_indices = defaultdict(list) self.scene_index_to_component_names = defaultdict(list) self._scene_last_updated = time.time() self.get_frame_viser_asset_states.cache_clear()
[docs] def set_robot_state(self, *args, **kwargs): raise ValueError("set_robot_state is not supported in ManiSkillStateManger")
[docs] def set_mesh_state(self, *args, **kwargs): raise ValueError("set_mesh_state is not supported in ManiSkillStateManger")
[docs] def set_point_cloud_state(self, *args, **kwargs): raise ValueError("set_point_cloud_state is not supported in ManiSkillStateManger")
[docs] def __getitem__(self, scene_index: int): raise ValueError("ManiSkillStateManger does not support __getitem__")
[docs] def __hash__(self) -> int: return hash((self._num_scenes, self._scene_last_updated))
[docs] def __repr__(self) -> str: return f"ManiSkillStateManger(num_scenes={self._num_scenes})"
[docs] def __str__(self) -> str: return self.__repr__()
[docs] class ManiSkillWebUI(SimWebUI): def __init__(self, env: Env, host: str = "localhost", port: int = 8080, disable_cache: bool = False): self._state_manager = ManiSkillStateManger() self._viser_helper = ViserHelper( state_manager=self._state_manager, host=host, port=port, enable_geometry_option=False, ) self._disable_cache = disable_cache if not isinstance(env, Env): raise ValueError(f"env must be an instance of gymnasium.Env, got {type(env)}") if not isinstance(env.unwrapped.scene, ManiSkillScene): # type: ignore raise ValueError(f"env must have a scene of type ManiSkillScene, got {type(env.unwrapped.scene)}") # type: ignore self._env: Env = env self._state_manager.num_scenes = self._env.unwrapped.num_envs # type: ignore # get assets and poses from the simulation self._get_assets_from_sim() self._get_asset_poses_from_sim() # override step function self._override_step() self._override_reset() # update the viser server self._viser_helper.update_server() def _override_step(self): ori_step_fn = self._env.step def step_fn(*args, **kwargs): log_once("The simulation will be paused if the web UI is not playing") IS_PLAYING_EVENT.wait() ori_ret = ori_step_fn(*args, **kwargs) self._get_asset_poses_from_sim() self._viser_helper.update_server() return ori_ret # monkeypatch the step function self._env.step = step_fn def _override_reset(self): ori_reset_fn = self._env.reset def reset_fn(*args, **kwargs): ori_ret = ori_reset_fn(*args, **kwargs) self._state_manager.reset() self._viser_helper.reset() self._get_assets_from_sim() self._get_asset_poses_from_sim() self._viser_helper.update_server() return ori_ret # monkeypatch the reset function in sapien self._env.reset = reset_fn @timeit("ManiSkillWebUI.get_assets_from_sim") def _get_assets_from_sim(self): scene: ManiSkillScene = self._env.unwrapped.scene # type: ignore # add robots for arti_name, articulation in scene.articulations.items(): art_scene_idxs = articulation._scene_idxs.tolist() for link_name, link in articulation.links_map.items(): # it manages multiple sapien articulation objects, we only need the first one link_mesh = merge_meshes(get_component_meshes(link._objs[0])) if link_mesh is None: continue component_name = f"arti_{arti_name}/{link_name}" asset_id = self.add_mesh_asset(trimesh_mesh=link_mesh, disable_cache=self._disable_cache) self._state_manager.component_name_to_asset_id[component_name] = asset_id self._state_manager.component_colors[component_name] = to_color_array("silver") self._state_manager.component_name_to_scene_indices[component_name] = art_scene_idxs for scene_idx in art_scene_idxs: self._state_manager.scene_index_to_component_names[scene_idx].append(component_name) # add objects for actor_name, actor in scene.actors.items(): act_scene_idxs = actor._scene_idxs.tolist() component_name = f"actor_{actor_name}" act = actor._objs[0] act_meshes = [] for comp in act.components: if isinstance(comp, PhysxRigidBaseComponent): comp_mesh = merge_meshes(get_component_meshes(comp)) if comp_mesh is not None: act_meshes.append(comp_mesh) if len(act_meshes) == 0: continue asset_id = self.add_mesh_asset(trimesh_mesh=merge_meshes(act_meshes), disable_cache=self._disable_cache) self._state_manager.component_name_to_asset_id[component_name] = asset_id self._state_manager.component_colors[component_name] = get_random_color() self._state_manager.component_name_to_scene_indices[component_name] = act_scene_idxs for scene_idx in act_scene_idxs: self._state_manager.scene_index_to_component_names[scene_idx].append(component_name) def _get_asset_poses_from_sim(self): scene: ManiSkillScene = self._env.unwrapped.scene # type: ignore component_poses = {} for arti_name, articulation in scene.articulations.items(): for link_name, link in articulation.links_map.items(): component_name = f"arti_{arti_name}/{link_name}" if component_name not in self._state_manager.component_name_to_asset_id: continue link_pose = to_numpy(link.pose.raw_pose) component_poses[component_name] = link_pose for actor_name, actor in scene.actors.items(): component_name = f"actor_{actor_name}" actor_pose = to_numpy(actor.pose.raw_pose) component_poses[component_name] = actor_pose self._state_manager.component_poses_history.append(component_poses)
[docs] def set_robot_state(self, *args, **kwargs): raise ValueError("set_robot_state is not supported in ManiSkillWebUI")
[docs] def set_mesh_state(self, *args, **kwargs): raise ValueError("set_mesh_state is not supported in ManiSkillWebUI")
[docs] def set_point_cloud_state(self, *args, **kwargs): raise ValueError("set_point_cloud_state is not supported in ManiSkillWebUI")
[docs] def __repr__(self) -> str: return f"ManiSkillWebUI(num_scenes={self._state_manager.num_scenes})"
[docs] def __str__(self) -> str: return self.__repr__()