Source code for fastdev.datasets.dexycb

import os
import pickle
from dataclasses import dataclass
from glob import glob
from typing import Dict, List, Literal, Optional

import cv2
import numpy as np
import torch
import trimesh
from huggingface_hub import hf_hub_download
from rich.progress import track
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, default_collate
from torchvision.io import read_image

from fastdev.constants import FDEV_DATASET_ROOT, FDEV_HF_REPO_ID
from fastdev.io import load
from fastdev.io.archive import extract_archive
from fastdev.utils import timeit

_SUBJECTS = [
    "20200709-subject-01",
    "20200813-subject-02",
    "20200820-subject-03",
    "20200903-subject-04",
    "20200908-subject-05",
    "20200918-subject-06",
    "20200928-subject-07",
    "20201002-subject-08",
    "20201015-subject-09",
    "20201022-subject-10",
]

_SERIALS = [
    "836212060125",
    "839512060362",
    "840412060917",
    "841412060263",
    "932122060857",
    "932122060861",
    "932122061900",
    "932122062010",
]

_YCB_CLASSES = {
    1: "002_master_chef_can",
    2: "003_cracker_box",
    3: "004_sugar_box",
    4: "005_tomato_soup_can",
    5: "006_mustard_bottle",
    6: "007_tuna_fish_can",
    7: "008_pudding_box",
    8: "009_gelatin_box",
    9: "010_potted_meat_can",
    10: "011_banana",
    11: "019_pitcher_base",
    12: "021_bleach_cleanser",
    13: "024_bowl",
    14: "025_mug",
    15: "035_power_drill",
    16: "036_wood_block",
    17: "037_scissors",
    18: "040_large_marker",
    19: "051_large_clamp",
    20: "052_extra_large_clamp",
    21: "061_foam_brick",
}


_DEXYCB_SIMPLIFIED_HF_FILENAME = "dexycb_simplified.zip"


@dataclass(frozen=True)
[docs] class DexYCBDatasetConfig: """Configuration for DexYCBDataset."""
[docs] data_root: str = os.path.join(FDEV_DATASET_ROOT, "dexycb")
[docs] hand_side: Literal["left", "right", "both"] = "left"
[docs] setup: Literal["s0", "s1", "s2", "s3"] = "s0"
""" - [s0] Seen subjects, camera views, grasped objects, unseen frames - [s1] Unseen subjects - [s2] Unseen camera views - [s3] Unseen grasped objects Reference_ .. _Reference: https://github.com/NVlabs/dex-ycb-toolkit/blob/64551b001d360ad83bc383157a559ec248fb9100/dex_ycb_toolkit/dex_ycb.py#L126 """
[docs] serials: Optional[List[str]] = None # specify serials to load
[docs] return_color: bool = False
[docs] return_depth: bool = False
[docs] return_object_mesh: bool = False # return object mesh vertices and faces in the object space
[docs] return_as_sequence: bool = False # return the whole sequence instead of a single frame
[docs] return_no_hand_frames: bool = True
# if False, filter out frames where the hand is not fully visible on image (i.e., mano pose non-zero and all joints inside the frame)
[docs] download_simplified_data_if_not_exist: bool = False
# download simplified data (pose, mano, object mesh) if not exist # pkl is faster than npz, suitable for caching
[docs] force_rebuild_cache: bool = False
[docs] def __post_init__(self) -> None: if not os.path.exists(self.data_root) and self.download_simplified_data_if_not_exist: if self.return_color or self.return_depth: raise ValueError("Cannot return color or depth with simplified data.") if self.serials is not None: assert all(s in _SERIALS for s in self.serials), f"Serials {self.serials} invalid."
[docs] class DexYCBDataset(Dataset): """DexYCBDataset. Args: config (DexYCBDatasetConfig): configuration for the dataset. Examples: >>> # doctest: +SKIP >>> from fastdev.datasets.dexycb import DexYCBDataset, DexYCBDatasetConfig >>> dataset = DexYCBDataset(DexYCBDatasetConfig()) """
[docs] ycb_classes = _YCB_CLASSES
def __init__(self, config: DexYCBDatasetConfig, split: Literal["train", "val", "test", "all"] = "train"):
[docs] self.config = config
if not os.path.exists(self.config.data_root) and self.config.download_simplified_data_if_not_exist: self.download_data(self.config.data_root)
[docs] self.split = split
self._data_dir = self.config.data_root self._calib_dir = os.path.join(self._data_dir, "calibration") self._model_dir = os.path.join(self._data_dir, "models") self._color_format = "color_{:06d}.jpg" self._depth_format = "aligned_depth_to_color_{:06d}.png" self._label_format = "labels_{:06d}.npz" self._h = 480 self._w = 640 self._obj_file = {k: os.path.join(self._model_dir, v, "textured_simple.obj") for k, v in _YCB_CLASSES.items()} self._obj_mesh_cache = {} metadata_cache_path = os.path.join(self._data_dir, "metadata_cache.pkl") # ----------- Cache Metadata ----------- if (not os.path.exists(metadata_cache_path)) or self.config.force_rebuild_cache: intrinsics = {} for serial in _SERIALS: intr_file = os.path.join( self._calib_dir, "intrinsics", "{}_{}x{}.yml".format(serial, self._w, self._h), ) intr = load(intr_file) intrinsics[serial] = intr["color"] extrinsics = {} extr_dirs = sorted(glob(os.path.join(self._calib_dir, "extrinsics_*"))) for extr_dir in extr_dirs: extr = load(os.path.join(extr_dir, "extrinsics.yml")) extrinsics[os.path.basename(extr_dir).split("extrinsics_")[-1]] = extr mano_calibs = {} mano_calib_dirs = sorted(glob(os.path.join(self._calib_dir, "mano_*"))) for mano_calib_dir in mano_calib_dirs: mano_calib = load(os.path.join(mano_calib_dir, "mano.yml")) mano_calibs[os.path.basename(mano_calib_dir).split("mano_")[-1]] = mano_calib seq_metas = {} for subject in track(_SUBJECTS, description="Caching seq metadata"): seqs = sorted(os.listdir(os.path.join(self._data_dir, subject))) seqs = [os.path.join(subject, seq) for seq in seqs] assert len(seqs) == 100, "Each subject in DexYCB should have 100 sequences." for seq in seqs: seq_meta = load(os.path.join(self._data_dir, seq, "meta.yml")) seq_metas[seq] = seq_meta seq_metas[seq]["valid_frame_range"] = {} seq_metas[seq]["joint_2d"], seq_metas[seq]["joint_3d"] = {}, {} seq_metas[seq]["pose_m"], seq_metas[seq]["pose_y"] = {}, {} seq_pose_m = np.load(os.path.join(self._data_dir, seq, "pose.npz"))["pose_m"][:, 0] non_zero_pose_frames = np.any(seq_pose_m != 0, axis=-1) for serial in _SERIALS: label_files = sorted(glob(os.path.join(self._data_dir, seq, serial, "labels_*.npz"))) label_data = [np.load(label_file) for label_file in label_files] joint_2d = np.stack([label["joint_2d"][0] for label in label_data]) joints_inside_frames = np.all( np.logical_and( np.logical_and(joint_2d[..., 0] >= 0, joint_2d[..., 0] < self._w), np.logical_and(joint_2d[..., 1] >= 0, joint_2d[..., 1] < self._h), ), axis=-1, ) valid_frames = np.logical_and(non_zero_pose_frames, joints_inside_frames) if np.any(valid_frames): valid_begin_frame, valid_end_frame = np.where(valid_frames)[0][[0, -1]] else: # no valid frames valid_begin_frame, valid_end_frame = -1, -1 # range is always half-open seq_metas[seq]["valid_frame_range"][serial] = ( valid_begin_frame, valid_end_frame + 1, ) seq_metas[seq]["joint_2d"][serial] = joint_2d seq_metas[seq]["joint_3d"][serial] = np.stack([label["joint_3d"][0] for label in label_data]) seq_metas[seq]["pose_m"][serial] = np.stack([label["pose_m"][0] for label in label_data]) seq_metas[seq]["pose_y"][serial] = np.stack([label["pose_y"][0] for label in label_data]) os.makedirs(os.path.dirname(metadata_cache_path), exist_ok=True) with open(metadata_cache_path, "wb") as f: pickle.dump( { "intrinsics": intrinsics, "extrinsics": extrinsics, "mano_calibs": mano_calibs, "seq_metas": seq_metas, }, f, ) with timeit(fn_or_print_tmpl="Loading metadata cache: {:.4f}s"): with open(metadata_cache_path, "rb") as f: metadata_cache = pickle.load(f) self._intrinsics = metadata_cache["intrinsics"] self._extrinsics = metadata_cache["extrinsics"] self._mano_calibs = metadata_cache["mano_calibs"] cached_seq_metas = metadata_cache["seq_metas"] # ----------- Filter sequences, serials, frames ----------- # Seen subjects, camera views, grasped objects. if self.config.setup == "s0": if self.split == "train": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = [i for i in range(100) if i % 5 != 4] if self.split == "val": subject_ind = [0, 1] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = [i for i in range(100) if i % 5 == 4] if self.split == "test": subject_ind = [2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = [i for i in range(100) if i % 5 == 4] # Unseen subjects. if self.config.setup == "s1": if self.split == "train": subject_ind = [0, 1, 2, 3, 4, 5, 9] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = list(range(100)) if self.split == "val": subject_ind = [6] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = list(range(100)) if self.split == "test": subject_ind = [7, 8] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = list(range(100)) # Unseen camera views. if self.config.setup == "s2": if self.split == "train": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [0, 1, 2, 3, 4, 5] sequence_ind = list(range(100)) if self.split == "val": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [6] sequence_ind = list(range(100)) if self.split == "test": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [7] sequence_ind = list(range(100)) # Unseen grasped objects. if self.config.setup == "s3": if self.split == "train": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = [i for i in range(100) if i // 5 not in (3, 7, 11, 15, 19)] if self.split == "val": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = [i for i in range(100) if i // 5 in (3, 19)] if self.split == "test": subject_ind = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] serial_ind = [0, 1, 2, 3, 4, 5, 6, 7] sequence_ind = [i for i in range(100) if i // 5 in (7, 11, 15)] if self.split == "all": subject_ind = list(range(10)) serial_ind = list(range(8)) sequence_ind = list(range(100)) self._subjects = [_SUBJECTS[i] for i in subject_ind] # type: ignore self._serials = [_SERIALS[i] for i in serial_ind] if self.config.serials is None else self.config.serials # type: ignore self._sequences, self._seq_metas = [], [] for subject in self._subjects: seqs = sorted([seq for seq in cached_seq_metas.keys() if subject in seq]) seqs = [seqs[i] for i in sequence_ind] # type: ignore for seq in seqs: meta = cached_seq_metas[seq] assert len(meta["mano_sides"]) == 1, "Each sequence in DexYCB only has one mano side." if self.config.hand_side != "both" and meta["mano_sides"][0] != self.config.hand_side: continue else: self._sequences.append(seq) self._seq_metas.append(meta) self._mapping = [] # type: ignore for seq_idx, (seq, meta) in enumerate(zip(self._sequences, self._seq_metas)): for serial_idx, serial in enumerate(self._serials): # range is always half-open valid_frame_range = meta["valid_frame_range"][serial] if valid_frame_range == (-1, 0) and not self.config.return_no_hand_frames: # no valid frames continue if not self.config.return_no_hand_frames: frame_range = valid_frame_range else: frame_range = (0, meta["num_frames"]) if not self.config.return_as_sequence: frame_indices = np.arange(*frame_range) seq_indices = seq_idx * np.ones_like(frame_indices) serial_indices = serial_idx * np.ones_like(frame_indices) mapping = np.vstack((seq_indices, serial_indices, frame_indices)).T self._mapping.append(mapping) # type: ignore else: self._mapping.append(np.array([seq_idx, serial_idx])) # type: ignore self._mapping: np.ndarray = np.vstack(self._mapping) # type: ignore
[docs] def __len__(self): return len(self._mapping)
def _load_frame_data(self, data_dir, frame_idx): frame_data = {} if self.config.return_color: color_file = os.path.join(data_dir, self._color_format.format(frame_idx)) frame_data["color"] = read_image(color_file) if self.config.return_depth: depth_file = os.path.join(data_dir, self._depth_format.format(frame_idx)) depth = cv2.imread(depth_file, cv2.IMREAD_ANYDEPTH) frame_data["depth"] = torch.from_numpy(depth.astype(np.float32) / 1000) return frame_data def _load_object_mesh(self, ycb_id: int) -> Dict[Literal["vertices", "faces"], torch.Tensor]: if ycb_id not in self._obj_mesh_cache: obj_file = self._obj_file[ycb_id] obj_mesh = trimesh.load(obj_file, process=False, force="mesh") vertices = torch.tensor(obj_mesh.vertices, dtype=torch.float32) # type: ignore faces = torch.tensor(obj_mesh.faces, dtype=torch.int32) # type: ignore self._obj_mesh_cache[ycb_id] = {"vertices": vertices, "faces": faces} return self._obj_mesh_cache[ycb_id]
[docs] def __getitem__(self, idx): if not self.config.return_as_sequence: seq_idx, serial_idx, frame_idx = self._mapping[idx] # range is always half-open frame_range = (frame_idx, frame_idx + 1) else: seq_idx, serial_idx = self._mapping[idx] if self.config.return_no_hand_frames: frame_range = (0, self._seq_metas[seq_idx]["num_frames"]) else: frame_range = self._seq_metas[seq_idx]["valid_frame_range"][self._serials[serial_idx]] data_dir = os.path.join(self._data_dir, self._sequences[seq_idx], self._serials[serial_idx]) seq_meta = self._seq_metas[seq_idx] extr = self._extrinsics[seq_meta["extrinsics"]]["extrinsics"][self._serials[serial_idx]] extr = torch.tensor(extr, dtype=torch.float32).reshape(3, 4) sample = { "sequence": self._sequences[seq_idx], "serial": self._serials[serial_idx], "frame_range": torch.tensor(frame_range, dtype=torch.int32), "intrinsics": self._intrinsics[self._serials[serial_idx]], "extrinsics": extr, "ycb_grasp_id": seq_meta["ycb_ids"][seq_meta["ycb_grasp_ind"]], "mano_side": seq_meta["mano_sides"][0], "mano_betas": torch.tensor( self._mano_calibs[seq_meta["mano_calib"][0]]["betas"], dtype=torch.float32, ), } pose_y = torch.from_numpy( seq_meta["pose_y"][self._serials[serial_idx]][frame_range[0] : frame_range[1]] ).float() sample["object_pose"] = pose_y object_mesh = self._load_object_mesh(sample["ycb_grasp_id"]) sample["object_vertices"] = object_mesh["vertices"] sample["object_faces"] = object_mesh["faces"] frames_data = [self._load_frame_data(data_dir, frame_idx) for frame_idx in range(*frame_range)] for key in frames_data[0].keys(): sample[key] = torch.stack([frame_data[key] for frame_data in frames_data], dim=0) if not self.config.return_as_sequence: sample[key].squeeze_(0) pose_m = torch.from_numpy( seq_meta["pose_m"][self._serials[serial_idx]][frame_range[0] : frame_range[1]] ).float() sample["mano_global_orient"] = pose_m[..., :3] sample["mano_hand_pose"] = pose_m[..., 3:48] sample["mano_transl"] = pose_m[..., 48:51] if not self.config.return_as_sequence: sample["mano_global_orient"].squeeze_(0) sample["mano_hand_pose"].squeeze_(0) sample["mano_transl"].squeeze_(0) sample["object_pose"].squeeze_(0) return sample
@property
[docs] def h(self): return self._h
@property
[docs] def w(self): return self._w
@property
[docs] def subjects(self): return self._subjects
@property
[docs] def serials(self): return self._serials
@property
[docs] def sequences(self): return self._sequences
[docs] def select_indices(self, seq: str, serial: Optional[str] = None): assert seq in self._sequences, f"Sequence {seq} not found in the dataset." if serial is None: return np.where(self._mapping[:, 0] == self._sequences.index(seq))[0] else: assert serial in self._serials, f"Serial {serial} not found in the dataset." return np.where( np.logical_and( self._mapping[:, 0] == self._sequences.index(seq), self._mapping[:, 1] == self._serials.index(serial), ) )[0]
[docs] def build_collate_fn(self): def _seq_collate_fn(batch): collated_batch = {} for key in batch[0].keys(): if key in [ "color", "depth", "mano_global_orient", "mano_hand_pose", "mano_transl", ]: collated_batch[key] = pad_sequence([sample[key] for sample in batch], batch_first=True) else: collated_batch[key] = default_collate([sample[key] for sample in batch]) return collated_batch if not self.config.return_as_sequence: return default_collate else: return _seq_collate_fn
@staticmethod
[docs] def download_data(data_root: str): os.makedirs(data_root, exist_ok=True) hf_hub_download( repo_id=FDEV_HF_REPO_ID, filename=_DEXYCB_SIMPLIFIED_HF_FILENAME, repo_type="dataset", local_dir=data_root ) extract_archive(os.path.join(data_root, _DEXYCB_SIMPLIFIED_HF_FILENAME), data_root, remove_top_dir=True) os.remove(os.path.join(data_root, _DEXYCB_SIMPLIFIED_HF_FILENAME))
__all__ = ["DexYCBDataset", "DexYCBDatasetConfig"]