fastdev.nn.mlp

Module Contents

class fastdev.nn.mlp.MLP(input_dim: int, output_dim: int, hidden_dims: List[int], activation_layer: Type[torch.nn.Module] | None = nn.ReLU, activation_on_output: bool = False, residual_on_output: bool = False, residual_on_hidden: bool = False, use_normalization: bool = False, normalization_layer: Type[torch.nn.Module] | None = nn.LayerNorm)[source]

Bases: torch.nn.Module

A flexible MLP

Parameters:
  • input_dim (int)

  • output_dim (int)

  • hidden_dims (List[int])

  • activation_layer (Optional[Type[torch.nn.Module]])

  • activation_on_output (bool)

  • residual_on_output (bool)

  • residual_on_hidden (bool)

  • use_normalization (bool)

  • normalization_layer (Optional[Type[torch.nn.Module]])

dims[source]
residual_on_hidden = False[source]
residual_on_output = False[source]
layers[source]
forward(x: torch.Tensor) torch.Tensor[source]
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

fastdev.nn.mlp.mlp[source]