Skip to content

ggml: Improve NVFP4 vecdot error#20435

Open
michaelw9999 wants to merge 1 commit intoggml-org:masterfrom
michaelw9999:nvfp4-improve-vec-dot
Open

ggml: Improve NVFP4 vecdot error#20435
michaelw9999 wants to merge 1 commit intoggml-org:masterfrom
michaelw9999:nvfp4-improve-vec-dot

Conversation

@michaelw9999
Copy link
Contributor

This update modifies:
const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);

That may not the best scale if you are considering all 16 weights in the subblock.
This check looks at some other nearby codes and calculates the difference to choose the best option.

This reduces the vecdot error from:

absolute quantization error:  0.002337
dot product error: 0.019774

to:

absolute quantization error: 0.002029
dot product error: 0.002411

This will help keep error down when using e2m1 x e2m1 on the GPU side or if a future quantizer gets implemented.

@michaelw9999 michaelw9999 requested a review from ggerganov as a code owner March 12, 2026 00:05
Copilot AI review requested due to automatic review settings March 12, 2026 00:05
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Mar 12, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Improves NVFP4 scale selection in quantize_row_nvfp4_ref by searching nearby UE4M3 encodings to minimize sub-block quantization error, reducing downstream vecdot error.

Changes:

  • Replaces single-shot UE4M3 scale (amax / 6) with a small neighborhood search around the initial encoding.
  • Chooses the candidate scale that minimizes summed squared reconstruction error over the sub-block.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +330 to +347
float lowest_err = INFINITY;
for (int difference = -2; difference <= 2; ++difference) {
const int candidate = (int) first_ue + difference;
if (candidate < 0 || candidate > 0x7E) {
continue;
}
const float test_scale = ggml_ue4m3_to_fp32((uint8_t) candidate);
float test_scale_error = 0.0f;
for (int j = 0; j < qk_sub; ++j) {
const int qi = best_index_mxfp4(xb[j], test_scale);
const float err = xb[j] - kvalues_mxfp4[qi] * test_scale;
test_scale_error += err * err;
}
if (test_scale_error < lowest_err) {
lowest_err = test_scale_error;
ue = (uint8_t) candidate;
}
}
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds up to 5 full passes over qk_sub per sub-block (re-running best_index_mxfp4 and error accumulation each time), which can noticeably increase quantization CPU cost. Consider adding an early-exit when test_scale_error reaches 0 (or below a tiny epsilon), and/or tightening the candidate search adaptively (e.g., evaluate difference=0 first, then only expand outward while error improves) to reduce worst-case work while keeping the accuracy benefit.

Copilot uses AI. Check for mistakes.
Comment on lines +331 to +333
for (int difference = -2; difference <= 2; ++difference) {
const int candidate = (int) first_ue + difference;
if (candidate < 0 || candidate > 0x7E) {
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The search window (-2..2) and the upper bound (0x7E) are magic constants here. Please add a short comment or named constants explaining (1) why a ±2 neighborhood is sufficient, and (2) why 0x7E is the maximum valid finite UE4M3 code (and what 0x7F represents). This will make the intent easier to maintain and less error-prone if the encoding rules change.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants