from__future__importannotationsimportglobimportloggingimportosimportshutilfromtypingimportAny,Dict,Literalimportrichfromfastdev.constantsimportFDEV_CSRC_ROOTfromfastdev.utils.cudaimportcuda_toolkit_available,current_cuda_archfromfastdev.utils.profileimporttimeitlogger=logging.getLogger("fastdev")console=rich.get_console()os.environ["TORCH_CUDA_ARCH_LIST"]=current_cuda_arch()_MODULE_SRC_DIR:Dict[str,str]={"fastdev_point_mesh":"point_mesh","fastdev_sample_farthest_points":"sample_farthest_points","fastdev_knn":"knn","fastdev_ball_query":"ball_query",}classLazyExtension:def__init__(self)->None:self._extra_include_paths:list[str]=[FDEV_CSRC_ROOT]self._extra_cflags=["-O3"]self._extra_cuda_cflags=["-O3"]self._cuda_toolkit_available=cuda_toolkit_available()ifself._cuda_toolkit_available:self._extra_cflags.append("-DWITH_CUDA")self._extra_cuda_cflags.append("-DWITH_CUDA")self._module:dict[str,Any]={}defload_module(self,module_name:Literal["fastdev_point_mesh","fastdev_sample_farthest_points","fastdev_knn","fastdev_ball_query",],)->Any:fromtorch.utils.cpp_extensionimport_get_build_directory,loadifmodule_namenotinself._module:build_dir=_get_build_directory(module_name,verbose=False)sources=[]sources.extend(glob.glob(os.path.join(FDEV_CSRC_ROOT,_MODULE_SRC_DIR[module_name],"**/*.cpp"),recursive=True))ifself._cuda_toolkit_available:sources.extend(glob.glob(os.path.join(FDEV_CSRC_ROOT,_MODULE_SRC_DIR[module_name],"**/*.cu"),recursive=True))try:ifos.listdir(build_dir)!=[]:# If the build exists, we assume the extension has been built and we can load it.withtimeit(f"Loading {module_name} extension"):self._module[module_name]=load(name=module_name,sources=sources,extra_cflags=self._extra_cflags,extra_cuda_cflags=self._extra_cuda_cflags,extra_include_paths=self._extra_include_paths,)else:# Build from scratch. Remove the build directory just to be safe: pytorch jit might stuck# if the build directory exists.shutil.rmtree(build_dir,ignore_errors=True)withtimeit(f"Building {module_name} extension"),console.status(f"Building {module_name} extension (This may take a few minutes the first time)",spinner="bouncingBall",):self._module[module_name]=load(name=module_name,sources=sources,extra_cflags=self._extra_cflags,extra_cuda_cflags=self._extra_cuda_cflags,extra_include_paths=self._extra_include_paths,)exceptExceptionase:logger.error(f"[bold red]Error building {module_name} extension: {e}")logger.error(f"The {module_name} will not be available.")self._module[module_name]=Nonereturnself._module[module_name]