Learning rates and parameter freezing
The training process allows for fine-grained control over the learning rates of different model components (Encoder, Decoder, DSR) and their hierarchical levels (Group-level "shared" parameters vs. Subject-specific parameters). Additionally, parameters can be frozen after a certain number of epochs.
Configuration Structure
The configuration is handled via the learning_rate key in the config. Depending on the desired granularity, you can provide a single value, a list, or a dictionary.
1. Global Learning Rate (Simple)
The easiest way is to set a single floating-point value. This learning rate is applied to all components.
- Note: For the DSR model, the subject-specific parameters are automatically scaled to 10x the global rate. A higher learning rate of the subject vectors has been found to increase the model's learning ability.
"learning_rate": 1e-3
2. Tuple / List
Allows for setting the learning rate separately for the subject vectors of the DSR model.
- Format:
[general_lr, dsr_subject_lr] - The
general_lrapplies to all parameters. - The
dsr_subject_lrapplies only to the subject vectors of the DSR model.
"learning_rate": [1e-3, 1e-2]
3. Component-Wise Configuration (Advanced)
For full control, use a dictionary. You can configure each model component (encoder, decoder, dsr) individually.
Simplified Setup (Leaf Nodes)
You can configure a component directly by providing lr and freeze_after (optional, use null if not needed). This setting applies to all parts (shared and subject-specific) of that component.
"learning_rate": {
# Train encoder with specific rate, freeze after 50 epochs
"encoder": { "lr": 1e-3, "freeze_after": 50 },
# Different rate for decoder
"decoder": { "lr": 2e-3, "freeze_after": null },
# DSR gets another rate
"dsr": { "lr": 1e-3, "freeze_after": null }
}
Detailed Setup (Hierarchical)
If you need different settings for the shared and subject-specific parts of a component, you can utilize the shared and subjects keys.
- Encoder: supports
shared,subjects - DSR: supports
shared,subjects - Decoder: supports
shared
"learning_rate": {
"encoder": {
"shared": { "lr": 1e-4, "freeze_after": null },
# Subject parameters often need higher learning rates
"subjects": { "lr": 1e-2, "freeze_after": 20 }
},
"decoder": {
# Decoder usually only has shared parameters
"shared": { "lr": 1e-3, "freeze_after": null }
},
"dsr": {
"shared": { "lr": 1e-3, "freeze_after": null },
"subjects": { "lr": 1e-2, "freeze_after": null }
}
}
Mixed Setup
You can mix simplified and detailed configurations.
"learning_rate": {
"encoder": { "lr": 1e-3, "freeze_after": null }, # Simplified
"decoder": { "lr": 1e-3, "freeze_after": null },
"dsr": { # Detailed
"shared": { "lr": 1e-4, "freeze_after": null },
"subjects": { "lr": 5e-3, "freeze_after": null }
}
}
Learning Rate Scheduling
The global learning rate schedule is configured via the lr_scheduling key. It uses a scheduler from torch.optim.lr_scheduler.
Supported schedulers:
MultiStepLR(Default)ReduceLROnPlateauExponentialLRStepLR
MultiStepLR
Decays the learning rate by gamma once the number of epochs reaches one of the milestones.
- Important:
milestonesshould be provided as fractions of the totaln_epochs(between 0 and 1).
"lr_scheduling": {
"name": "MultiStepLR",
"hyperparameters": {
"gamma": 0.1,
"milestones": [0.2, 0.8] # e.g. at 20% and 80% of training
}
}
ReduceLROnPlateau
Reduces the learning rate when a metric (training loss) has stopped improving.
"lr_scheduling": {
"name": "ReduceLROnPlateau",
"hyperparameters": {
"mode": "min",
"factor": 0.1,
"patience": 10
}
}
ExponentialLR
Decays the learning rate of each parameter group by gamma every epoch.
Alternatively to gamma, you can specify initial and final learning rates, and the system will automatically calculate the required gamma to interpolate between them over n_epochs.
"lr_scheduling": {
"name": "ExponentialLR",
"hyperparameters": {
"gamma": 0.95
# OR
# "initial": 1e-3,
# "final": 1e-5
}
}
The learning rate scheduler applies the scheduling to each learning rate, there is no option to set different schedules for different parameter groups of the model.