Skip to content

Conversation

@aayushg55
Copy link
Contributor

@aayushg55 aayushg55 commented Feb 11, 2026

PhysicsNeMo Pull Request

Description

Adds obs sensor embedding modules used in HealDA.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@aayushg55
Copy link
Contributor Author

@NickGeneva @pzharrington Here are the sensor embedding modules.

@aayushg55 aayushg55 marked this pull request as ready for review February 11, 2026 01:28
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

Adds sensor embedding modules for HealDA, implementing multi-sensor observation tokenization and aggregation onto HEALPix grids.

  • Implements ObsTokenizer for converting observations with metadata into feature vectors
  • Implements SensorEmbedder for embedding single-sensor observations onto HEALPix grids with scatter aggregation
  • Implements MultiSensorObsEmbedding for fusing multiple sensor embeddings with reordering to HEALPIX_PAD_XY format
  • Implements UniformFusion for averaging sensor embeddings with variance-preserving scaling
  • Includes comprehensive test coverage with gradient checks, non-regression tests, and edge case handling
  • Previous review comments have flagged missing jaxtyping annotations and shape validation in several methods (MOD-005, MOD-006)

Important Files Changed

Filename Overview
physicsnemo/experimental/models/healda/init.py Adds exports for new sensor embedding modules - clean implementation
physicsnemo/experimental/models/healda/point_embed.py Implements multi-sensor observation embedding with comprehensive validation and documentation; previously flagged jaxtyping annotation and shape validation issues already noted
test/models/healda/test_point_embed.py Comprehensive test coverage including gradient checks, forward pass accuracy, and edge case handling

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 10 comments

Edit Code Review Agent Settings | Greptile


Parameters
----------
sensor_configs : list[dict[str, Any]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of dicts sensor_configs feels a bit messy. Previously, we encapsulated this in a dataclass. Open to any suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if @NickGeneva agrees, but my 2c is that the dataclass is not necessarily the worst thing to have if it is a constructor argument; I assume you refactored bc this module subclasses physicsnemo.Module and thus the constructor args need to be JSON-serializable for the .from_checkpoint() functionality. If this is more of a helper module, specific to HealDA, imo it is ok to have it just subclass torch.nn.Module and then it could accept a dataclass if that's preferred -- not sure if that makes sense though, e.g. if the dataclass is intended to be a user-configurable thing that is passed to the top-level HealDA model (which should be a physicsnemo.Module)...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I removed the dataclass due to json-serializaling issues, and it seemed custom dataclasses were not preferred. For now, I switched to having separate parameters, each being a list, and then validated for matching length, as suggested by @NickGeneva.

return out


def _offsets_to_batch_idx(offsets: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment for all Tensor inputs, can these all get updated to JAX typing. This will allow you to spec the dimensions

Here's an example:
https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/models/srrn/super_res_net.py#L297

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added jaxtyping annotations everywhere

local_platform: torch.Tensor,
obs_type: torch.Tensor,
offsets: torch.Tensor,
expected_num_sensors: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed here? can we just assume offsets.shape[0] is the right dim? Bit confused about the check below with expected_num_sensors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can remove this check


for sensor_idx in range(nsensors):
end = offsets[sensor_idx, -1, -1].item()
start = 0 if sensor_idx == 0 else offsets[sensor_idx - 1, -1, -1].item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused about these two lines, we have:

offsets[sensor_idx, -1, -1].item()
offsets[sensor_idx - 1, -1, -1].item()

so different dimensions... also why if conditional statement start = 0 if

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offsets[sensor_idx, -1, -1] refers to the end of the current sensor.
offsets[sensor_idx - 1, -1, -1] refers to the end of the previous sensor at sensor_idx - 1. The alignment of the two lines makes it seem like we are indexing into an extra dimension. Can clean up the if statement and add some comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified logic and added comments to make this clearer


def __init__(
self,
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the general positional args needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong preference for using it. I included the keyword only * since nplatform and nplatform are both int, so with keyword only, it would force the user to be clear about what they are setting each to.

) -> torch.Tensor:
"""Aggregate observations to spatial grid and project to output dimension."""
# Convert observation pixels to aggregator grid resolution
aggregation_pix = pix // int(4.0 ** (hpx_level - self.hpx_level))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it hpx_level - self.hpx_level not following exactly what this line is doing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these can have better names, self.hpx_level is the target emb level right... what is hpx_level then? The representitive hpx resolution of the sensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.hpx_level is the grid level of the sensor/model, whereas hpx_level is the level corresponding to the incoming pix tensor. This accounts for if the loader calculated pix on a higher resolution. Since we use the hpx nest format, converting from higher to lower resolution can be done by dividing by 4^(difference in level). But we don't need this added complexity and can assume pix is of the model level. Can simplify to not pass in hpx level/remove this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

)


def _prod(shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does math.prod not work? You import the math module already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .numel() instead

offsets: torch.Tensor,
hpx_level: int,
) -> torch.Tensor:
if self.use_checkpoint:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats this condition for?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is use_checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for gradient_checkpointing. It could be useful for memory saving, although I haven't tested what the saving actually is. Open to removing.


def __init__(
self,
sensor_configs: list[dict[str, Any]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sensor config is like

  • name: sensor name (bookkeeping, unused).
  • nchannel: number of sensor channels.
  • nplatform: number of sensor platforms.

Could this just be three separate parameters of list[str], list[int], and list[int] which get validated that the lengths are the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think the separate lists make the most sense. The sensor names are not actually used by the model- I had initially thought to keep them around as metadata so it is easier to map each embedder network to what it actually is, but not sure if it is worth keeping.

Flattened local platform ids of each observation with shape :math:`(N_{obs},)`.
obs_type : torch.Tensor
Flattened observation type ids with shape :math:`(N_{obs},)`.
offsets : torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not following how this tensor (S, B, T) is mapped to (N_{obs},).

Why would this just not be of size (S,)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is int input of the forward actually (B, T, N_obs)?

Copy link
Contributor Author

@aayushg55 aayushg55 Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of our inputs are this packed/flattened nobs tensor, where it is packed across time, then batch, then sensor. The offsets then indicate where each "window" or sample in the full flat tensor ends. Basically, it is describing a (N_sensor,B,T,N_sensor_obs) tensor, but all flattened so that we don't need to use padding. The batch and time support is to make the module more general, although we only really use B=1 and T=1.

Comment on lines 30 to 49
HEALPIXPAD_AVAILABLE = check_version_spec("earth2grid", "0.1.0", hard_fail=False)

if HEALPIXPAD_AVAILABLE:
_healpix_mod = importlib.import_module("earth2grid.healpix")
hpx_grid = _healpix_mod.Grid
HEALPIX_PAD_XY = _healpix_mod.HEALPIX_PAD_XY
HEALPIX_NEST = _healpix_mod.NEST
else:
HEALPIX_PAD_XY = None
HEALPIX_NEST = None

def hpx_grid(*args, **kwargs):
"""Dummy symbol for missing earth2grid backend."""
raise ImportError(
(
"earth2grid is not installed, cannot use it as a backend for HEALPix padding.\n"
"Install earth2grid from https://github.com/NVlabs/earth2grid.git to enable the accelerated path.\n"
"pip install --no-build-isolation https://github.com/NVlabs/earth2grid/archive/main.tar.gz"
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some recent updates made an effort to reduce this boilerplate. See, for example this PR: #1390

The functionality is already merged, that one is just applying it. It may simplify your life here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated to use the OptionalImport functionality.

@aayushg55
Copy link
Contributor Author

@greptileai

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants