Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a bug in FP8 optimizer state allocation that caused failures when model weight tensors had a first dimension not divisible by 128 × world_size. The fix introduces a flat_view / flattened_view utility to correctly shape FP8 momentum scaling tensors by flattening each weight tensor before sharding, rather than directly dividing the first dimension of multi-dimensional tensors.
Changes:
- New
flat_view(Tensor)andflattened_view(GenericTensorContainer)utility functions to produce 1D views of tensors and containers - Rewrote FP8 scale tensor shaping in
AdamWStateManager::allocate_stateto flatten first, shard bymWorld, then shard by 128 - Added early-exit and null-pointer guards to
fill_impto safely handle zero-element tensors
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/utilities/tensor.h |
Declares new flat_view(const Tensor&) function |
src/utilities/tensor.cpp |
Implements flat_view and flattened_view(GenericTensorContainer) |
src/utilities/tensor_container.h |
Declares flattened_view(GenericTensorContainer) |
src/training/adamw_optimizer.cpp |
Replaces direct shard_empty_container(_, 128*mWorld) with a correct flatten-then-shard approach for FP8 scale tensors |
src/kernels/fill.cu |
Adds count == 0 early return and null pointer guard to fill_imp |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // flatten the local shard | ||
| auto flattened = flattened_view(sharded); | ||
| // and group into scaling groups | ||
| auto grouped = shard_empty_container(std::move(flattened), 128); |
There was a problem hiding this comment.
The second call to flattened_view(sharded) on this line is redundant. After shard_empty_container(flattened_view(c), mWorld), all tensors in the container are already 1D (rank 1), so applying flattened_view again produces the same shapes. The sharded container can be passed directly to shard_empty_container on the next line. Removing this call would simplify the logic and avoid allocating an unnecessary intermediate container.
| // flatten the local shard | |
| auto flattened = flattened_view(sharded); | |
| // and group into scaling groups | |
| auto grouped = shard_empty_container(std::move(flattened), 128); | |
| // tensors in 'sharded' are already 1D; directly group into scaling groups | |
| auto grouped = shard_empty_container(std::move(sharded), 128); |
| auto grouped = shard_empty_container(std::move(flattened), 128); | ||
| return grouped; | ||
| }; | ||
| // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights. |
There was a problem hiding this comment.
The comment on this line ("we 'shard' for 128 as many GPUs, so that we get 1 scale per 128 weights") is outdated. The old code directly divided by 128 * mWorld; the new code first shards by mWorld (matching main weight sharding), then flattens the local shard, and shards by 128 to get one scale per 128 weights. The comment should be updated to reflect the two-step sharding in prepare_shape_for_scales.
| // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights. | |
| // we first shard by mWorld (matching main weights), then shard the local | |
| // flattened view by 128 to get 1 scale per 128 weights. |
| //! are `nullptr`, but sizes have been set up. | ||
| GenericTensorContainer shard_empty_container(GenericTensorContainer&& c, int world); | ||
|
|
||
| //! Flattens all tensors is the container. |
There was a problem hiding this comment.
The doc comment says "Flattens all tensors is the container" — "is" should be "in".
| //! Flattens all tensors is the container. | |
| //! Flattens all tensors in the container. |
seems to work again for 1.5b