Skip to content

vulkan: Slang flash attention shader#20451

Draft
0cc4m wants to merge 5 commits intomasterfrom
0cc4m/vulkan-slang-flash-attention
Draft

vulkan: Slang flash attention shader#20451
0cc4m wants to merge 5 commits intomasterfrom
0cc4m/vulkan-slang-flash-attention

Conversation

@0cc4m
Copy link
Collaborator

@0cc4m 0cc4m commented Mar 12, 2026

This is a port of the existing scalar Flash Attention GLSL shader to Slang. I wanted to try to port it to see what the state of Slang is, how easy is it to use, what does it make easier compared to Slang? The purpose of this PR is just to serve as a point of discussion, for now. @jeffbolznv FYI

I started by copying the GLSL shader and changing just the bare minimum, then went deeper to transform some of the structures using Slang features. Overall I think it has a lot of potential, getting rid of the crazy preprocessor structures we need for GLSL would be nice.

What I like

  • The dequantization code is much cleaner with Slang generics and interfaces. You can just define an interface and plug in various dequantization/transformation algorithms, the code looks much cleaner in the end, IMO.
  • Templating/Generics also allows putting common patterns like reductions into functions that can be reused.
  • Typealiasing and vector<> also simplify data type choice, so I don't need to have a define per vector type I want to use, just one for the data type.
  • The module system seems nice to abstract out common code into functions that can be reused across many shaders.

What I don't like

  • The HLSL-based subgroup intrinsic naming seems clunky compared to GLSL. Like WaveReadLaneAt() instead of subgroupShuffle(). subgroupShuffleXor() is even missing completely, requiring WaveReadLaneAt(value, WaveGetLaneIndex() ^ s) as a workaround.
  • The Generics/Templating system still has flaws that prevent using e.g. the same reduce function for scalars and for vectors. I have to duplicate the code and provide a different function for vectors. A builtin vector type interface seems to be missing.
  • Shared Memory can't be passed as reference into a module function. That seems like a huge oversight to me. For a reduction I may need shared memory. To keep shared memory amount optimal I have to define it in the main file, but I can't pass it into a module function without strange interface workarounds like this:
public interface ISharedMemory<T> {
    static T get(uint idx);
    static void set(uint idx, T value);
}
groupshared float tmpsh[tmpsh_size];
struct ShMemFloat: ISharedMemory<float> {
    static float get(uint idx) {
        return tmpsh[idx];
    }
    static void set(uint idx, float value) {
        tmpsh[idx] = value;
    }
}

That's very verbose and just makes the code harder to read.

There's probably some things that can still be "slang-ified", but I don't have time right now.

I plan to do some performance checks, but I'm having some trouble with the slang compiler currently.

I'll leave this as is for now, discuss what I found with the Slang developers and hopefully pick it up again in the future.

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 12, 2026
@jeffbolznv
Copy link
Collaborator

I read through the code. It definitely has a more "structured" feel, but it would take some getting used to. Will be interesting to see if it has any performance deficit compared to GLSL.

Typealiasing and vector<> also simplify data type choice

FWIW, https://github.com/KhronosGroup/GLSL/blob/main/extensions/ext/GLSL_EXT_long_vector.txt also introduces a vector type with template syntax. Currently glslang will emit the LongVector capability if it's used, but I don't think that's actually required and it would be possible to support 2-4 component vectors on older drivers.

@jeffbolznv
Copy link
Collaborator

FWIW, https://github.com/KhronosGroup/GLSL/blob/main/extensions/ext/GLSL_EXT_long_vector.txt also introduces a vector type with template syntax. Currently glslang will emit the LongVector capability if it's used, but I don't think that's actually required and it would be possible to support 2-4 component vectors on older drivers.

I went ahead and made this change here: KhronosGroup/glslang#4188

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants