Prevent memory explosion during GeoTransolver inference on large meshes #1361
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
…er sub batch
PhysicsNeMo Pull Request
Description
This PR fixes an out-of-memory (OOM) issue when running GeoTransolver inference on large meshes (10M+ cells) that was causing the process to be killed.
During inference on full car meshes, the broadcast_global_features: true setting caused fx (global features: air density, stream velocity) to be replicated to every mesh point before sub-batching.
This, combined with downstream processing in the ContextProjector, exceeded GPU memory before even the first forward pass.
The GeoTransolver model uses a global_tokenizer (ContextProjector) that processes the global features through linear projections and multi-head attention. When fx is broadcast to 2M+ tokens upfront, the intermediate activations and attention computations scale linearly with mesh size, causing OOM.
Solution:
inference_on_vtk.py: Force broadcast_global_features: false in the datapipe for inference, regardless of the training config. This keeps fx as a single token (B, 1, 2).
inference_on_zarr.py: Modified batched_inference_loop to broadcast fx per sub-batch dynamically:
If fx is single-token → expand to match sub-batch size
If fx is full-mesh (legacy path) → slice for sub-batch
Why This Doesn't Affect Inference Quality?
Since all tokens in broadcast fx have identical values, the aggregation result is mathematically equivalent. The model sees the same sub-batch size it was trained on, just processed sequentially instead of all at once.
Checklist
Dependencies
None
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.