Conversation
…iex/dev/warm-and-frozen-teachers
…iex/dev/warm-and-frozen-teachers
…to be (at least here) identical
clessig
left a comment
There was a problem hiding this comment.
Overall looks fine. I pushed some minor changes. config_jepa.yml has 2D rope param but it's not in here. This should be removed (it was also one of the things that caused problems for me).
| cf, target_and_aux_calc_params.get("model_param_overrides", {}) | ||
| ) | ||
| prepare_encoder_teacher( | ||
| meta_ema_model, cf.training_config, cf_overridden.ae_global_dim_embed |
There was a problem hiding this comment.
It would be more generic to pass cf_overridden to prepare_encoder_teacher(); there might be more params in the future from the config that are useful beyond cf_overridden.ae_global_dim_embed
| self.batch_size = batch_size | ||
| self.reset() | ||
|
|
||
| def _forward_teacher(self, model_params, batch): |
There was a problem hiding this comment.
I don't think it's a "private" function since it's called from in the base class. We also usually don't use the '_' convention so I would remove.
| class FrozenTeacher(EncoderTeacher): | ||
| """SSL teacher using a frozen pre-trained encoder. | ||
|
|
||
| The encoder is loaded from a checkpoint and never updated. Non-encoder |
There was a problem hiding this comment.
The teacher_model is assumed to have non-encoder parts discarded, not?
| self.teacher_model.eval() | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher: |
There was a problem hiding this comment.
This function is inconsistent with what is done for EMATeacher in model_interface. Either we have from_pretrained() for both classes or we have the functionality in model_inferface.py
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _create_head(name: str, head_type: str, dim_embed: int, loss_conf, cf=None) -> nn.Module: |
There was a problem hiding this comment.
If this is for teacher_heads then the function same should say so.
| model.pred_heads = nn.ModuleDict() | ||
|
|
||
| # Ensure latent_pre_norm exists (teacher may not have had SSL training) | ||
| if model.latent_pre_norm is None: |
There was a problem hiding this comment.
When/why wouldn't this exist?
| 3. Creates fresh latent_heads based on the student's SSL loss config | ||
| """ | ||
| # Strip non-encoder components | ||
| model.forecast_engine = None |
There was a problem hiding this comment.
Can we formulate it as is not encoder so that we are robust to changes in the model design, e.g. we discussed to have a decoder-type model for the stream-specific prediction heads and we will most likely forget this hidden dependency here. Otherwise, we might have a function in model that reduces it to the encoder which is called here.
There was a problem hiding this comment.
Something similar to
encoder_params = {
k: v for k, v in params.items() if k.startswith(("encoder.", "latent_pre_norm"))
}
| logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.") | ||
|
|
||
|
|
||
| def load_encoder_from_checkpoint( |
There was a problem hiding this comment.
Why do we need this as well as the first part of prepare_encoder_teacher(); it seems to be the same functionality
| @@ -0,0 +1,16 @@ | |||
| training_config: | |||
There was a problem hiding this comment.
How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing
| @@ -0,0 +1,7 @@ | |||
| training_config: | |||
There was a problem hiding this comment.
How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing
Description
Allow for the warm start with EMA and Frozen Teachers
Issue Number
Closes #1881
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60