Skip to content

[1914][1596] Implemented step_callback for incremental output of forecast outputs#1916

Draft
evenmn wants to merge 3 commits intoecmwf:developfrom
metno:evenmn/fix/incremental-rollout
Draft

[1914][1596] Implemented step_callback for incremental output of forecast outputs#1916
evenmn wants to merge 3 commits intoecmwf:developfrom
metno:evenmn/fix/incremental-rollout

Conversation

@evenmn
Copy link

@evenmn evenmn commented Feb 24, 2026

Description

Enables incremental output writing during inference by writing forecast predictions after each step (or every N steps) instead of accumulating all steps in memory. Maintains constant memory usage regardless of forecast length, addressing limitations for high-resolution, multi-step forecasts.

Changes

  • model.py: Added optional step_callback parameter to Model.forward()
  • validation_io.py: Refactored with modular architecture (_prepare_batch_data(), _process_timestep(), StreamingOutputWriter)
  • trainer.py: Updated validate() for streaming callback support

Issue Number

Closes #1914
Closes #1596

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Signed-off-by: evenmn <evenmn@mn.uio.no>
@evenmn
Copy link
Author

evenmn commented Feb 24, 2026

Fixes #1914

@clessig clessig self-requested a review February 25, 2026 07:03
Copy link
Contributor

@grassesi grassesi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this implementation, I really like the general design however I have two things I think will streamline this PR alot:

  • Lets consider the non-streaming case a special case of streaming where n_forecast_steps == n_steps_to_accumulate_before_write. This way we get rid of the decision branches in write_output and Trainer.validate, reducing the overall complexity.
  • Lets simplify the configuration: Adding one new option (the number of fsteps to accumulate before writing) to validation_config.output / training_config.output should suffice:
    • output.fstep_chunk_size: 1 => current default behaviour
    • output.fstep_chunk_size: n => accumulate n chunks before writing
    • output.fstep_chunk_size: $training_config.forecast.num_steps => no streaming
  • Finally I think the default for this new argument should be None and if it is None then we dont do streaming (set output.fstep_chunk_size = training_config.forecast.num_steps)

If you dont mind I would already take this design and provide a callback function for model.forward in the new io (#1908 )

Comment on lines 617 to 660
def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
def forward(
self, model_params: ModelParams, batch: ModelBatch, step_callback=None
) -> ModelOutput:
"""Forward pass of the model

Tokens are processed through the model components, which were defined in the create method.
Args:
model_params : Query and embedding parameters
batch
batch : Batch of data
step_callback : Optional callback function to be called after each forecast step.
Called as step_callback(step, output) where step is the forecast step
index and output is the partial ModelOutput for that step.
Used for streaming output writing - allows writing after each step
instead of accumulating all steps in memory.
Returns:
A list containing all prediction results
"""

output = ModelOutput(batch.get_output_len())

tokens, posteriors = self.encoder(model_params, batch)
output.add_latent_prediction(0, "posteriors", posteriors)

# recover batch dimension and separate input_steps
shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:])
# collapse along input step dimension
tokens = tokens.reshape(shape).sum(axis=1)

# roll-out in latent space, iterate and generate output over requested output steps
for step in batch.get_output_idxs():
# apply forecasting engine (if present)
if self.forecast_engine:
tokens = self.forecast_engine(tokens, step, coords=model_params.rope_coords)

# decoder predictions
output = self.predict_decoders(model_params, step, tokens, batch, output)
# latent predictions (raw and with SSL heads)
output = self.predict_latent(model_params, step, tokens, batch, output)

# invoke callback for streaming output if provided
if step_callback is not None:
step_callback(step, output)

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this design: having a hook into the forecasting loop is super flexible and can be also used in the future to facilitate all kinds of interactions.

@github-project-automation github-project-automation bot moved this to In Progress in WeatherGen-dev Feb 25, 2026
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model forward function should not have callbacks. The control flow should be with the trainer, and the model forward just executes. Having an IO operation (which potentially is very slow, with a very large variance per worker) in forward will likely cause all kinds of havoc with model parallelism because it leads to highly asynchronous execution.

@evenmn
Copy link
Author

evenmn commented Feb 25, 2026

The model forward function should not have callbacks. The control flow should be with the trainer, and the model forward just executes. Having an IO operation (which potentially is very slow, with a very large variance per worker) in forward will likely cause all kinds of havoc with model parallelism because it leads to highly asynchronous execution.

I didn't think about that, but you're totally right. I will move the logics to trainer.py and push a new version

…cial case of streaming

Signed-off-by: evenmn <evenmn@mn.uio.no>
@grassesi
Copy link
Contributor

The model forward function should not have callbacks. The control flow should be with the trainer, and the model forward just executes. Having an IO operation (which potentially is very slow, with a very large variance per worker) in forward will likely cause all kinds of havoc with model parallelism because it leads to highly asynchronous execution.

The execution of the forecasting loop has to to be interrupted by io, since otherwise otherwise predictions will accumulate in memory. So I dont see a implemetation without asynchronous-parallel io, that does not suffer from the same concerns regarding model parallelism/asynchronous execution. Having a callback is a nice and simple interface, alternatively the forecasting-loop would need to moved outside the forward function, or the forward function would need to be turned in some kind of yielding Generator (but this would probably require more logic in the loss calculation/optimization )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

Incremental forward function for long-term roll-out Forecast output needs to be written incrementally

3 participants