[1914][1596] Implemented step_callback for incremental output of forecast outputs#1916
[1914][1596] Implemented step_callback for incremental output of forecast outputs#1916evenmn wants to merge 3 commits intoecmwf:developfrom
Conversation
Signed-off-by: evenmn <evenmn@mn.uio.no>
|
Fixes #1914 |
grassesi
left a comment
There was a problem hiding this comment.
Thanks for this implementation, I really like the general design however I have two things I think will streamline this PR alot:
- Lets consider the non-streaming case a special case of streaming where n_forecast_steps == n_steps_to_accumulate_before_write. This way we get rid of the decision branches in
write_outputandTrainer.validate, reducing the overall complexity. - Lets simplify the configuration: Adding one new option (the number of fsteps to accumulate before writing) to
validation_config.output/training_config.outputshould suffice:output.fstep_chunk_size: 1=> current default behaviouroutput.fstep_chunk_size: n=> accumulate n chunks before writingoutput.fstep_chunk_size: $training_config.forecast.num_steps=> no streaming
- Finally I think the default for this new argument should be None and if it is None then we dont do streaming (set
output.fstep_chunk_size = training_config.forecast.num_steps)
If you dont mind I would already take this design and provide a callback function for model.forward in the new io (#1908 )
| def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: | ||
| def forward( | ||
| self, model_params: ModelParams, batch: ModelBatch, step_callback=None | ||
| ) -> ModelOutput: | ||
| """Forward pass of the model | ||
|
|
||
| Tokens are processed through the model components, which were defined in the create method. | ||
| Args: | ||
| model_params : Query and embedding parameters | ||
| batch | ||
| batch : Batch of data | ||
| step_callback : Optional callback function to be called after each forecast step. | ||
| Called as step_callback(step, output) where step is the forecast step | ||
| index and output is the partial ModelOutput for that step. | ||
| Used for streaming output writing - allows writing after each step | ||
| instead of accumulating all steps in memory. | ||
| Returns: | ||
| A list containing all prediction results | ||
| """ | ||
|
|
||
| output = ModelOutput(batch.get_output_len()) | ||
|
|
||
| tokens, posteriors = self.encoder(model_params, batch) | ||
| output.add_latent_prediction(0, "posteriors", posteriors) | ||
|
|
||
| # recover batch dimension and separate input_steps | ||
| shape = (len(batch), batch.get_num_steps(), *tokens.shape[1:]) | ||
| # collapse along input step dimension | ||
| tokens = tokens.reshape(shape).sum(axis=1) | ||
|
|
||
| # roll-out in latent space, iterate and generate output over requested output steps | ||
| for step in batch.get_output_idxs(): | ||
| # apply forecasting engine (if present) | ||
| if self.forecast_engine: | ||
| tokens = self.forecast_engine(tokens, step, coords=model_params.rope_coords) | ||
|
|
||
| # decoder predictions | ||
| output = self.predict_decoders(model_params, step, tokens, batch, output) | ||
| # latent predictions (raw and with SSL heads) | ||
| output = self.predict_latent(model_params, step, tokens, batch, output) | ||
|
|
||
| # invoke callback for streaming output if provided | ||
| if step_callback is not None: | ||
| step_callback(step, output) | ||
|
|
||
| return output |
There was a problem hiding this comment.
I like this design: having a hook into the forecasting loop is super flexible and can be also used in the future to facilitate all kinds of interactions.
clessig
left a comment
There was a problem hiding this comment.
The model forward function should not have callbacks. The control flow should be with the trainer, and the model forward just executes. Having an IO operation (which potentially is very slow, with a very large variance per worker) in forward will likely cause all kinds of havoc with model parallelism because it leads to highly asynchronous execution.
I didn't think about that, but you're totally right. I will move the logics to trainer.py and push a new version |
…cial case of streaming Signed-off-by: evenmn <evenmn@mn.uio.no>
The execution of the forecasting loop has to to be interrupted by io, since otherwise otherwise predictions will accumulate in memory. So I dont see a implemetation without asynchronous-parallel io, that does not suffer from the same concerns regarding model parallelism/asynchronous execution. Having a callback is a nice and simple interface, alternatively the forecasting-loop would need to moved outside the forward function, or the forward function would need to be turned in some kind of yielding Generator (but this would probably require more logic in the loss calculation/optimization ) |
Description
Enables incremental output writing during inference by writing forecast predictions after each step (or every N steps) instead of accumulating all steps in memory. Maintains constant memory usage regardless of forecast length, addressing limitations for high-resolution, multi-step forecasts.
Changes
step_callbackparameter toModel.forward()_prepare_batch_data(),_process_timestep(),StreamingOutputWriter)validate()for streaming callback supportIssue Number
Closes #1914
Closes #1596
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60