layers=[]foriinrange(len(self.dims)-1):block=[nn.Linear(self.dims[i],self.dims[i+1])]ifuse_normalization:block.append(normalization_layer(self.dims[i+1]))# type: ignore# Only add activation if it's not the last layer (or if activation_on_output is True).ifi<len(self.dims)-2oractivation_on_output:ifactivation_layerisnotNone:block.append(activation_layer(inplace=True))# type: ignorelayers.append(nn.Sequential(*block))
# Set up optional output skip connectionifresidual_on_output:self.skip_output=nn.Linear(input_dim,output_dim)ifinput_dim!=output_dimelsenn.Identity()else:self.skip_output=None# type: ignore
[docs]defforward(self,x:torch.Tensor)->torch.Tensor:original_shape=x.shapeifx.ndim>2:# Flatten to 2D if needed (B, ..., D) => (B*..., D)x=x.reshape(-1,x.shape[-1])x_in=x# Keep a copy of the original input for output residualfori,layerinenumerate(self.layers):out=layer(x)# Optional residual on hidden layers if matching dimsifself.residual_on_hiddenandself.dims[i]==self.dims[i+1]:out+=xx=outifself.skip_outputisnotNone:# Optional residual on outputx+=self.skip_output(x_in)ifx.shape[-1]!=original_shape[-1]:# Reshape back to the original leading dimensions if output_dim changedx=x.reshape(original_shape[:-1]+(self.dims[-1],))returnx