Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2464,6 +2464,8 @@ extern "C" {
bool lower,
bool uni);

// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
11 changes: 5 additions & 6 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10436,8 +10436,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(

const float * state_in_base = (const float *)src_state->data;

const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1;
//const int64_t rq1 = nev1 / neq1;
//const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;

Expand All @@ -10447,8 +10447,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence

const int64_t iq1 = iv1 / rq1;
const int64_t ik1 = iv1 / rk1;
const int64_t iq1 = iv1 % neq1;
const int64_t ik1 = iv1 % nek1;

const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;
Expand All @@ -10468,7 +10468,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);

const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);

if (kda) {
for (int64_t i = 0; i < S_v; ++i) {
Expand Down Expand Up @@ -10501,7 +10501,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(

attn_data += S_v * H; // advance to next token
}

}
}

Expand Down
183 changes: 104 additions & 79 deletions ggml/src/ggml-cuda/gated_delta_net.cu
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
#include "gated_delta_net.cuh"
#include "ggml-cuda/common.cuh"

template <int S_v, bool KDA>
__global__ void __launch_bounds__(S_v, 1)
gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
const int64_t H,
const int64_t n_tokens,
const int64_t n_seqs,
const int64_t sq1,
const int64_t sq2,
const int64_t sq3,
const int64_t sv1,
const int64_t sv2,
const int64_t sv3,
const int64_t sb1,
const int64_t sb2,
const int64_t sb3,
const int64_t rq1,
const int64_t rq3,
const float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column

const int64_t iq1 = h_idx / rq1;
const int64_t iq3 = sequence / rq3;
__global__ void gated_delta_net_cuda(const float * q,
const float * k,
const float * v,
const float * g,
const float * beta,
const float * curr_state,
float * dst,
int64_t H,
int64_t n_tokens,
int64_t n_seqs,
int64_t sq1,
int64_t sq2,
int64_t sq3,
int64_t sv1,
int64_t sv2,
int64_t sv3,
int64_t sb1,
int64_t sb2,
int64_t sb3,
const uint3 neqk1_magic,
const uint3 rq3_magic,
float scale) {
const uint32_t h_idx = blockIdx.x;
const uint32_t sequence = blockIdx.y;
// each warp owns one column, using warp-level primitives to reduce across rows
const int lane = threadIdx.x;
const int col = blockIdx.z * blockDim.y + threadIdx.y;

const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
const uint32_t iq3 = fastdiv(sequence, rq3_magic);

const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
float * attn_data = dst;
Expand All @@ -41,17 +41,14 @@ gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;

// GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
// TODO: check optimal path for RDNA1 and RDNA2 devices.
#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
extern __shared__ float s_shared[];
float * s = s_shared + col * S_v;
#else
float s[S_v];
#endif
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
float s_shard[rows_per_lane];
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = curr_state[i * S_v + col];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
}

for (int t = 0; t < n_tokens; t++) {
Expand All @@ -69,55 +66,71 @@ gated_delta_net_cuda(const float * q,
const float g_val = expf(*g_t);

// kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += s[i] * k_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum<warp_size>(kv_shard);

// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;

// fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = g_val * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}

attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum<warp_size>(attn_partial);

if (lane == 0) {
attn_data[col] = attn_col * scale;
}
} else {
// kv[col] = sum_i g[i] * S[i][col] * k[i]
float kv_col = 0.0f;
float kv_shard = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
kv_col += expf(g_t[i]) * s[i] * k_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}

float kv_col = warp_reduce_sum<warp_size>(kv_shard);

// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;

// fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
// attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
float attn_col = 0.0f;
float attn_partial = 0.0f;
#pragma unroll
for (int i = 0; i < S_v; i++) {
s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
attn_col += s[i] * q_t[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}

attn_data[col] = attn_col * scale;
float attn_col = warp_reduce_sum<warp_size>(attn_partial);

if (lane == 0) {
attn_data[col] = attn_col * scale;
}
}

attn_data += S_v * H;
}

// Write state back to global memory
#pragma unroll
for (int i = 0; i < S_v; i++) {
state[i * S_v + col] = s[i];
for (int r = 0; r < rows_per_lane; r++) {
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
}
}

Expand All @@ -135,35 +148,43 @@ static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t rq1, int64_t rq3,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
//TODO: Add chunked kernel for even faster pre-fill
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const int num_warps = 4;
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);

dim3 grid_dims(H, n_seqs, 1);
dim3 block_dims(S_v, 1, 1);
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
const uint3 rq3_magic = init_fastdiv_values(rq3);

int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;

switch (S_v) {
case 32: {
constexpr int sv = 32;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
case 16:
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 64: {
constexpr int sv = 64;
size_t smem = calculate_smem(sv, cc);
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
case 128: {
Expand All @@ -172,7 +193,7 @@ static void launch_gated_delta_net(
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
}
default:
Expand All @@ -190,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_tensor * src_state = dst->src[5];

GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);

const int64_t S_v = nev0;
const int64_t H = nev1;
Expand All @@ -202,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *

const bool kda = (src_g->ne[0] == S_v);

const int64_t rq1 = nev1 / neq1;
GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;

const int64_t rq3 = nev3 / neq3;

const float * q_d = (const float *) src_q->data;
Expand Down Expand Up @@ -241,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}
35 changes: 35 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];

// v is src[2], dimensions: S_v = ne[0], H = ne[1]
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G

const int nsg = op->src[2]->ne[0]/32;

GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(op->ne[0] == ne20 * ne21);
GGML_ASSERT(ne20 % 32 == 0);

snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);
}

res.nsg = nsg;

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
return op->src[2]->ne[0] % 32 == 0;
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
Expand Down
Loading
Loading