Decoders and Evaluators
The system is designed to be extensible. You can add custom decoders and evaluators by following the existing patterns. This detailed guide explains the file structure, the configuration resolution process, and the steps to implement and register new modules.
File Structure
- Decoders: Located in
src/models/decoder/models/.- They are grouped by type (e.g.,
gaussian/,categorical/).
- They are grouped by type (e.g.,
- Evaluators: Located in
src/evaluation/evaluators/.modality_specific/: For evaluators paired with a specific decoder (modality).modality_independent/: For evaluators checking latent state properties or cross-modality properties.general/: For global analysis requiring only the model and dataset used directly during the evaluation loop.
Configuration Resolution
The system uses a configuration resolution mechanism to validate and process user inputs before initializing the models. This ensures that all necessary parameters are present and correctly formatted.
- Process: The
resolve_configstatic method is called for bothDecoderandEvaluatorclasses during startup (viaconfiguration_resolver.py). - Merging: User-provided parameters in the config file are merged with the
default_hyperparameters(for decoders) ordefault_params(for evaluators) defined in your class. - Registration: When adding a new module, you must register it in two places within the respective manager class (
DecoderorEvaluator):__init__: To instantiate the class with the resolved parameters.resolve_config: To map the string name to the class and merge parameters.
Adding a Decoder
1. Implementation
Create a new class inheriting from BaseDecoder (defined in src/models/decoder/models/base.py) in src/models/decoder/models/<your_type>/.
- Inheritance: Your class should inherit from
BaseDecoder[T], whereTis aTypedDictinheriting fromDecoderHyperparameters. - Required Methods:
@staticmethod default_hyperparameters() -> T: Return a dictionary of default hyperparameter values. This is crucial for configuration resolution.forward(self, Z: tc.Tensor, expected: bool = False) -> tc.Tensor: Return the decoded value from latent stateZ.log_likelihood(self, Y: tc.Tensor, Z: tc.Tensor) -> tc.Tensor: Return the log-likelihood of targetYgiven latent stateZ(for loss calculation).
2. Registration
You need to edit src/models/decoder/decoder_models.py in two places:
- In
resolve_config:- Add a case to the
match decoder_name:block. - Set
decoder_class = YourNewClass. - Note: This allows the system to look up your class and call
default_hyperparameters().
- Add a case to the
- In
__init__:- Add a case to the
match decoder["name"]:block. - Set
decoder_class = YourNewClass. - Note: This instantiates your class with the fully resolved configuration.
- Add a case to the
Adding an Evaluator
Evaluators are categorized into three types. Choose the one that fits your needs.
1. Modality Specific Evaluator
These operate on data from a specific modality and are often paired with a specific decoder type.
Implementation:
- Location:
src/evaluation/evaluators/modality_specific/. - Inheritance:
BaseEvaluatorfromsrc.evaluation.evaluators.modality_specific.base. - Parameters: Define a
TypedDictinheriting fromModalitySpecificEvaluatorParams. - Required Methods:
@staticmethod default_params() -> T: Return default parameters.get_tensorboard_figures(self, args: GetTensorboardFiguresArguments) -> dict[str, plt.Figure]: Return figures.compute_metrics(self, args: ComputeMetricArguments) -> dict[str, float]: Return metrics.
Registration (src/evaluation/evaluator.py):
- In
resolve_config:- If you want your evaluator to be automatically selected for a specific decoder, add a case to the first
match decoder_option["name"]:block (e.g., mapMyDecodertoMyEvaluator). - Add a case to the second
match modality_specific_option["name"]:block to handle explicit configuration. Setevaluator_class = YourEvaluator.
- If you want your evaluator to be automatically selected for a specific decoder, add a case to the first
- In
__init__:- Add a case to the
modality_evaluatorsloop'smatch modality_specific_option["name"]:block. Setselected_evaluator = YourEvaluator.
- Add a case to the
2. Modality Independent Evaluator
These operate on data independent of any specific modality (e.g., latent space analysis).
Implementation:
- Location:
src/evaluation/evaluators/modality_independent/. - Inheritance:
BaseEvaluatorfromsrc.evaluation.evaluators.modality_independent.base. - Parameters: Define a
TypedDictinheriting fromModalityIndependentEvaluatorParams. - Required Methods:
@staticmethod default_params() -> T: Return default parameters.get_tensorboard_figures(self, args: ModalityIndependentEvaluatorFunctionArguments) -> dict[str, plt.Figure]: Return figures.compute_metrics(self, args: ModalityIndependentEvaluatorFunctionArguments) -> dict[str, float]: Return metrics.
Registration (src/evaluation/evaluator.py):
- In
resolve_config:- Add a case to the
no_modality_evaluatorsloop'smatchblock. Setselected_evaluator = YourEvaluator.
- Add a case to the
- In
__init__:- Add a case to the
no_modality_evaluatorsloop'smatchblock. Setselected_evaluator = YourEvaluator.
- Add a case to the
3. General Evaluator
These are for global analysis and are called once per evaluation step.
Implementation:
- Location:
src/evaluation/evaluators/general/. - Inheritance:
BaseEvaluatorfromsrc.evaluation.evaluators.general.base. - Parameters: Define a
TypedDictinheriting fromGeneralEvaluatorParams. - Required Methods:
@staticmethod default_params() -> T: Return default parameters.compute_metrics_and_figures(self, args: GeneralEvaluatorFunctionArguments) -> GeneralEvaluatorReturns: Return metrics and figures.
Registration (src/evaluation/evaluator.py):
- In
resolve_config:- Add a case to the
general_evaluatorsloop'smatchblock. Setselected_evaluator = YourEvaluator.
- Add a case to the
- In
__init__:- Add a case to the
general_evaluatorsloop'smatchblock. Setselected_evaluator = YourEvaluator.
- Add a case to the
Configuration
Once registered, you can use your new module by referencing its name in the configuration. The resolve_config process will merge your provided parameters with the defaults.