Skip to content

Warm start and frozen teachers#1876

Open
sophie-xhonneux wants to merge 8 commits intodevelopfrom
sophiex/dev/warm-and-frozen-teachers
Open

Warm start and frozen teachers#1876
sophie-xhonneux wants to merge 8 commits intodevelopfrom
sophiex/dev/warm-and-frozen-teachers

Conversation

@sophie-xhonneux
Copy link
Contributor

@sophie-xhonneux sophie-xhonneux commented Feb 18, 2026

Description

Allow for the warm start with EMA and Frozen Teachers

Issue Number

Closes #1881

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@sophie-xhonneux sophie-xhonneux changed the title Write first solution with Claude Warm start and frozen teachers Feb 19, 2026
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Feb 19, 2026
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks fine. I pushed some minor changes. config_jepa.yml has 2D rope param but it's not in here. This should be removed (it was also one of the things that caused problems for me).

cf, target_and_aux_calc_params.get("model_param_overrides", {})
)
prepare_encoder_teacher(
meta_ema_model, cf.training_config, cf_overridden.ae_global_dim_embed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more generic to pass cf_overridden to prepare_encoder_teacher(); there might be more params in the future from the config that are useful beyond cf_overridden.ae_global_dim_embed

self.batch_size = batch_size
self.reset()

def _forward_teacher(self, model_params, batch):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a "private" function since it's called from in the base class. We also usually don't use the '_' convention so I would remove.

class FrozenTeacher(EncoderTeacher):
"""SSL teacher using a frozen pre-trained encoder.

The encoder is loaded from a checkpoint and never updated. Non-encoder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The teacher_model is assumed to have non-encoder parts discarded, not?

self.teacher_model.eval()

@classmethod
def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is inconsistent with what is done for EMATeacher in model_interface. Either we have from_pretrained() for both classes or we have the functionality in model_inferface.py

logger = logging.getLogger(__name__)


def _create_head(name: str, head_type: str, dim_embed: int, loss_conf, cf=None) -> nn.Module:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is for teacher_heads then the function same should say so.

model.pred_heads = nn.ModuleDict()

# Ensure latent_pre_norm exists (teacher may not have had SSL training)
if model.latent_pre_norm is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When/why wouldn't this exist?

3. Creates fresh latent_heads based on the student's SSL loss config
"""
# Strip non-encoder components
model.forecast_engine = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we formulate it as is not encoder so that we are robust to changes in the model design, e.g. we discussed to have a decoder-type model for the stream-specific prediction heads and we will most likely forget this hidden dependency here. Otherwise, we might have a function in model that reduces it to the encoder which is called here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something similar to

    encoder_params = {
        k: v for k, v in params.items() if k.startswith(("encoder.", "latent_pre_norm"))
    }

logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.")


def load_encoder_from_checkpoint(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this as well as the first part of prepare_encoder_teacher(); it seems to be the same functionality

@@ -0,0 +1,16 @@
training_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing

@@ -0,0 +1,7 @@
training_config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Allow for frozen teachers and warm starts

2 participants