-
Notifications
You must be signed in to change notification settings - Fork 575
HealDA Sensor Embedder #1397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
HealDA Sensor Embedder #1397
Conversation
|
@NickGeneva @pzharrington Here are the sensor embedding modules. |
Greptile OverviewGreptile SummaryAdds sensor embedding modules for HealDA, implementing multi-sensor observation tokenization and aggregation onto HEALPix grids.
Important Files Changed
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 10 comments
|
|
||
| Parameters | ||
| ---------- | ||
| sensor_configs : list[dict[str, Any]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of dicts sensor_configs feels a bit messy. Previously, we encapsulated this in a dataclass. Open to any suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if @NickGeneva agrees, but my 2c is that the dataclass is not necessarily the worst thing to have if it is a constructor argument; I assume you refactored bc this module subclasses physicsnemo.Module and thus the constructor args need to be JSON-serializable for the .from_checkpoint() functionality. If this is more of a helper module, specific to HealDA, imo it is ok to have it just subclass torch.nn.Module and then it could accept a dataclass if that's preferred -- not sure if that makes sense though, e.g. if the dataclass is intended to be a user-configurable thing that is passed to the top-level HealDA model (which should be a physicsnemo.Module)...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I removed the dataclass due to json-serializaling issues, and it seemed custom dataclasses were not preferred. For now, I switched to having separate parameters, each being a list, and then validated for matching length, as suggested by @NickGeneva.
| return out | ||
|
|
||
|
|
||
| def _offsets_to_batch_idx(offsets: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment for all Tensor inputs, can these all get updated to JAX typing. This will allow you to spec the dimensions
Here's an example:
https://github.com/NVIDIA/physicsnemo/blob/main/physicsnemo/models/srrn/super_res_net.py#L297
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added jaxtyping annotations everywhere
| local_platform: torch.Tensor, | ||
| obs_type: torch.Tensor, | ||
| offsets: torch.Tensor, | ||
| expected_num_sensors: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed here? can we just assume offsets.shape[0] is the right dim? Bit confused about the check below with expected_num_sensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can remove this check
|
|
||
| for sensor_idx in range(nsensors): | ||
| end = offsets[sensor_idx, -1, -1].item() | ||
| start = 0 if sensor_idx == 0 else offsets[sensor_idx - 1, -1, -1].item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confused about these two lines, we have:
offsets[sensor_idx, -1, -1].item()
offsets[sensor_idx - 1, -1, -1].item()so different dimensions... also why if conditional statement start = 0 if
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offsets[sensor_idx, -1, -1] refers to the end of the current sensor.
offsets[sensor_idx - 1, -1, -1] refers to the end of the previous sensor at sensor_idx - 1. The alignment of the two lines makes it seem like we are indexing into an extra dimension. Can clean up the if statement and add some comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified logic and added comments to make this clearer
|
|
||
| def __init__( | ||
| self, | ||
| *, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the general positional args needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong preference for using it. I included the keyword only * since nplatform and nplatform are both int, so with keyword only, it would force the user to be clear about what they are setting each to.
| ) -> torch.Tensor: | ||
| """Aggregate observations to spatial grid and project to output dimension.""" | ||
| # Convert observation pixels to aggregator grid resolution | ||
| aggregation_pix = pix // int(4.0 ** (hpx_level - self.hpx_level)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it hpx_level - self.hpx_level not following exactly what this line is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe these can have better names, self.hpx_level is the target emb level right... what is hpx_level then? The representitive hpx resolution of the sensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.hpx_level is the grid level of the sensor/model, whereas hpx_level is the level corresponding to the incoming pix tensor. This accounts for if the loader calculated pix on a higher resolution. Since we use the hpx nest format, converting from higher to lower resolution can be done by dividing by 4^(difference in level). But we don't need this added complexity and can assume pix is of the model level. Can simplify to not pass in hpx level/remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
| ) | ||
|
|
||
|
|
||
| def _prod(shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does math.prod not work? You import the math module already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using .numel() instead
| offsets: torch.Tensor, | ||
| hpx_level: int, | ||
| ) -> torch.Tensor: | ||
| if self.use_checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats this condition for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is use_checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for gradient_checkpointing. It could be useful for memory saving, although I haven't tested what the saving actually is. Open to removing.
|
|
||
| def __init__( | ||
| self, | ||
| sensor_configs: list[dict[str, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the sensor config is like
name: sensor name (bookkeeping, unused).nchannel: number of sensor channels.nplatform: number of sensor platforms.
Could this just be three separate parameters of list[str], list[int], and list[int] which get validated that the lengths are the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I think the separate lists make the most sense. The sensor names are not actually used by the model- I had initially thought to keep them around as metadata so it is easier to map each embedder network to what it actually is, but not sure if it is worth keeping.
| Flattened local platform ids of each observation with shape :math:`(N_{obs},)`. | ||
| obs_type : torch.Tensor | ||
| Flattened observation type ids with shape :math:`(N_{obs},)`. | ||
| offsets : torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not following how this tensor (S, B, T) is mapped to (N_{obs},).
Why would this just not be of size (S,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is int input of the forward actually (B, T, N_obs)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of our inputs are this packed/flattened nobs tensor, where it is packed across time, then batch, then sensor. The offsets then indicate where each "window" or sample in the full flat tensor ends. Basically, it is describing a (N_sensor,B,T,N_sensor_obs) tensor, but all flattened so that we don't need to use padding. The batch and time support is to make the module more general, although we only really use B=1 and T=1.
| HEALPIXPAD_AVAILABLE = check_version_spec("earth2grid", "0.1.0", hard_fail=False) | ||
|
|
||
| if HEALPIXPAD_AVAILABLE: | ||
| _healpix_mod = importlib.import_module("earth2grid.healpix") | ||
| hpx_grid = _healpix_mod.Grid | ||
| HEALPIX_PAD_XY = _healpix_mod.HEALPIX_PAD_XY | ||
| HEALPIX_NEST = _healpix_mod.NEST | ||
| else: | ||
| HEALPIX_PAD_XY = None | ||
| HEALPIX_NEST = None | ||
|
|
||
| def hpx_grid(*args, **kwargs): | ||
| """Dummy symbol for missing earth2grid backend.""" | ||
| raise ImportError( | ||
| ( | ||
| "earth2grid is not installed, cannot use it as a backend for HEALPix padding.\n" | ||
| "Install earth2grid from https://github.com/NVlabs/earth2grid.git to enable the accelerated path.\n" | ||
| "pip install --no-build-isolation https://github.com/NVlabs/earth2grid/archive/main.tar.gz" | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some recent updates made an effort to reduce this boilerplate. See, for example this PR: #1390
The functionality is already merged, that one is just applying it. It may simplify your life here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated to use the OptionalImport functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
PhysicsNeMo Pull Request
Description
Adds obs sensor embedding modules used in HealDA.
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.