Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10477,34 +10477,40 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);

// state is stored transposed: s_out[j*S_v + i] = S[i][j]
// so row j of s_out = column j of S (contiguous access)

if (kda) {
// precompute exp(g) into delta scratch (reused below)
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
delta[i] = expf(g_d[i]);
}
// S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
for (int64_t j = 0; j < S_v; ++j) {
ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
}
} else {
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
}

// delta[j] = sum_i S[j][i] * k[i]
memset(delta, 0, S_v * sizeof(float));
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
}
// delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
for (int64_t j = 0; j < S_v; ++j) {
delta[j] = (v_d[j] - delta[j]) * beta_val;
float sum = 0.0f;
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
delta[j] = (v_d[j] - sum) * beta_val;
}

// outer product: S[j][i] += k[i] * delta[j]
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
// outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
for (int64_t j = 0; j < S_v; ++j) {
ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
}

// attn_out[j] = sum_i S[j][i] * q[i]
memset(attn_data, 0, S_v * sizeof(float));
for (int64_t i = 0; i < S_v; ++i) {
ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
// attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
for (int64_t j = 0; j < S_v; ++j) {
float sum = 0.0f;
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
attn_data[j] = sum * scale;
}
ggml_vec_scale_f32(S_v, attn_data, scale);

attn_data += S_v * H; // advance to next token
}
Expand Down
37 changes: 24 additions & 13 deletions ggml/src/ggml-cuda/gated_delta_net.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include "gated_delta_net.cuh"

template <int S_v, bool KDA>
__global__ void gated_delta_net_cuda(const float * q,
__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2)
gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
Expand Down Expand Up @@ -38,17 +39,19 @@ __global__ void gated_delta_net_cuda(const float * q,

const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_offset;
curr_state += state_offset;
curr_state += state_offset + col * S_v;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;

constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous

#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
s_shard[r] = curr_state[i];
}

for (int t = 0; t < n_tokens; t++) {
Expand All @@ -62,15 +65,24 @@ __global__ void gated_delta_net_cuda(const float * q,

const float beta_val = *beta_t;

// Cache k and q in registers
float k_reg[rows_per_lane];
float q_reg[rows_per_lane];
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
k_reg[r] = k_t[i];
q_reg[r] = q_t[i];
}

if constexpr (!KDA) {
const float g_val = expf(*g_t);

// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
kv_shard += s_shard[r] * k_reg[r];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);

Expand All @@ -82,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q,
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}

float attn_col = warp_reduce_sum<warp_size>(attn_partial);
Expand All @@ -98,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q,
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r];
}

float kv_col = warp_reduce_sum<warp_size>(kv_shard);
Expand All @@ -112,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q,
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col;
attn_partial += s_shard[r] * q_reg[r];
}

float attn_col = warp_reduce_sum<warp_size>(attn_partial);
Expand All @@ -126,11 +137,11 @@ __global__ void gated_delta_net_cuda(const float * q,
attn_data += S_v * H;
}

// Write state back to global memory
// Write state back to global memory (transposed layout)
#pragma unroll
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
state[col * S_v + i] = s_shard[r];
}
}

Expand Down
9 changes: 5 additions & 4 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2466,13 +2466,14 @@ kernel void kernel_gated_delta_net_impl(

const float scale = 1.0f / sqrt((float)S_v);

device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;

float ls[NSG];

FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is*S_v];
ls[j] = s_ptr[is];
}

device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
Expand Down Expand Up @@ -2533,11 +2534,11 @@ kernel void kernel_gated_delta_net_impl(
g_ptr += args.ne21*G;
}

device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;

FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = ls[j];
dst_state[is] = ls[j];
}

#undef S_v
Expand Down
30 changes: 12 additions & 18 deletions src/models/delta-net-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
cb(kg_t, "key_gdiff_t", il);

ggml_tensor * s_t = ggml_transpose(ctx0, s);
s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
cb(s_t, "dnet_add_ch_state", il);
s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
cb(s, "dnet_add_ch_state", il);

// [CS, S_v, n_chunks, H_v * n_seqs]
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
Expand All @@ -240,7 +239,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]

// [CS, S_v, 1, H_v * n_seqs]
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
cb(v_t_p, "v_prime", il);

// [CS, S_v, 1, H_v * n_seqs]
Expand All @@ -252,7 +251,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
cb(v_attn, "v_attn", il);

// [S_v, CS, 1, H_v * n_seqs]
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
cb(attn_inter, "attn_inter", il);

// [S_v, CS, 1, H_v * n_seqs]
Expand All @@ -268,21 +267,19 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);

s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
s_t = ggml_add(ctx0, s_t, kgv);
cb(s_t, "dnet_add_ch_state", il);
s = ggml_mul(ctx0, s, ch_g_last_exp_t);
s = ggml_add(ctx0, s, kgv);
cb(s, "dnet_add_ch_state", il);
}

s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);

// truncate padded tokens
ggml_tensor * o = ggml_view_4d(ctx0, v,
S_v, n_tokens, H_v, n_seqs,
ggml_row_size(v->type, S_v),
ggml_row_size(v->type, S_v * CS * n_chunks),
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
s = ggml_transpose(ctx0, s_t);
s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
cb(s, "output_state", il);

return {o, s};
Expand Down Expand Up @@ -341,11 +338,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
g = ggml_exp(ctx0, g);
s = ggml_mul(ctx0, s, g);

ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));

// [1, S_v, H_v, n_seqs]
ggml_tensor * sk;
sk = ggml_mul (ctx0, s_t, k);
sk = ggml_mul (ctx0, s, k);
sk = ggml_sum_rows(ctx0, sk);

// [S_v, 1, H_v, n_seqs]
Expand All @@ -362,15 +357,14 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
k = ggml_repeat(ctx0, k, s);
kd = ggml_mul (ctx0, k, d_t);

s_t = ggml_add(ctx0, s_t, kd);
s = ggml_add(ctx0, s, kd);

cb(s_t, "dnet_add_ar_state", il);
cb(s, "dnet_add_ar_state", il);

ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
ggml_tensor * s_q = ggml_mul (ctx0, s, q);
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);

o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]

return {o, s};
}
Expand Down