-
Notifications
You must be signed in to change notification settings - Fork 14
Fp8 opt bugfix #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Fp8 opt bugfix #68
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -177,17 +177,27 @@ void AdamWStateManager::allocate_state(IModel& model, cudaStream_t stream, EAllo | |||||||
| } | ||||||||
|
|
||||||||
| mBlocksMScales.resize(mConfig.NumLayers); | ||||||||
|
|
||||||||
| if(mMType == ETensorDType::FP8_E4M3) { | ||||||||
| auto prepare_shape_for_scales = [&](auto&& c) { | ||||||||
| // creates shards same as main weight | ||||||||
| auto sharded = shard_empty_container(flattened_view(c), mWorld); | ||||||||
| // flatten the local shard | ||||||||
| auto flattened = flattened_view(sharded); | ||||||||
| // and group into scaling groups | ||||||||
| auto grouped = shard_empty_container(std::move(flattened), 128); | ||||||||
| return grouped; | ||||||||
| }; | ||||||||
| // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights. | ||||||||
|
||||||||
| // we "shard" for 128 as many GPUs, so that we get 1 scale per 128 weights. | |
| // we first shard by mWorld (matching main weights), then shard the local | |
| // flattened view by 128 to get 1 scale per 128 weights. |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -61,6 +61,9 @@ class GenericTensorContainer final : public SimpleTensorContainer { | |||||
| //! are `nullptr`, but sizes have been set up. | ||||||
| GenericTensorContainer shard_empty_container(GenericTensorContainer&& c, int world); | ||||||
|
|
||||||
| //! Flattens all tensors is the container. | ||||||
|
||||||
| //! Flattens all tensors is the container. | |
| //! Flattens all tensors in the container. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second call to
flattened_view(sharded)on this line is redundant. Aftershard_empty_container(flattened_view(c), mWorld), all tensors in the container are already 1D (rank 1), so applyingflattened_viewagain produces the same shapes. Theshardedcontainer can be passed directly toshard_empty_containeron the next line. Removing this call would simplify the logic and avoid allocating an unnecessary intermediate container.