Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 122 additions & 11 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,9 @@ struct vk_device_struct {
vk_pipeline pipeline_rwkv_wkv7_f32;
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
vk_pipeline pipeline_gated_delta_net[3][2];
vk_pipeline pipeline_gated_delta_net_chunk_intra;
vk_pipeline pipeline_gated_delta_net_chunk_inter;
vk_pipeline pipeline_gated_delta_net_chunk_output;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
Expand Down Expand Up @@ -1468,6 +1471,18 @@ struct vk_op_gated_delta_net_push_constants {
float scale;
};

struct vk_op_gated_delta_net_chunk_push_constants {
uint32_t H;
uint32_t n_tokens;
uint32_t n_seqs;
uint32_t sq1, sq2, sq3;
uint32_t sv1, sv2, sv3;
uint32_t sb1, sb2, sb3;
uint32_t neq1, rq3;
uint32_t n_chunks;
uint32_t s_off;
};

struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
Expand Down Expand Up @@ -4599,6 +4614,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
}

ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_intra, "gated_delta_net_chunk_intra_f32_d128",
gated_delta_net_chunk_intra_f32_len, gated_delta_net_chunk_intra_f32_data, "main",
8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_inter, "gated_delta_net_chunk_inter_f32_d128",
gated_delta_net_chunk_inter_f32_len, gated_delta_net_chunk_inter_f32_data, "main",
9, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output, "gated_delta_net_chunk_output_f32_d128",
gated_delta_net_chunk_output_f32_len, gated_delta_net_chunk_output_f32_data, "main",
6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {128, 64}, 1);

if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
Expand Down Expand Up @@ -10373,9 +10398,13 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}

static constexpr uint32_t GDN_CHUNK_SIZE = 64;
static constexpr uint32_t GDN_CHUNK_THRESHOLD = UINT32_MAX; // Disabled

static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_g = dst->src[3];
const ggml_tensor * src_beta = dst->src[4];

GGML_ASSERT(dst->buffer != nullptr);
Expand All @@ -10386,11 +10415,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t n_seqs = (uint32_t)src_v->ne[3];

const uint32_t s_off = S_v * H * n_tokens * n_seqs;

vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);

ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const bool kda = (src_g->ne[0] == (int64_t)S_v);
const bool use_chunked = !kda && S_v == 128 && n_tokens > GDN_CHUNK_THRESHOLD;

vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
Expand All @@ -10411,19 +10437,104 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t neq1 = (uint32_t)src_q->ne[1];
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);

const float scale = 1.0f / sqrtf((float)S_v);
const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
if (!use_chunked) {
// Autoregressive path (optimal for TG / small n_tokens)
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);

ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);

const float scale = 1.0f / sqrtf((float)S_v);
const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
scale
};

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, { H, n_seqs, 1u });
return;
}

// Chunked parallel path (PP acceleration)
const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE;

vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra;
vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter;
vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output;

ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);
ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1);

// Scratch buffer layout within prealloc_split_k
const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float);
const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float);
const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float);
const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float);

const size_t w_off = 0;
const size_t u_off = wu_size;
const size_t vn_off = 2 * wu_size;
const size_t dec_off = 3 * wu_size;
const size_t gcum_off = dec_off + d_size;
const size_t h_off = gcum_off + g_size;
const size_t total_scratch = h_off + h_size;

if (ctx->prealloc_size_split_k < total_scratch) {
ctx->prealloc_size_split_k = total_scratch;
ggml_vk_preallocate_buffers(ctx, subctx);
}

if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}

vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size };
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size };

const vk_op_gated_delta_net_chunk_push_constants pc = {
H, n_tokens, n_seqs,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
neq1, rq3,
scale
n_chunks, s_off
};

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
// Dispatch 1: Intra-chunk (parallel across chunks)
// Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
scratch_w, scratch_u, scratch_dec, scratch_gcum},
pc, { n_chunks * H, n_seqs, 1u });

ggml_vk_sync_buffers(ctx, subctx);

// Dispatch 2: Inter-chunk state propagation (sequential across chunks)
// Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst)
ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
{src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
src_buf[5], scratch_h, scratch_vnew, dst_buf},
pc, { H, n_seqs, 1u });

ggml_vk_sync_buffers(ctx, subctx);

// Dispatch 3: Output (parallel across chunks)
// Bindings: Q, K, H, VNew, GCum, Dst
ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
{src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
pc, { n_chunks * H, n_seqs, 1u });

ctx->prealloc_split_k_need_sync = true;
}

static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
Expand Down
126 changes: 126 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#version 450

#extension GL_EXT_control_flow_attributes : require

// Inter-chunk state propagation for chunked gated delta net
//
// Sequential across chunks, parallel across state columns.
// For each chunk c:
// 1. Store state snapshot h[c] for output kernel
// 2. v_corrected = U - W @ S (C x d)
// 3. S_next = exp(g_total) * S + K_gated^T @ v_corrected (d x d)
//
// where K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
//
// Grid: (H, n_seqs, 1)
// Workgroup: S_V threads (one per state column)

layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint CHUNK_SIZE = 64;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
uint n_chunks;
uint s_off;
};

layout(binding = 0) readonly buffer KBuf { float k_in[]; };
layout(binding = 1) readonly buffer WBuf { float w_in[]; };
layout(binding = 2) readonly buffer UBuf { float u_in[]; };
layout(binding = 3) readonly buffer DecayBuf { float decay_in[]; };
layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; };
layout(binding = 5) readonly buffer StateBuf { float state_in[]; };
layout(binding = 6) writeonly buffer HBuf { float h_out[]; };
layout(binding = 7) writeonly buffer VNewBuf { float vnew_out[]; };
layout(binding = 8) buffer FinalBuf { float final_out[]; };

shared float s_w[S_V];
shared float s_kg[S_V];

void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;

if (col >= S_V) return;

const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;

const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;

float state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = state_in[state_base + i * S_V + col];
}

for (uint c = 0; c < n_chunks; c++) {
const uint chunk_start = c * CHUNK_SIZE;
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);

const uint h_base = ((seq_id * n_chunks + c) * H + head_id) * state_size;
[[unroll]] for (uint i = 0; i < S_V; i++) {
h_out[h_base + i * S_V + col] = state[i];
}

const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V;
const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE;

const uint decay_idx = (seq_id * n_chunks + c) * H + head_id;
const float g_total = decay_in[decay_idx];

float delta[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] = 0.0;
}

for (uint t = 0; t < chunk_len; t++) {
s_w[col] = w_in[wu_base + t * S_V + col];
barrier();

float ws = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
ws += dot(
vec4(s_w[i], s_w[i+1], s_w[i+2], s_w[i+3]),
vec4(state[i], state[i+1], state[i+2], state[i+3])
);
}

float vnew = u_in[wu_base + t * S_V + col] - ws;
vnew_out[wu_base + t * S_V + col] = vnew;

// K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
float g_cumsum_t = gcum_in[gcum_base + t];
float decay_factor = exp(g_total - g_cumsum_t);

const uint t_global = chunk_start + t;
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
s_kg[col] = k_in[k_off + col] * decay_factor;
barrier();

[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] += s_kg[i] * vnew;
}
barrier();
}

float total_decay = exp(g_total);
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = total_decay * state[i] + delta[i];
}
}

// Write final state to dst at s_off
[[unroll]] for (uint i = 0; i < S_V; i++) {
final_out[s_off + state_base + i * S_V + col] = state[i];
}
}
Loading