[docs]defget_human_readable_count(number:Union[int,float])->str:"""Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively. Examples: >>> get_human_readable_count(123) '123 ' >>> get_human_readable_count(1234) # (one thousand) '1.2 K' >>> get_human_readable_count(2e6) # (two million) '2.0 M' >>> get_human_readable_count(3e9) # (three billion) '3.0 B' >>> get_human_readable_count(4e14) # (four hundred trillion) '400 T' >>> get_human_readable_count(5e15) # (more than trillion) '5,000 T' Args: number: a positive integer number Return: A string formatted according to the pattern described above. """assertnumber>=0labels=PARAMETER_NUM_UNITSnum_digits=int(math.floor(math.log10(number))+1ifnumber>0else1)num_groups=int(math.ceil(num_digits/3))num_groups=min(num_groups,len(labels))# don't abbreviate beyond trillionsshift=-3*(num_groups-1)number=number*(10**shift)index=num_groups-1ifindex<1ornumber>=100:returnf"{int(number):,d}{labels[index]}"returnf"{number:,.1f}{labels[index]}"
[docs]defsummarize_model(model:nn.Module,max_depth:int=1):ifnotisinstance(max_depth,int)ormax_depth<-1:raiseValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")mods:List[Tuple[str,nn.Module]]ifmax_depth==0:mods=[]elifmax_depth==1:# the children are the top-level modulesmods=list(model.named_children())else:mods=model.named_modules()mods=list(mods)[1:]# do not include root module (nn.Module)ifmax_depth>=1:mods=[mforminmodsifm[0].count(".")<max_depth]# do not count modules without parametersmod_num_params={name:sum(p.numel()forpinmod.parameters())forname,modinmods}mod_num_trainable_params={name:sum(p.numel()forpinmod.parameters()ifp.requires_grad)forname,modinmods}table=Table(title=f"Model Summary: [bold magenta]{model.__class__.__name__}[/]",header_style="bold magenta")table.add_column("",style="dim")table.add_column("Name",justify="left",no_wrap=True)table.add_column("Type")table.add_column("Params",justify="right")table.add_column("Trainable Params",justify="right")foridx,(name,mod)inenumerate(mods):ifmod_num_params[name]==0:continuetable.add_row(str(idx),name,mod.__class__.__name__,get_human_readable_count(mod_num_params[name]),get_human_readable_count(mod_num_trainable_params[name]),)console.print(table)total_params=sum(mod_num_params.values())total_trainable_params=sum(mod_num_trainable_params.values())grid=Table.grid(expand=True)grid.add_column()grid.add_column()grid.add_row(f"[bold]Total params[/]: {get_human_readable_count(total_params)}")grid.add_row(f"[bold]Trainable params[/]: {get_human_readable_count(total_trainable_params)}")grid.add_row(f"[bold]Non-trainable params[/]: {get_human_readable_count(total_params-total_trainable_params)}")console.print(grid)