Skip to content

Generic gradients#67

Merged
ngc92 merged 7 commits intodevfrom
generic-gradients
Mar 5, 2026
Merged

Generic gradients#67
ngc92 merged 7 commits intodevfrom
generic-gradients

Conversation

@ngc92
Copy link
Collaborator

@ngc92 ngc92 commented Mar 5, 2026

generic implementation for gradient manager

Copilot AI review requested due to automatic review settings March 5, 2026 01:02
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the gradient management system from a LLaMA-model-specific LLamaGradsManager class into a generic IGradientManager interface with two reusable base implementations (UnshardedGradientManager and ShardedBlocksGradientManager). LLaMA-specific behavior (gradient zeroing patterns, communication strategies) is now injected through virtual hooks, making the gradient management infrastructure reusable across model architectures.

Changes:

  • New IGradientManager interface (and two base implementations) in src/training/gradients.{h,cpp}, with generic get_non_block_full/shard and get_block_full/shard API using integer weight indices instead of named methods.
  • New shard_empty_container/shard_view utilities for GenericTensorContainer in tensor_container.h/tensor.cpp, and two new reduce_scatter overloads on NCCLCommunicator.
  • Replacement of LLamaGradsManager throughout llama_model.cpp, llama_gradients.h/cpp, adamw_optimizer.cpp, and py_train.cpp with the new generic interface.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
src/training/gradients.h New generic IGradientManager, UnshardedGradientManager, ShardedBlocksGradientManager classes
src/training/gradients.cpp Implementation of the three gradient manager classes
src/utilities/tensor_container.h Adds shard_empty_container and shard_view(GenericTensorContainer) declarations
src/utilities/tensor.cpp Implements shard_empty_container and shard_view(GenericTensorContainer)
src/utilities/comm.h Adds reduce_scatter overloads for single tensor and container
src/utilities/comm.cpp Implements the new reduce_scatter overloads
src/models/llama_gradients.h Restructures LLaMA gradient classes to extend the new generic base classes
src/models/llama_gradients.cpp Removes large amounts of boilerplate; retains only LLaMA-specific hook implementations
src/models/llama_model.h Grads member changed to IGradientManager; grads() returns IGradientManager&
src/models/llama_model.cpp All gradient API calls updated to use new generic index-based API
src/training/adamw_optimizer.cpp Replaces local shard_container with the new shard_empty_container utility
src/binding/py_train.cpp Updates get_gradients to use the new generic API
CMakeLists.txt Adds src/training/gradients.cpp to the build

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

using namespace LLamaWeightID;

std::vector<std::pair<std::string, Tensor>> result;
// TODO make this work with generalized gradients
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment here indicates that get_gradients still hardcodes LLaMA-specific weight IDs (using LLamaWeightID::EMBEDDING, LM_HEAD, LNF_W, etc.) rather than leveraging the generic IGradientManager interface. This means the function is LLaMA-specific and won't work correctly for other model architectures that may use IGradientManager. Since the PR's goal is to provide a generic gradient manager implementation, this is a known gap that should be tracked.

Copilot uses AI. Check for mistakes.
@ngc92 ngc92 merged commit 79c8c9f into dev Mar 5, 2026
60 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants