From b0dbb39e1047f39756fe882c8db4d8fa6b77e921 Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Wed, 11 Mar 2026 21:28:09 -0400 Subject: [PATCH 1/6] ggml : transpose fused GDN state access for coalesced memory reads (#20436) The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix column-wise on row-major storage, causing strided reads (stride S_v = 128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a 39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused path. Transpose the state indexing so threads read contiguously: - Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v) - CUDA: curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced) - CPU: restructured loops for row-wise transposed access Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so users can control fused GDN independently of auto-detection. All GATED_DELTA_NET backend-ops tests pass. Co-Authored-By: Claude Opus 4.6 --- common/arg.cpp | 15 ++++++++++ common/common.cpp | 1 + common/common.h | 1 + ggml/src/ggml-cpu/ops.cpp | 42 +++++++++++++++++---------- ggml/src/ggml-cuda/gated_delta_net.cu | 7 +++-- ggml/src/ggml-metal/ggml-metal.metal | 9 +++--- include/llama.h | 9 ++++++ src/llama-context.cpp | 8 +++-- src/llama.cpp | 12 ++++++++ 9 files changed, 79 insertions(+), 25 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 10aa1b5e4fe..42a9a0752a2 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1344,6 +1344,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg({ "-fgdn", "--fused-gdn" }, "[on|off|auto]", + string_format("set fused Gated Delta Net kernel use ('on', 'off', or 'auto', default: '%s')", + llama_fused_gdn_type_name(params.fused_gdn_type)), + [](common_params & params, const std::string & value) { + if (is_truthy(value)) { + params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_ENABLED; + } else if (is_falsey(value)) { + params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_DISABLED; + } else if (is_autoy(value)) { + params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_AUTO; + } else { + throw std::runtime_error( + string_format("error: unknown value for --fused-gdn: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_ARG_FUSED_GDN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index cc423d3439f..2e7ae7b0f31 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1369,6 +1369,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; cparams.flash_attn_type = params.flash_attn_type; + cparams.fused_gdn_type = params.fused_gdn_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index ee7a2d805e0..f693d56279d 100644 --- a/common/common.h +++ b/common/common.h @@ -412,6 +412,7 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention + enum llama_fused_gdn_type fused_gdn_type = LLAMA_FUSED_GDN_TYPE_AUTO; // whether to use fused Gated Delta Net kernel struct common_params_sampling sampling; struct common_params_speculative speculative; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 85db02d92f1..85de92eee56 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10477,34 +10477,46 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + // state is stored transposed: s_out[j*S_v + i] = S[i][j] + // so row j of s_out = column j of S (contiguous access) + if (kda) { + // precompute exp(g) into delta scratch (reused below) for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i])); + delta[i] = expf(g_d[i]); + } + // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i]) + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta); } } else { ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); } - // delta[j] = sum_i S[j][i] * k[i] - memset(delta, 0, S_v * sizeof(float)); - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]); - } + // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k) for (int64_t j = 0; j < S_v; ++j) { - delta[j] = (v_d[j] - delta[j]) * beta_val; + const float * row_j = s_out + j * S_v; + float sum = 0.0f; + for (int64_t i = 0; i < S_v; ++i) { + sum += row_j[i] * k_d[i]; + } + delta[j] = (v_d[j] - sum) * beta_val; } - // outer product: S[j][i] += k[i] * delta[j] - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]); + // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i] + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]); } - // attn_out[j] = sum_i S[j][i] * q[i] - memset(attn_data, 0, S_v * sizeof(float)); - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]); + // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q) + for (int64_t j = 0; j < S_v; ++j) { + const float * row_j = s_out + j * S_v; + float sum = 0.0f; + for (int64_t i = 0; i < S_v; ++i) { + sum += row_j[i] * q_d[i]; + } + attn_data[j] = sum * scale; } - ggml_vec_scale_f32(S_v, attn_data, scale); attn_data += S_v * H; // advance to next token } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 5f0fa8e58df..eb38588202a 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q, static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; float s_shard[rows_per_lane]; + // state is stored transposed: M[col][i] = S[i][col], row col is contiguous #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[i * S_v + col]; + s_shard[r] = curr_state[col * S_v + i]; } for (int t = 0; t < n_tokens; t++) { @@ -126,11 +127,11 @@ __global__ void gated_delta_net_cuda(const float * q, attn_data += S_v * H; } - // Write state back to global memory + // Write state back to global memory (transposed layout) #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - state[i * S_v + col] = s_shard[r]; + state[col * S_v + i] = s_shard[r]; } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 107e7cf2ff3..d4b129ed756 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2469,13 +2469,14 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; float ls[NSG]; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - ls[j] = s_ptr[is*S_v]; + ls[j] = s_ptr[is]; } device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; @@ -2536,11 +2537,11 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - dst_state[is*S_v] = ls[j]; + dst_state[is] = ls[j]; } #undef S_v diff --git a/include/llama.h b/include/llama.h index c6e102abe51..05e5152d0db 100644 --- a/include/llama.h +++ b/include/llama.h @@ -190,6 +190,14 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); + enum llama_fused_gdn_type { + LLAMA_FUSED_GDN_TYPE_AUTO = -1, + LLAMA_FUSED_GDN_TYPE_DISABLED = 0, + LLAMA_FUSED_GDN_TYPE_ENABLED = 1, + }; + + LLAMA_API const char * llama_fused_gdn_type_name(enum llama_fused_gdn_type fused_gdn_type); + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -338,6 +346,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention + enum llama_fused_gdn_type fused_gdn_type; // when to enable fused Gated Delta Net kernel // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 14dccac5b55..f20fce74c4e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -151,9 +151,9 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; - cparams.fused_gdn_ar = true; - cparams.fused_gdn_ch = true; - cparams.auto_fgdn = true; + cparams.fused_gdn_ar = params.fused_gdn_type != LLAMA_FUSED_GDN_TYPE_DISABLED; + cparams.fused_gdn_ch = params.fused_gdn_type != LLAMA_FUSED_GDN_TYPE_DISABLED; + cparams.auto_fgdn = params.fused_gdn_type == LLAMA_FUSED_GDN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -201,6 +201,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); + LLAMA_LOG_INFO("%s: fused_gdn = %s\n", __func__, llama_fused_gdn_type_name(params.fused_gdn_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -2878,6 +2879,7 @@ llama_context_params llama_context_default_params() { /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, + /*.fused_gdn_type =*/ LLAMA_FUSED_GDN_TYPE_AUTO, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, diff --git a/src/llama.cpp b/src/llama.cpp index 872e659edca..4bc996bbee9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -45,6 +45,18 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } +const char * llama_fused_gdn_type_name(enum llama_fused_gdn_type fused_gdn_type) { + switch (fused_gdn_type) { + case LLAMA_FUSED_GDN_TYPE_AUTO: + return "auto"; + case LLAMA_FUSED_GDN_TYPE_DISABLED: + return "disabled"; + case LLAMA_FUSED_GDN_TYPE_ENABLED: + return "enabled"; + } + GGML_ABORT("fatal error"); +} + struct llama_device_memory_data { int64_t total; int64_t free; From fb32cd487b0cd1ba506a00f1c1781aed387e594e Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Wed, 11 Mar 2026 22:53:50 -0400 Subject: [PATCH 2/6] ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags - Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized dot products in the CPU fused GDN kernel (delta and attention output) - Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one path lacks device support, disable both to prevent state layout mismatch between transposed (fused) and non-transposed (unfused) formats Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-cpu/ops.cpp | 10 ++-------- src/llama-context.cpp | 8 ++++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 85de92eee56..314cc1088a0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10495,11 +10495,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k) for (int64_t j = 0; j < S_v; ++j) { - const float * row_j = s_out + j * S_v; float sum = 0.0f; - for (int64_t i = 0; i < S_v; ++i) { - sum += row_j[i] * k_d[i]; - } + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1); delta[j] = (v_d[j] - sum) * beta_val; } @@ -10510,11 +10507,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q) for (int64_t j = 0; j < S_v; ++j) { - const float * row_j = s_out + j * S_v; float sum = 0.0f; - for (int64_t i = 0; i < S_v; ++i) { - sum += row_j[i] * q_d[i]; - } + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1); attn_data[j] = sum * scale; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f20fce74c4e..3feac99a302 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -547,6 +547,14 @@ void llama_context::sched_reserve() { } } + // the fused kernel uses a transposed state layout, so both paths must agree + // to avoid a state layout mismatch when switching between AR and chunked + if (cparams.fused_gdn_ar != cparams.fused_gdn_ch) { + cparams.fused_gdn_ar = false; + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net AR/chunked support mismatch, disabling both\n", __func__); + } + cparams.auto_fgdn = false; } From d9a7ab365fe7f3ea49b31cfbcde879d9f42d67ff Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Mar 2026 08:29:28 +0200 Subject: [PATCH 3/6] llama : rever fgdn argument changes --- common/arg.cpp | 15 --------------- common/common.cpp | 1 - common/common.h | 1 - include/llama.h | 9 --------- src/llama-context.cpp | 16 +++------------- src/llama.cpp | 12 ------------ 6 files changed, 3 insertions(+), 51 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 42a9a0752a2..10aa1b5e4fe 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1344,21 +1344,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("error: unknown value for --flash-attn: '%s'\n", value.c_str())); } }).set_env("LLAMA_ARG_FLASH_ATTN")); - add_opt(common_arg({ "-fgdn", "--fused-gdn" }, "[on|off|auto]", - string_format("set fused Gated Delta Net kernel use ('on', 'off', or 'auto', default: '%s')", - llama_fused_gdn_type_name(params.fused_gdn_type)), - [](common_params & params, const std::string & value) { - if (is_truthy(value)) { - params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_ENABLED; - } else if (is_falsey(value)) { - params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_DISABLED; - } else if (is_autoy(value)) { - params.fused_gdn_type = LLAMA_FUSED_GDN_TYPE_AUTO; - } else { - throw std::runtime_error( - string_format("error: unknown value for --fused-gdn: '%s'\n", value.c_str())); - } - }).set_env("LLAMA_ARG_FUSED_GDN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", diff --git a/common/common.cpp b/common/common.cpp index 2e7ae7b0f31..cc423d3439f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1369,7 +1369,6 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.pooling_type = params.pooling_type; cparams.attention_type = params.attention_type; cparams.flash_attn_type = params.flash_attn_type; - cparams.fused_gdn_type = params.fused_gdn_type; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; diff --git a/common/common.h b/common/common.h index f693d56279d..ee7a2d805e0 100644 --- a/common/common.h +++ b/common/common.h @@ -412,7 +412,6 @@ struct common_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention - enum llama_fused_gdn_type fused_gdn_type = LLAMA_FUSED_GDN_TYPE_AUTO; // whether to use fused Gated Delta Net kernel struct common_params_sampling sampling; struct common_params_speculative speculative; diff --git a/include/llama.h b/include/llama.h index 05e5152d0db..c6e102abe51 100644 --- a/include/llama.h +++ b/include/llama.h @@ -190,14 +190,6 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); - enum llama_fused_gdn_type { - LLAMA_FUSED_GDN_TYPE_AUTO = -1, - LLAMA_FUSED_GDN_TYPE_DISABLED = 0, - LLAMA_FUSED_GDN_TYPE_ENABLED = 1, - }; - - LLAMA_API const char * llama_fused_gdn_type_name(enum llama_fused_gdn_type fused_gdn_type); - enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -346,7 +338,6 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention - enum llama_fused_gdn_type fused_gdn_type; // when to enable fused Gated Delta Net kernel // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3feac99a302..14dccac5b55 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -151,9 +151,9 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; - cparams.fused_gdn_ar = params.fused_gdn_type != LLAMA_FUSED_GDN_TYPE_DISABLED; - cparams.fused_gdn_ch = params.fused_gdn_type != LLAMA_FUSED_GDN_TYPE_DISABLED; - cparams.auto_fgdn = params.fused_gdn_type == LLAMA_FUSED_GDN_TYPE_AUTO; + cparams.fused_gdn_ar = true; + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -201,7 +201,6 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); - LLAMA_LOG_INFO("%s: fused_gdn = %s\n", __func__, llama_fused_gdn_type_name(params.fused_gdn_type)); LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -547,14 +546,6 @@ void llama_context::sched_reserve() { } } - // the fused kernel uses a transposed state layout, so both paths must agree - // to avoid a state layout mismatch when switching between AR and chunked - if (cparams.fused_gdn_ar != cparams.fused_gdn_ch) { - cparams.fused_gdn_ar = false; - cparams.fused_gdn_ch = false; - LLAMA_LOG_WARN("%s: fused Gated Delta Net AR/chunked support mismatch, disabling both\n", __func__); - } - cparams.auto_fgdn = false; } @@ -2887,7 +2878,6 @@ llama_context_params llama_context_default_params() { /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, - /*.fused_gdn_type =*/ LLAMA_FUSED_GDN_TYPE_AUTO, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, diff --git a/src/llama.cpp b/src/llama.cpp index 4bc996bbee9..872e659edca 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -45,18 +45,6 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } -const char * llama_fused_gdn_type_name(enum llama_fused_gdn_type fused_gdn_type) { - switch (fused_gdn_type) { - case LLAMA_FUSED_GDN_TYPE_AUTO: - return "auto"; - case LLAMA_FUSED_GDN_TYPE_DISABLED: - return "disabled"; - case LLAMA_FUSED_GDN_TYPE_ENABLED: - return "enabled"; - } - GGML_ABORT("fatal error"); -} - struct llama_device_memory_data { int64_t total; int64_t free; From fe3ef4a0bcf395213b390cec7918ef47c0d8cd49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Mar 2026 08:30:40 +0200 Subject: [PATCH 4/6] graph : remove GDN state transposes --- src/models/delta-net-base.cpp | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index a62dbc15dd0..6bc989c9509 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -225,9 +225,8 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); cb(kg_t, "key_gdiff_t", il); - ggml_tensor * s_t = ggml_transpose(ctx0, s); - s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); - cb(s_t, "dnet_add_ch_state", il); + s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs); + cb(s, "dnet_add_ch_state", il); // [CS, S_v, n_chunks, H_v * n_seqs] ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); @@ -240,7 +239,7 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s); cb(v_t_p, "v_prime", il); // [CS, S_v, 1, H_v * n_seqs] @@ -252,7 +251,7 @@ std::pair llm_build_delta_net_base::build_delta_ne cb(v_attn, "v_attn", il); // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp); cb(attn_inter, "attn_inter", il); // [S_v, CS, 1, H_v * n_seqs] @@ -268,13 +267,11 @@ std::pair llm_build_delta_net_base::build_delta_ne // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); - s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t); - s_t = ggml_add(ctx0, s_t, kgv); - cb(s_t, "dnet_add_ch_state", il); + s = ggml_mul(ctx0, s, ch_g_last_exp_t); + s = ggml_add(ctx0, s, kgv); + cb(s, "dnet_add_ch_state", il); } - s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); - // truncate padded tokens ggml_tensor * o = ggml_view_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs, @@ -282,7 +279,7 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_row_size(v->type, S_v * CS * n_chunks), ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); + s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs); cb(s, "output_state", il); return {o, s}; @@ -341,11 +338,9 @@ std::pair llm_build_delta_net_base::build_delta_ne g = ggml_exp(ctx0, g); s = ggml_mul(ctx0, s, g); - ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); - // [1, S_v, H_v, n_seqs] ggml_tensor * sk; - sk = ggml_mul (ctx0, s_t, k); + sk = ggml_mul (ctx0, s, k); sk = ggml_sum_rows(ctx0, sk); // [S_v, 1, H_v, n_seqs] @@ -362,15 +357,14 @@ std::pair llm_build_delta_net_base::build_delta_ne k = ggml_repeat(ctx0, k, s); kd = ggml_mul (ctx0, k, d_t); - s_t = ggml_add(ctx0, s_t, kd); + s = ggml_add(ctx0, s, kd); - cb(s_t, "dnet_add_ar_state", il); + cb(s, "dnet_add_ar_state", il); - ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * s_q = ggml_mul (ctx0, s, q); ggml_tensor * o = ggml_sum_rows(ctx0, s_q); o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] return {o, s}; } From 2882a4b882b6560bebd87a7da2c7f908efe92ca8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Mar 2026 18:48:25 +0200 Subject: [PATCH 5/6] vulkan : adapt --- ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 1fdf889e824..f008859b99d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -44,7 +44,7 @@ void main() { FLOAT_TYPE state[S_V]; [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]); + state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -123,6 +123,6 @@ void main() { } [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + i * S_V + col] = state[i]; + data_dst[s_off + state_base + col * S_V + i] = state[i]; } } From d466d89b609d3aa66ff480024a6fc0b696eb82af Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Fri, 13 Mar 2026 19:20:08 +0200 Subject: [PATCH 6/6] cuda : remove obsolete smem code --- ggml/src/ggml-cuda/gated_delta_net.cu | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index eb38588202a..1ce6d5f31b5 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -135,15 +135,6 @@ __global__ void gated_delta_net_cuda(const float * q, } } -static size_t calculate_smem(const int sv, int cc) -{ - size_t smem = 0; - if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) { - smem = sv * sv * sizeof(float); - } - return smem; -} - template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, @@ -180,18 +171,14 @@ static void launch_gated_delta_net( sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 64: { - constexpr int sv = 64; - size_t smem = calculate_smem(sv, cc); - gated_delta_net_cuda<<>>( + gated_delta_net_cuda<64, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { - constexpr int sv = 128; - size_t smem = calculate_smem(sv, cc); - gated_delta_net_cuda<<>>( + gated_delta_net_cuda<128, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);