Conversation
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces an optional handwritten Tensor Core GEMM path (targeting Ada) and threads it through training/runtime configuration so users can switch between cuBLASLt and the custom kernel.
Changes:
- Add
EMatmulBackendplumbing across C++ training code, CLI, and Python bindings/config to select cuBLASLt vs custom GEMM. - Implement a custom TN GEMM kernel (
gemm_mma_tn) for BF16/FP8→BF16 and dispatch it frommatmul. - Add/adjust tests and CI coverage (new GEMM unit test; RoPE test updated to match fp16 freqs).
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| train.cpp | Adds --custom-matmul CLI flag. |
| src/utilities/sol.cpp | Updates matmul calls for new backend argument. |
| src/training/model.h | Adds MatmulBackend to run state. |
| src/testing/test_utils.h | Adds float→fp16 helper used by tests. |
| src/testing/test-rope.cu | Updates RoPE test to use fp16 freqs. |
| src/testing/test-gemm.cpp | New unit tests comparing custom GEMM vs cuBLAS. |
| src/models/llama_model.h | Adds UseCustomMatmul option. |
| src/models/llama_model.cpp | Threads backend selection through all matmul call sites; initializes backend from options. |
| src/kernels/tensor_core_utils.cuh | New low-level helpers for ldmatrix/mma fragments. |
| src/kernels/matmul.cpp | Adds backend-aware dispatch and hooks custom GEMM. |
| src/kernels/kernels.h | Adds EMatmulBackend and extends matmul APIs with backend parameter. |
| src/kernels/kernels.cpp | Threads backend through Tensor-based matmul wrapper. |
| src/kernels/gemm_mma.cu | Implements the custom GEMM kernel and launcher. |
| src/binding/python/training.py | Adds custom_matmul to TrainingConfig. |
| src/binding/python/tests/run.py | Wires config → LLamaOptions; adds argparse flag. |
| src/binding/kernel_binding.cpp | Extends Python matmul binding with a backend parameter. |
| src/binding/binding.cpp | Exposes use_custom_matmul on LLamaOptions in Python. |
| scripts/train.py | Adds custom_matmul toggle and passes it into options. |
| CMakeLists.txt | Builds new kernel and new unit test. |
| .github/workflows/wheel.yml | Adds a Modal CI job exercising --custom-matmul. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Contributor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 7 comments.
Comments suppressed due to low confidence (1)
src/testing/test-rope.cu:215
- The bf16 RoPE test now sends
half(fp16) freqs to the GPU, but the CPU baseline still quantizesh_freqs_fwithround_bf16. This makes the reference computation use different frequency precision than the kernel and can cause incorrect comparisons. Update the CPU baseline to quantize/emulate freqs in fp16 (or keep freqs in fp32) to match the kernel’shalf* freqs_ciscontract.
// Prepare freqs and quantize to fp16 (kernel expects fp16 freqs)
std::vector<float> h_freqs_f(size_freqs);
precompute_freqs_cis(h_freqs_f.data(), HD, T, 10000.0f);
std::vector<half> h_freqs_fp16 = to_fp16(h_freqs_f);
// CPU baseline with bf16 emulation: quantize inputs/freqs to bf16, do math in float, quantize outputs
std::vector<float> h_inp_q = round_bf16(h_inp_f);
std::vector<float> h_freqs_q = round_bf16(h_freqs_f);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
0.5% speedup for 1.5B on 4x4090. YMMV for other models and cards.