[1849][model] Add optional latent Zarr writer#1860
[1849][model] Add optional latent Zarr writer#1860evenmn wants to merge 8 commits intoecmwf:developfrom
Conversation
…ams: ['..., latent']) Signed-off-by: evenmn <evenmn@mn.uio.no>
|
@evenmn : can you please make sure the linter and unit tests pass. |
|
@grassesi : can you have a look at the changes in io.py. How far are you with refactoring the output writing? |
Signed-off-by: evenmn <evenmn@mn.uio.no>
Signed-off-by: evenmn <evenmn@mn.uio.no>
…Generator into feature/latent-zarr-writer
grassesi
left a comment
There was a problem hiding this comment.
Overall I like the idea of treating latent output just as another stream. One thing that needs to be addressed is that there is some masking logic that prevents the latent stream from trying to be processed during evaluation (esp. the to_xarray call will not work). Otherwise just some stylistic remarks.
I saw in the Issue that the latents should be eventually exposed via an JSON API? Would it make sense to already implement this and not try to piggyback on ZarrIO?
| # additionally yield latent output items if a latent stream name was provided | ||
| if self.latent_stream_name is not None and self.latents: | ||
| for s, fo_s in itertools.product(self.samples, self.forecast_steps): | ||
| key = ItemKey(int(s), int(fo_s), self.latent_stream_name) | ||
| latent_item = self._make_latent_item(key) | ||
| if latent_item is not None: | ||
| yield latent_item | ||
|
|
There was a problem hiding this comment.
Please try to wrap this logic into the above for iteration: having two yielding for loops works but is confusing at best. I see two alternatives:
- Mix the writing of latent items into the normal loop. Something like
for s, fo_s, fi_s in itertools.product(
self.samples, self.forecast_steps, self.streams.keys()
):
key = ItemKey(int(s), int(fo_s), fi_s)
if fi_s == LATENT_STREAM:
latent_item = self._make_latent_item(key)
if latent_item is not None:
yield latent_item
else:
yield self.extract(ItemKey(int(s), int(fo_s), fi_s))- have the writing of latent items in a separate method:
def latent_items(self):
if self.latents:
for s, fo_s in itertools.product(self.samples, self.forecast_steps):
key = ItemKey(int(s), int(fo_s), LATENT_STREAM)
latent_item = self._make_latent_item(key)
if latent_item is not None:
yield latent_item
...
with zarrio_writer(config.get_path_results(cf, mini_epoch)) as zio:
for subset in data.items():
zio.write_zarr(subset)
for latent in data.latent_items():
zio.write_zarr(latent)Option 2. is maybe a bit more clearer and more equivalent to the current solution and also provides more flexibility for the future. But Option 1 would be also fine with me.
| # optional name to use for latent pseudo-stream when yielding latent items | ||
| latent_stream_name: str | None = None |
There was a problem hiding this comment.
Please use a named constant for this.
Thanks for your feedback. Exposing the latent space via an JSON API is useful when running the model operationally. However, I still think we should export the latent state as a Zarr file, since this is useful for other applications, for instance explanable AI. |
Yes, json API is a separate issue and this PR should address writing the latent space to disc as zarr. |
clessig
left a comment
There was a problem hiding this comment.
Overall looks good. But validation_io will be refactored and it should be discussed how to best do the latent output going forward.
| output_streams = {} | ||
| for name in output_stream_names: | ||
| if name == "latent": | ||
| latent_stream_name = name |
There was a problem hiding this comment.
I don't understand the logic here. Wouldn't it be enough to have
if "latent" in output_stream_names
in l 158? Do we expect to have have multiple latent states?
| per_sample = {} | ||
| for lname, lval in latent_pred.items(): | ||
| if isinstance(lval, LatentState): | ||
| for field in ("z_pre_norm", "patch_tokens", "register_tokens", "class_token"): |
There was a problem hiding this comment.
The latent state that should be relevant for the output are the patch_tokens. These are used for the decoder. To be fully future proof we could have an argument which part of LatentState is written although it might be over-engineering.
|
|
||
| # collect latent outputs per forecast step and per sample (optional) | ||
| latents_all = [] | ||
| if latent_stream_name is not None: |
There was a problem hiding this comment.
This should go to a separate function.
Signed-off-by: evenmn <evenmn@mn.uio.no>
Signed-off-by: evenmn <evenmn@mn.uio.no>
| latent_stream_name: str | None = None | ||
| latent_stream_name: str | None = LATENT_STREAM |
There was a problem hiding this comment.
Please remove: since you are always using the default here anyway it makes no difference if self.latent_stream_name or LATENT_STREAM is used. But using using latent_stream_name clutters up the namespace/interface of OutputBatchData
There was a problem hiding this comment.
self.latent_stream_name is currently an alias to LATENT_STREAM which is never None. So please just use if self.latents (This should not disregard my previous comment on these lines.)
| stream_names = [stream.name for stream in cf.streams] | ||
| # include known pseudo-stream names (e.g. latent) so they are treated as known | ||
| if io.LATENT_STREAM not in stream_names: | ||
| stream_names.append(io.LATENT_STREAM) |
There was a problem hiding this comment.
Please remove and just use:
if io.LATENT_STREAM in output_stream_names:
output_streams[io.LATENT_STREAM] = Nonein ll. 136
| for name in output_stream_names: | ||
| if name == "latent": | ||
| if name == io.LATENT_STREAM: | ||
| latent_stream_name = name |
There was a problem hiding this comment.
Please remove this if clause, instead implement the suggestion I commented above
There was a problem hiding this comment.
Please use io.LATENT_STREAM here.
The choosen approach should translate relatively well into the refactored version. |
…get_latent_output' Signed-off-by: evenmn <evenmn@mn.uio.no>
|
Thanks to both of you for the feedback, it truly improved this PR. I believe I have incorporated the suggested changes in my latest commit, but I still need to test the implementation before it gets merged in. |
Can you please test it as far as you can, and then I have a final look. |
Description
This PR introduces a writer for the latent vector. To avoid additional config options, I decided to add "latent" as a special case of a stream output, e.g:
output.streams: ["ERA5", "latent"]Another option would be to have a dedicated config option,
write_latent=Trueor similar.Issue Number
Closes #1849
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60