Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the gradient management system from a LLaMA-model-specific LLamaGradsManager class into a generic IGradientManager interface with two reusable base implementations (UnshardedGradientManager and ShardedBlocksGradientManager). LLaMA-specific behavior (gradient zeroing patterns, communication strategies) is now injected through virtual hooks, making the gradient management infrastructure reusable across model architectures.
Changes:
- New
IGradientManagerinterface (and two base implementations) insrc/training/gradients.{h,cpp}, with genericget_non_block_full/shardandget_block_full/shardAPI using integer weight indices instead of named methods. - New
shard_empty_container/shard_viewutilities forGenericTensorContainerintensor_container.h/tensor.cpp, and two newreduce_scatteroverloads onNCCLCommunicator. - Replacement of
LLamaGradsManagerthroughoutllama_model.cpp,llama_gradients.h/cpp,adamw_optimizer.cpp, andpy_train.cppwith the new generic interface.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/training/gradients.h |
New generic IGradientManager, UnshardedGradientManager, ShardedBlocksGradientManager classes |
src/training/gradients.cpp |
Implementation of the three gradient manager classes |
src/utilities/tensor_container.h |
Adds shard_empty_container and shard_view(GenericTensorContainer) declarations |
src/utilities/tensor.cpp |
Implements shard_empty_container and shard_view(GenericTensorContainer) |
src/utilities/comm.h |
Adds reduce_scatter overloads for single tensor and container |
src/utilities/comm.cpp |
Implements the new reduce_scatter overloads |
src/models/llama_gradients.h |
Restructures LLaMA gradient classes to extend the new generic base classes |
src/models/llama_gradients.cpp |
Removes large amounts of boilerplate; retains only LLaMA-specific hook implementations |
src/models/llama_model.h |
Grads member changed to IGradientManager; grads() returns IGradientManager& |
src/models/llama_model.cpp |
All gradient API calls updated to use new generic index-based API |
src/training/adamw_optimizer.cpp |
Replaces local shard_container with the new shard_empty_container utility |
src/binding/py_train.cpp |
Updates get_gradients to use the new generic API |
CMakeLists.txt |
Adds src/training/gradients.cpp to the build |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| using namespace LLamaWeightID; | ||
|
|
||
| std::vector<std::pair<std::string, Tensor>> result; | ||
| // TODO make this work with generalized gradients |
There was a problem hiding this comment.
The TODO comment here indicates that get_gradients still hardcodes LLaMA-specific weight IDs (using LLamaWeightID::EMBEDDING, LM_HEAD, LNF_W, etc.) rather than leveraging the generic IGradientManager interface. This means the function is LLaMA-specific and won't work correctly for other model architectures that may use IGradientManager. Since the PR's goal is to provide a generic gradient manager implementation, this is a known gap that should be tracked.
generic implementation for gradient manager