Skip to content

Handwritten gemm for ADA#66

Merged
ngc92 merged 16 commits intodevfrom
gemm
Mar 5, 2026
Merged

Handwritten gemm for ADA#66
ngc92 merged 16 commits intodevfrom
gemm

Conversation

@ngc92
Copy link
Collaborator

@ngc92 ngc92 commented Mar 4, 2026

0.5% speedup for 1.5B on 4x4090. YMMV for other models and cards.

Copilot AI review requested due to automatic review settings March 4, 2026 15:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an optional handwritten Tensor Core GEMM path (targeting Ada) and threads it through training/runtime configuration so users can switch between cuBLASLt and the custom kernel.

Changes:

  • Add EMatmulBackend plumbing across C++ training code, CLI, and Python bindings/config to select cuBLASLt vs custom GEMM.
  • Implement a custom TN GEMM kernel (gemm_mma_tn) for BF16/FP8→BF16 and dispatch it from matmul.
  • Add/adjust tests and CI coverage (new GEMM unit test; RoPE test updated to match fp16 freqs).

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
train.cpp Adds --custom-matmul CLI flag.
src/utilities/sol.cpp Updates matmul calls for new backend argument.
src/training/model.h Adds MatmulBackend to run state.
src/testing/test_utils.h Adds float→fp16 helper used by tests.
src/testing/test-rope.cu Updates RoPE test to use fp16 freqs.
src/testing/test-gemm.cpp New unit tests comparing custom GEMM vs cuBLAS.
src/models/llama_model.h Adds UseCustomMatmul option.
src/models/llama_model.cpp Threads backend selection through all matmul call sites; initializes backend from options.
src/kernels/tensor_core_utils.cuh New low-level helpers for ldmatrix/mma fragments.
src/kernels/matmul.cpp Adds backend-aware dispatch and hooks custom GEMM.
src/kernels/kernels.h Adds EMatmulBackend and extends matmul APIs with backend parameter.
src/kernels/kernels.cpp Threads backend through Tensor-based matmul wrapper.
src/kernels/gemm_mma.cu Implements the custom GEMM kernel and launcher.
src/binding/python/training.py Adds custom_matmul to TrainingConfig.
src/binding/python/tests/run.py Wires config → LLamaOptions; adds argparse flag.
src/binding/kernel_binding.cpp Extends Python matmul binding with a backend parameter.
src/binding/binding.cpp Exposes use_custom_matmul on LLamaOptions in Python.
scripts/train.py Adds custom_matmul toggle and passes it into options.
CMakeLists.txt Builds new kernel and new unit test.
.github/workflows/wheel.yml Adds a Modal CI job exercising --custom-matmul.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings March 4, 2026 23:58
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 7 comments.

Comments suppressed due to low confidence (1)

src/testing/test-rope.cu:215

  • The bf16 RoPE test now sends half (fp16) freqs to the GPU, but the CPU baseline still quantizes h_freqs_f with round_bf16. This makes the reference computation use different frequency precision than the kernel and can cause incorrect comparisons. Update the CPU baseline to quantize/emulate freqs in fp16 (or keep freqs in fp32) to match the kernel’s half* freqs_cis contract.
    // Prepare freqs and quantize to fp16 (kernel expects fp16 freqs)
    std::vector<float> h_freqs_f(size_freqs);
    precompute_freqs_cis(h_freqs_f.data(), HD, T, 10000.0f);
    std::vector<half> h_freqs_fp16 = to_fp16(h_freqs_f);

    // CPU baseline with bf16 emulation: quantize inputs/freqs to bf16, do math in float, quantize outputs
    std::vector<float> h_inp_q = round_bf16(h_inp_f);
    std::vector<float> h_freqs_q = round_bf16(h_freqs_f);


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ngc92 ngc92 merged commit 69bc1f3 into dev Mar 5, 2026
34 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants