Skip to main content

Decoders and Evaluators

The system is designed to be extensible. You can add custom decoders and evaluators by following the existing patterns. This detailed guide explains the file structure, the configuration resolution process, and the steps to implement and register new modules.

File Structure

  • Decoders: Located in src/models/decoder/models/.
    • They are grouped by type (e.g., gaussian/, categorical/).
  • Evaluators: Located in src/evaluation/evaluators/.
    • modality_specific/: For evaluators paired with a specific decoder (modality).
    • modality_independent/: For evaluators checking latent state properties or cross-modality properties.
    • general/: For global analysis requiring only the model and dataset used directly during the evaluation loop.

Configuration Resolution

The system uses a configuration resolution mechanism to validate and process user inputs before initializing the models. This ensures that all necessary parameters are present and correctly formatted.

  • Process: The resolve_config static method is called for both Decoder and Evaluator classes during startup (via configuration_resolver.py).
  • Merging: User-provided parameters in the config file are merged with the default_hyperparameters (for decoders) or default_params (for evaluators) defined in your class.
  • Registration: When adding a new module, you must register it in two places within the respective manager class (Decoder or Evaluator):
    1. __init__: To instantiate the class with the resolved parameters.
    2. resolve_config: To map the string name to the class and merge parameters.

Adding a Decoder

1. Implementation

Create a new class inheriting from BaseDecoder (defined in src/models/decoder/models/base.py) in src/models/decoder/models/<your_type>/.

  • Inheritance: Your class should inherit from BaseDecoder[T], where T is a TypedDict inheriting from DecoderHyperparameters.
  • Required Methods:
    • @staticmethod default_hyperparameters() -> T: Return a dictionary of default hyperparameter values. This is crucial for configuration resolution.
    • forward(self, Z: tc.Tensor, expected: bool = False) -> tc.Tensor: Return the decoded value from latent state Z.
    • log_likelihood(self, Y: tc.Tensor, Z: tc.Tensor) -> tc.Tensor: Return the log-likelihood of target Y given latent state Z (for loss calculation).

2. Registration

You need to edit src/models/decoder/decoder_models.py in two places:

  1. In resolve_config:
    • Add a case to the match decoder_name: block.
    • Set decoder_class = YourNewClass.
    • Note: This allows the system to look up your class and call default_hyperparameters().
  2. In __init__:
    • Add a case to the match decoder["name"]: block.
    • Set decoder_class = YourNewClass.
    • Note: This instantiates your class with the fully resolved configuration.

Adding an Evaluator

Evaluators are categorized into three types. Choose the one that fits your needs.

1. Modality Specific Evaluator

These operate on data from a specific modality and are often paired with a specific decoder type.

Implementation:

  • Location: src/evaluation/evaluators/modality_specific/.
  • Inheritance: BaseEvaluator from src.evaluation.evaluators.modality_specific.base.
  • Parameters: Define a TypedDict inheriting from ModalitySpecificEvaluatorParams.
  • Required Methods:
    • @staticmethod default_params() -> T: Return default parameters.
    • get_tensorboard_figures(self, args: GetTensorboardFiguresArguments) -> dict[str, plt.Figure]: Return figures.
    • compute_metrics(self, args: ComputeMetricArguments) -> dict[str, float]: Return metrics.

Registration (src/evaluation/evaluator.py):

  1. In resolve_config:
    • If you want your evaluator to be automatically selected for a specific decoder, add a case to the first match decoder_option["name"]: block (e.g., map MyDecoder to MyEvaluator).
    • Add a case to the second match modality_specific_option["name"]: block to handle explicit configuration. Set evaluator_class = YourEvaluator.
  2. In __init__:
    • Add a case to the modality_evaluators loop's match modality_specific_option["name"]: block. Set selected_evaluator = YourEvaluator.

2. Modality Independent Evaluator

These operate on data independent of any specific modality (e.g., latent space analysis).

Implementation:

  • Location: src/evaluation/evaluators/modality_independent/.
  • Inheritance: BaseEvaluator from src.evaluation.evaluators.modality_independent.base.
  • Parameters: Define a TypedDict inheriting from ModalityIndependentEvaluatorParams.
  • Required Methods:
    • @staticmethod default_params() -> T: Return default parameters.
    • get_tensorboard_figures(self, args: ModalityIndependentEvaluatorFunctionArguments) -> dict[str, plt.Figure]: Return figures.
    • compute_metrics(self, args: ModalityIndependentEvaluatorFunctionArguments) -> dict[str, float]: Return metrics.

Registration (src/evaluation/evaluator.py):

  1. In resolve_config:
    • Add a case to the no_modality_evaluators loop's match block. Set selected_evaluator = YourEvaluator.
  2. In __init__:
    • Add a case to the no_modality_evaluators loop's match block. Set selected_evaluator = YourEvaluator.

3. General Evaluator

These are for global analysis and are called once per evaluation step.

Implementation:

  • Location: src/evaluation/evaluators/general/.
  • Inheritance: BaseEvaluator from src.evaluation.evaluators.general.base.
  • Parameters: Define a TypedDict inheriting from GeneralEvaluatorParams.
  • Required Methods:
    • @staticmethod default_params() -> T: Return default parameters.
    • compute_metrics_and_figures(self, args: GeneralEvaluatorFunctionArguments) -> GeneralEvaluatorReturns: Return metrics and figures.

Registration (src/evaluation/evaluator.py):

  1. In resolve_config:
    • Add a case to the general_evaluators loop's match block. Set selected_evaluator = YourEvaluator.
  2. In __init__:
    • Add a case to the general_evaluators loop's match block. Set selected_evaluator = YourEvaluator.

Configuration

Once registered, you can use your new module by referencing its name in the configuration. The resolve_config process will merge your provided parameters with the defaults.