From d7ceb8c02229ae75db31a71fbf31231253901d68 Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Wed, 11 Mar 2026 21:28:09 -0400 Subject: [PATCH 1/2] ggml : transpose fused GDN state access for coalesced memory reads (#20436) The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix column-wise on row-major storage, causing strided reads (stride S_v = 128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a 39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused path. Transpose the state indexing so threads read contiguously: - Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v) - CUDA: curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced) - CPU: restructured loops for row-wise transposed access Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so users can control fused GDN independently of auto-detection. All GATED_DELTA_NET backend-ops tests pass. Co-Authored-By: Claude Opus 4.6 --- common/arg.cpp | 15 ++++++++++ common/common.cpp | 1 + common/common.h | 1 + ggml/src/ggml-cpu/ops.cpp | 42 +++++++++++++++++---------- ggml/src/ggml-cuda/gated_delta_net.cu | 7 +++-- ggml/src/ggml-metal/ggml-metal.metal | 9 +++--- include/llama.h | 9 ++++++ src/llama-context.cpp | 8 +++-- src/llama.cpp | 12 ++++++++ 9 files changed, 79 insertions(+), 25 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 41da8563d63..0f7caf1c8f4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1339,6 +1339,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg({ "-fgdn", "--fused-gdn" }, "[on|off|auto]", + string_format("set fused Gated Delta Net kernel use ('on', 'off', or 'auto', default: '%s')", + llama_fused_gdn_type_name(params.fused_gdn_type)), + [](common_params & params, const std::string & value) { + if (is_truthy(value)) { + params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_ENABLED; + } else if (is_falsey(value)) { + params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_DISABLED; + } else if (is_autoy(value)) { + params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_AUTO; + } else { + throw std::runtime_error( + string_format("error: unknown value for --fused-gdn: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_ARG_FUSED_GDN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index cc423d3439f..2e7ae7b0f31 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1369,6 +1369,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; cparams.flash_attn_type = params.flash_attn_type; + cparams.fused_gdn_type = params.fused_gdn_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index c5645bba460..0b30f0c8aab 100644 --- a/common/common.h +++ b/common/common.h @@ -411,6 +411,7 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention + enum llama_fused_gdn_type fused_gdn_type = LLAMA_FUSED_GDN_TYPE_AUTO; // whether to use fused Gated Delta Net kernel struct common_params_sampling sampling; struct common_params_speculative speculative; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fa9d27046b5..13ca087faf2 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10477,34 +10477,46 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + // state is stored transposed: s_out[j*S_v + i] = S[i][j] + // so row j of s_out = column j of S (contiguous access) + if (kda) { + // precompute exp(g) into delta scratch (reused below) for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i])); + delta[i] = expf(g_d[i]); + } + // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i]) + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta); } } else { ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); } - // delta[j] = sum_i S[j][i] * k[i] - memset(delta, 0, S_v * sizeof(float)); - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]); - } + // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k) for (int64_t j = 0; j < S_v; ++j) { - delta[j] = (v_d[j] - delta[j]) * beta_val; + const float * row_j = s_out + j * S_v; + float sum = 0.0f; + for (int64_t i = 0; i < S_v; ++i) { + sum += row_j[i] * k_d[i]; + } + delta[j] = (v_d[j] - sum) * beta_val; } - // outer product: S[j][i] += k[i] * delta[j] - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]); + // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i] + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]); } - // attn_out[j] = sum_i S[j][i] * q[i] - memset(attn_data, 0, S_v * sizeof(float)); - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]); + // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q) + for (int64_t j = 0; j < S_v; ++j) { + const float * row_j = s_out + j * S_v; + float sum = 0.0f; + for (int64_t i = 0; i < S_v; ++i) { + sum += row_j[i] * q_d[i]; + } + attn_data[j] = sum * scale; } - ggml_vec_scale_f32(S_v, attn_data, scale); attn_data += S_v * H; // advance to next token } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 5f0fa8e58df..eb38588202a 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q, static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; float s_shard[rows_per_lane]; + // state is stored transposed: M[col][i] = S[i][col], row col is contiguous #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[i * S_v + col]; + s_shard[r] = curr_state[col * S_v + i]; } for (int t = 0; t < n_tokens; t++) { @@ -126,11 +127,11 @@ __global__ void gated_delta_net_cuda(const float * q, attn_data += S_v * H; } - // Write state back to global memory + // Write state back to global memory (transposed layout) #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - state[i * S_v + col] = s_shard[r]; + state[col * S_v + i] = s_shard[r]; } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0b77d5349b8..e06f266b533 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2466,13 +2466,14 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; float ls[NSG]; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - ls[j] = s_ptr[is*S_v]; + ls[j] = s_ptr[is]; } device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; @@ -2533,11 +2534,11 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - dst_state[is*S_v] = ls[j]; + dst_state[is] = ls[j]; } #undef S_v diff --git a/include/llama.h b/include/llama.h index c6e102abe51..05e5152d0db 100644 --- a/include/llama.h +++ b/include/llama.h @@ -190,6 +190,14 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_fused_gdn_type { + LLAMA_FUSED_GDN_TYPE_AUTO = -1, + LLAMA_FUSED_GDN_TYPE_DISABLED = 0, + LLAMA_FUSED_GDN_TYPE_ENABLED = 1, + }; + + LLAMA_API const char * llama_fused_gdn_type_name(enum llama_fused_gdn_type fused_gdn_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -338,6 +346,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention + enum llama_fused_gdn_type fused_gdn_type; // when to enable fused Gated Delta Net kernel // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0be94939102..5deb81f8b17 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -150,9 +150,9 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; - cparams.fused_gdn_ar = true; - cparams.fused_gdn_ch = true; - cparams.auto_fgdn = true; + cparams.fused_gdn_ar = params.fused_gdn_type != LLAMA_FUSED_GDN_TYPE_DISABLED; + cparams.fused_gdn_ch = params.fused_gdn_type != LLAMA_FUSED_GDN_TYPE_DISABLED; + cparams.auto_fgdn = params.fused_gdn_type == LLAMA_FUSED_GDN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -200,6 +200,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); + LLAMA_LOG_INFO("%s: fused_gdn = %s\n", __func__, llama_fused_gdn_type_name(params.fused_gdn_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2869,6 +2870,7 @@ llama_context_params llama_context_default_params() { /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, + /*.fused_gdn_type =*/ LLAMA_FUSED_GDN_TYPE_AUTO, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, diff --git a/src/llama.cpp b/src/llama.cpp index 872e659edca..4bc996bbee9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -45,6 +45,18 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } +const char * llama_fused_gdn_type_name(enum llama_fused_gdn_type fused_gdn_type) { + switch (fused_gdn_type) { + case LLAMA_FUSED_GDN_TYPE_AUTO: + return "auto"; + case LLAMA_FUSED_GDN_TYPE_DISABLED: + return "disabled"; + case LLAMA_FUSED_GDN_TYPE_ENABLED: + return "enabled"; + } + GGML_ABORT("fatal error"); +} + struct llama_device_memory_data { int64_t total; int64_t free; From 6b1b5cf7abf31e01f7e3075d99551d171e1429dd Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Wed, 11 Mar 2026 22:53:50 -0400 Subject: [PATCH 2/2] ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags - Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized dot products in the CPU fused GDN kernel (delta and attention output) - Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one path lacks device support, disable both to prevent state layout mismatch between transposed (fused) and non-transposed (unfused) formats Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-cpu/ops.cpp | 10 ++-------- src/llama-context.cpp | 8 ++++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 13ca087faf2..e027325cd0a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10495,11 +10495,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k) for (int64_t j = 0; j < S_v; ++j) { - const float * row_j = s_out + j * S_v; float sum = 0.0f; - for (int64_t i = 0; i < S_v; ++i) { - sum += row_j[i] * k_d[i]; - } + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1); delta[j] = (v_d[j] - sum) * beta_val; } @@ -10510,11 +10507,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q) for (int64_t j = 0; j < S_v; ++j) { - const float * row_j = s_out + j * S_v; float sum = 0.0f; - for (int64_t i = 0; i < S_v; ++i) { - sum += row_j[i] * q_d[i]; - } + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1); attn_data[j] = sum * scale; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5deb81f8b17..4586bd07192 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -538,6 +538,14 @@ void llama_context::sched_reserve() { } } + // the fused kernel uses a transposed state layout, so both paths must agree + // to avoid a state layout mismatch when switching between AR and chunked + if (cparams.fused_gdn_ar != cparams.fused_gdn_ch) { + cparams.fused_gdn_ar = false; + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net AR/chunked support mismatch, disabling both\n", __func__); + } + cparams.auto_fgdn = false; }