Skip to main content

Latent step model

info

The terms latent step model and DSR model are used interchangeably in the documentation.

The latent step model is the core of the DSR model. It is the RNN which predicts the next latent state from the current latent state and possible the external inputs. As the name suggests, these RNNs operate in the latent space, meaning that their central hyperparameter is the latent space's size. This can be set in the config through the latent_dim parameter.

note

The latent_dim parameter is declared outside of the latent_step dictionary as it is essential for multiple parts of the model.

Implemented models

The repository implements many different PLRNNs, all of which have been used extensively in previous papers of the Durstewitz lab. The choice can be set using the latent_step parameter. It expects a dictionary, where the key name is reserved for the choice. Options are:

  • PLRNN (Vanilla Piecewise Linear RNN)
  • ALRNN (Almost Linear PLRNN)
  • shPLRNN (shallow PLRNN)
  • clipped_shPLRNN (clipped shallow PLRNN)

Hyperparameters for the chosen model can be set using the hyperparameters key. This dictionary includes model-agnostic and model-specific keys. Model-agnostic keys are the following:

  • learnable_C: Whether the external input matrix C\mathbf{C} is learnable or not. If not, it projects the external inputs sts_t into the "end" of the latent state, using an identity matrix. In this case the user has to make sure that the latent space is larger than the external input dimension.

For mathematical update rules of the options please refer to the mathematical background chapter.

PLRNN

{
"name": "PLRNN"
}

Specific hyperparameters are:

  • mean_centering (bool, default=True): Whether the latent state should be mean centered before applying the RNN. Should be kept true.

(Clipped) Shallow PLRNN

{
"name": "shPLRNN" | "clipped_shPLRNN"
}

The clipped shallow PLRNN is a variant of the shPLRNN with the following update rule:

zt+1=Azt+W1[ϕ(W2zt+h2)ϕ(W2zt)]+h1+Cst,\mathbf{z}_{t+1} = \mathbf{A}z_t + \mathbf{W_1}[\phi(\mathbf{W_2}\mathbf{z}_t+ \mathbf{h}_2) - \phi(\mathbf{W_2}\mathbf{z}_t)] + \mathbf{h}_1 + \mathbf{C}\mathbf{s}_t,

which is a reformulation of the standard shPLRNN.

Both share specific hyperparameters:

  • hidden_dim (int): The dimension of the shallow PLRNN's hidden layer.

In case of the shallow PLRNN the latent_dim carries a less important function as its hidden layer may implement any non-linear transformation.

ALRNN

{
"name": "ALRNN"
}

Specific hyperparameters:

  • num_relus (int): The number of non-linearities used. These are applied to the last num_relus positions of the state vector.
  • ar_epsilon (Optional[int], default=None): The ε\varepsilon used for the autoregressive convergence loss (see Brenner, Hemmer et al. (2024) for details). A regularization loss which pulls A+diag(W)A + \text{diag}(W) to 1ε1-\varepsilon. Its use is advised if the ALRNN is observed to diverge during training. If None, the loss is not applied.
  • off_diagonal_W (bool, default=False): Whether the WW matrix is explicitly parametrized to only contain off-diagonal elements.