diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7446e0df133..2e41a2eccec 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -11003,13 +11003,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # these layers act as MLM head, so we don't need them - if name.startswith("decoder."): - return - if name.startswith("model."): name = name[6:] + if self.cls_out_labels: + # For BertForSequenceClassification (direct projection layer) + if name == "classifier.weight": + name = "classifier.out_proj.weight" + + if name == "classifier.bias": + name = "classifier.out_proj.bias" + yield from super().modify_tensors(data_torch, name, bid) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h deleted file mode 100644 index a7078687288..00000000000 --- a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +++ /dev/null @@ -1,333 +0,0 @@ -#pragma once - -typedef vector unsigned char vec_t; -typedef __vector_quad acc_t; - -template -class tinyBLAS_Q0_PPC { - public: - tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth); - - void matmul(int64_t m, int64_t n); - void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { - vec_t A_pack[mc*kc*2]; - vec_t B_pack[nc*kc*2]; - int comparray[mc*kc]; - constexpr bool is_Ablock_q4 = std::is_same_v; - int64_t ytiles = m / mc; - int64_t xtiles = n / nc; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) { - end = tiles; - } - for (int64_t job = start; job < end; ++job) { - int64_t ii = (job / xtiles) * mc; - int64_t jj = (job % xtiles) * nc; - for (int64_t kk = 0; kk < k; kk += kc) { - if constexpr(is_Ablock_q4) { - packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); - } else { - packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); - } - packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); - } - } - } - - private: - inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); - } - } - } - - inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); - *c_ptr += *((float*)&fin_res[idx+I]+J); - } - } - } - - template - inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { - vector signed int vec_C[4]; - vector float CA[4] = {0}; - vector float res[4] = {0}; - __builtin_mma_disassemble_acc(vec_C, ACC); - for (int i = 0; i < 4; i++) { - CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); - res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); - fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); - } - } - - inline void process_q4_elements(vector signed char (&c)[2], int* ca) { - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; - c[0] = vec_and(c[1], lowMask); - c[1] = vec_sr(c[1], v4); - c[0] = vec_sub(c[0], v8); - c[1] = vec_sub(c[1], v8); - vsum = vec_sum4s(c[0], vsum); - vsum2 = vec_sum4s(c[1], vsum2); - vsum = vec_add(vsum, vsum2); - *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - V2 t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - t1 = vec_perm(s1, s2, swiz1); - t2 = vec_perm(s1, s2, swiz2); - t3 = vec_perm(s3, s4, swiz1); - t4 = vec_perm(s3, s4, swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - } - - template - inline void kernel(int64_t ii, int64_t jj) { - if constexpr(RM == 4 && RN == 8) { - KERNEL_4x8(ii,jj); - } else if constexpr(RM == 8 && RN == 4) { - KERNEL_8x4(ii,jj); - } else if constexpr(RM == 8 && RN == 8) { - KERNEL_8x8(ii,jj); - } else { - assert(false && "RN/RM values not supported"); - } - } - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray); - template - void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); - void KERNEL_4x8(int64_t ii, int64_t jj); - void KERNEL_8x4(int64_t ii, int64_t jj); - void KERNEL_8x8(int64_t ii, int64_t jj); - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); - template - void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); - - void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ - for (int I = 0; I<8; I++) { - float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); - *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); - } - } - } - - inline void process_q8_elements(const int8_t *qs, int *ca) { - vector signed char c1 = vec_xl(0, qs); - vector signed char c2 = vec_xl(16, qs); - vector signed int vsum1 = {0}; - vector signed int vsum2 = {0}; - vsum1 = vec_sum4s(c1, vsum1); - vsum2 = vec_sum4s(c2, vsum2); - vector signed int vsum = vec_add(vsum1, vsum2); - *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { - int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; - __vector_pair arr[8]; - VB c[8][2] = {0}; - VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - j = (rows >> 3); - int index = 0; - if (j > 0) { - do { - for (int it = 0; it < 8; it++) - aoffsets[it] = aoffset + it*lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); - c1[it] = c[it][0]; - c2[it] = c[it][1]; - if (comparray){ - process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); - } - } - vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); - vecOffset += 256; - } - j--; - index += 8*kc; - } while(j > 0); - } - - } - - void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { - int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - int index = 0; - j = (rows >> 3); - if (j > 0) { - do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs)); - c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs)); - c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs)); - c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs)); - c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs)); - c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs)); - c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs)); - c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs)); - - process_q4_elements(c1, &comparray[index + 8*blk+0]); - process_q4_elements(c2, &comparray[index + 8*blk+1]); - process_q4_elements(c3, &comparray[index + 8*blk+2]); - process_q4_elements(c4, &comparray[index + 8*blk+3]); - process_q4_elements(c5, &comparray[index + 8*blk+4]); - process_q4_elements(c6, &comparray[index + 8*blk+5]); - process_q4_elements(c7, &comparray[index + 8*blk+6]); - process_q4_elements(c8, &comparray[index + 8*blk+7]); - vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); - vecOffset += 256; - } - j--; - index += 8*kc; - } while (j > 0); - } - } - - void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { - acc_t acc[8]; - for (int i = 0; i < mc ; i += 8) { - for (int j = 0; j < nc; j += 8) { - vector float fin_res[16] = {0}; - vector float vs[16] = {0}; - for (int64_t kk = 0; kk < kc; kk+=2) { - for (int x = 0; x < 8; x++) { - __builtin_mma_xxsetaccz(&acc[x]); - } - int A_block_idx = (i/8)*(16*kc) + kk*16; - int B_block_idx = (j/8)*(16*kc)+ kk*16; - vec_t *A_block = &vec_A[A_block_idx]; - vec_t *B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk, vs); - int c_index = (i/8)*(8*kc)+ kk*8; - int* c_block = &comparray[c_index]; - compute(&acc[0], 0, 0, c_block, vs, fin_res); - compute(&acc[1], 4, 4, c_block, vs, fin_res); - compute(&acc[2], 0, 8, c_block, vs, fin_res); - compute(&acc[3], 4, 12, c_block, vs, fin_res); - - A_block_idx = (i/8)*(16*kc) + (kk+1)*16; - B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; - A_block = &vec_A[A_block_idx]; - B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk+1, vs); - c_index = (i/8)*(8*kc)+ (kk+1)*8; - c_block = &comparray[c_index]; - compute(&acc[4], 0, 0, c_block, vs, fin_res); - compute(&acc[5], 4, 4, c_block, vs, fin_res); - compute(&acc[6], 0, 8, c_block, vs, fin_res); - compute(&acc[7], 4, 12, c_block, vs, fin_res); - - } - if (l == 0) { - save_res(ii+i, jj+j, 0, fin_res); - save_res(ii+i+4, jj+j, 4, fin_res); - save_res(ii+i, jj+j+4, 8, fin_res); - save_res(ii+i+4, jj+j+4, 12, fin_res); - } else { - add_save_res(ii+i, jj+j, 0, fin_res); - add_save_res(ii+i+4, jj+j, 4, fin_res); - add_save_res(ii+i, jj+j+4, 8, fin_res); - add_save_res(ii+i+4, jj+j+4, 12, fin_res); - } - } - } - } - - const TA *const A; - const block_q8_0 *const B; - float *C; - const int64_t k; - int64_t kc; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 8f980c16b96..da412fd009b 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); } #endif #if defined(__MMA__) -#include "sgemm-ppc.h" +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC { packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); - mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_1, vec_A[x+4], vec_B[x]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC { const int nth; }; - template - tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA * A, int64_t lda, + const block_q8_0 * B, int64_t ldb, + float * C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - kc = 64; } - template - void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) { - int mc = 64; int nc = 64; - if (n % 8 == 0 && n < nc) { - nc = n; - mc = 32 ; - kc = 32; - } - const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0); - if (is_aligned) { - this->matmul_tiled_q0(m, n, mc, nc, kc); + void matmul(int64_t m, int64_t n) { + const int64_t mc = 64; + const int64_t kc = 64; + int64_t nc = 64; + int64_t n_aligned = 0; + if (n % 64 == 0) { + n_aligned = n; + } else if (n == 4) { + n_aligned = 4; + } else if (n < 64) { + n_aligned = (n / 8) * 8; + } else { + n_aligned = (n / 64) * 64; + } + + if (n_aligned > 0) { + if (n_aligned % 64 == 0) nc = 64; + else if (n_aligned == n) nc = n; + else if (n_aligned % 32 == 0) nc = 32; + else if (n_aligned % 24 == 0) nc = 24; + else if (n_aligned % 16 == 0) nc = 16; + else nc = 8; + } + bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + if (can_use_tiled) { + matmul_tiled(m, n_aligned, mc, nc, kc); + if (n > n_aligned) { + mnpack(0, m, n_aligned, n); + } } else { mnpack(0, m, 0, n); } } - template - template - void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + private: + inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J); + } + } + } + + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J); + } + } + } + + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I); + *c_ptr += *((float *)&vec_C[I] + J); + } + } + } + + template + inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]); + } + } + + inline void process_q4_elements(vector signed char (&c)[2], int * ca) { + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) { + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 16); + vec_xst(t7, 0, vecOffset + 32); + vec_xst(t8, 0, vecOffset + 48); + } + + inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) { + const vector signed char lowMask = vec_splats((signed char)0x0F); + const vector signed char v8 = vec_splats((signed char)0x08); + const vector unsigned char v4 = vec_splats((unsigned char)4); + lo = vec_and(packed, lowMask); + hi = vec_sr(packed, v4); + lo = vec_sub(lo, v8); + hi = vec_sub(hi, v8); + } + + inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + for (int i = 0; i < 4; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) { + vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16)); + } + } + + static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) { + vector signed short i16_hi = vec_unpackh(raw); + vector signed short i16_lo = vec_unpackl(raw); + + vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0); + vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0); + vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0); + vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0); + out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale)); + out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale)); + } + + void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + for (int i = 0; i < rows; i += 8) { + const block_q4_0 * rows_base[8]; + for (int r = 0; r < 8; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[8][4]; + for (int r = 0; r < 8; r++) { + const block_q4_0 * current_blk = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); + vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs)); + vector signed char c1, c2; + unpack_q4_to_q8(v_qs, c1, c2); + convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int c = 0; c < 4; c++) { + vector unsigned char c_arr[8]; + for (int r = 0; r < 8; r++) { + c_arr[r] = (vector unsigned char)hp_res[r][c]; + } + vector_permute_store_fp16((vec_t *)c_arr, vecOffset); + vecOffset += 128; + } + } + } + } + + template + static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + for (int i = 0; i < rows; i += chunk_size) { + const block_q8_0 * rows_base[chunk_size]; + for (int r = 0; r < chunk_size; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[chunk_size][4]; + for (int r = 0; r < chunk_size; r++) { + const block_q8_0 * b = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d)); + vector signed char c[2]; + __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs); + __builtin_vsx_disassemble_pair(c, & pair); + convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int col = 0; col < 4; col++) { + if constexpr (chunk_size == 8) { + vec_t t[8]; + t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1); + t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2); + t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1); + t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2); + + vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48)); + vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64)); + vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80)); + vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96)); + vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112)); + vecOffset += 128; + } else { + vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + + vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48)); + vecOffset += 64; + } + } + } + } + } + + void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + if (rows == 4) { + pack_q8_block<4>(a, lda, rows, blocks, vec); + } else { + pack_q8_block<8>(a, lda, rows, blocks, vec); + } + } + + template + void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array & comparray) { int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + TA * aoffset = NULL; + int8_t * vecOffset = NULL; + TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL; + TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL; vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { @@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC { c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); - process_q4_elements(c5, &comparray[4]); - process_q4_elements(c6, &comparray[5]); - process_q4_elements(c7, &comparray[6]); - process_q4_elements(c8, &comparray[7]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); + process_q4_elements(c5, & comparray[4]); + process_q4_elements(c6, & comparray[5]); + process_q4_elements(c7, & comparray[6]); + process_q4_elements(c8, & comparray[7]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC { c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC { case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC { } } - template template - void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) { int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; + block_q8_0 * aoffset = NULL; + VA * vecOffset = NULL; + block_q8_0 * aoffsets[8]; __vector_pair arr[8]; VB c[8][2] = {0}; VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { aoffsets[0] = aoffset; for (int it = 1; it < 8; it++) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip); for (int it = 0; it < 8; it++) aoffsets[it] += lda; vecOffset += 256; @@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC { if (i > 0) { do { for (int it = 0; it < 4; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 4; it++) { aoffsets[it] += lda; } @@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC { if (rows & 3) { aoffsets[0] = aoffset; for (int it = 1; it < 3; it++ ) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); - __builtin_vsx_disassemble_pair(c[2], &arr[2]); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], & arr[2]); c1[2] = c[2][0]; c2[2] = c[2][1]; - case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); - __builtin_vsx_disassemble_pair(c[1], &arr[1]); + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], & arr[1]); c1[1] = c[1][0]; c2[1] = c[1][1]; - case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); - __builtin_vsx_disassemble_pair(c[0], &arr[0]); + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], & arr[0]); c1[0] = c[0][0]; c2[0] = c[0][1]; break; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 3; it++) aoffsets[it] += lda; vecOffset += 128; @@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC { } } - template - void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); @@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC { } - template - void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) { + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]); } for (int I = 0; I<4; I++) { for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 4; i++) { comparray[i] = 0; int ca = 0; @@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 0, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 0, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii, jj+4, 4, fin_res); + save_res(ii, jj + 4, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) { + void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < 8; I++) { + for (int J = 0; J < 4; J++) { + *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); + save_res(ii + 4, jj, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) { + void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; acc_t acc_4, acc_5, acc_6, acc_7; @@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC { vector float vs[16] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); - __builtin_mma_xxsetaccz(&acc_2); - __builtin_mma_xxsetaccz(&acc_3); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); + __builtin_mma_xxsetaccz(& acc_2); + __builtin_mma_xxsetaccz(& acc_3); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); - __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]); + __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + for (int I = 0; I < 8 ; I++) { + for (int J = 0; J < 4; J++) { + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); - compute(&acc_2, 0, 8, comparray, vs, fin_res); - compute(&acc_3, 4, 12, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_2, 0, 8, comparray, vs, fin_res); + compute(& acc_3, 4, 12, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); - save_res(ii, jj+4, 8, fin_res); - save_res(ii+4, jj+4, 12, fin_res); + save_res(ii + 4, jj, 4, fin_res); + save_res(ii, jj + 4, 8, fin_res); + save_res(ii + 4, jj + 4, 12, fin_res); + } + + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) { + acc_t acc[8]; + for (int i = 0; i < mc ; i += 16) { + for (int j = 0; j < nc; j += 8) { + int A0_base = (i / 16) * (2 * 32 * kc); + int B0_base = (j / 8) * (32 * kc); + for (int x = 0; x < 8; x++) { + __builtin_mma_xxsetaccz(&acc[x]); + } + for (int64_t kk = 0; kk < kc; kk++) { + int A0_block_idx = A0_base + kk * 32; + int B0_block_idx = B0_base + kk * 32; + int A1_block_idx = A0_block_idx + 32 * kc; + int B1_block_idx = B0_block_idx + 32 * kc; + vec_t * A0_block = & vec_A[A0_block_idx]; + vec_t * B0_block = & vec_B[B0_block_idx]; + vec_t * A1_block = & vec_A[A1_block_idx]; + for (int it = 0; it < 4; it++) { + for (int x = 0; x < 4; x++) { + __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]); + __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + } + } + } + if (l == 0) { + save_acc(& acc[0], ii + i, jj + j); + save_acc(& acc[1], ii + i, jj + j + 4); + save_acc(& acc[2], ii + i + 4, jj + j); + save_acc(& acc[3], ii + i + 4, jj + j + 4); + save_acc(& acc[4], ii + i + 8, jj + j); + save_acc(& acc[5], ii + i + 8, jj + j + 4); + save_acc(& acc[6], ii + i + 12, jj + j); + save_acc(& acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(& acc[0], ii + i, jj + j); + add_save_acc(& acc[1], ii + i, jj + j + 4); + add_save_acc(& acc[2], ii + i + 4, jj + j); + add_save_acc(& acc[3], ii + i + 4, jj + j + 4); + add_save_acc(& acc[4], ii + i + 8, jj + j); + add_save_acc(& acc[5], ii + i + 8, jj + j + 4); + add_save_acc(& acc[6], ii + i + 12, jj + j); + add_save_acc(& acc[7], ii + i + 12, jj + j + 4); + } + } + } } - template - void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { + vec_t A_pack[mc * kc * 4]; + vec_t B_pack[nc * kc * 4]; + constexpr bool is_Ablock_q4 = std::is_same_v; + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + if constexpr(is_Ablock_q4) { + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } else { + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + } + } + } + + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC { vector float fin_res[4] = {0}; vector float vs[4] = {0}; vector float CA[4] = {0}; - __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value - __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value for (int l = 0; l < k; l++) { - __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_mma_xxsetaccz(&acc_0); + __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(& acc_0); if (isAblock_q4) { - packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); - for(int x = 0; x < 8; x+=4) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + packNormal((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true); + for (int x = 0; x < 8; x += 4) { + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]); } - for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } - __builtin_mma_disassemble_acc(vec_C, &acc_0); + __builtin_mma_disassemble_acc(vec_C, & acc_0); if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < RM; i++) { comparray[i] = 0; int ca = 0; @@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC { } } - template + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + assert(false && "RN/RM values not supported"); + } + } + template - NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC { for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - this->kernel(ii, jj); + kernel(ii, jj); } } - -template class tinyBLAS_Q0_PPC; -template class tinyBLAS_Q0_PPC; + const TA * const A; + const block_q8_0 * const B; + float * C; + const int64_t k; + int64_t kc; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; class tinyBLAS_PPC { public: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 727e4dd96ee..4b0f81ecb24 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -652,6 +652,7 @@ class MODEL_TENSOR(IntEnum): ENC_OUTPUT_NORM = auto() CLS = auto() # classifier CLS_OUT = auto() # classifier output projection + CLS_NORM = auto() CONV1D = auto() CONVNEXT_DW = auto() CONVNEXT_NORM = auto() @@ -1088,6 +1089,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", MODEL_TENSOR.CLS: "cls", MODEL_TENSOR.CLS_OUT: "cls.output", + MODEL_TENSOR.CLS_NORM: "cls.norm", MODEL_TENSOR.CONV1D: "conv1d", MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm", @@ -1507,6 +1509,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, + MODEL_TENSOR.CLS_NORM, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 228ba70aa34..5fc75c52eb8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1240,6 +1240,10 @@ class TensorNameMap: MODEL_TENSOR.CLS_OUT: ( "classifier.out_proj", # roberta ), + + MODEL_TENSOR.CLS_NORM: ( + "head.norm", # modern-bert + ), ############################################################################# MODEL_TENSOR.CONVNEXT_DW: ( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7e4da4e78cf..965066cb668 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -367,6 +367,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_CLS, "cls" }, { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, @@ -828,6 +829,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, }; case LLM_ARCH_JINA_BERT_V2: return { @@ -2518,6 +2520,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 521944370b4..e37f634e373 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -497,6 +497,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fc05989aa55..7f4b4a933ea 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2761,6 +2761,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->cls_b, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); for (struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 70d8ff02a92..69272498737 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -185,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { } void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && + (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) { + const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -2437,7 +2440,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -2476,8 +2480,15 @@ void llm_graph_context::build_pooling( } break; case LLAMA_POOLING_TYPE_RANK: { - ggml_tensor * inp_cls = build_inp_cls(); - cur = ggml_get_rows(ctx0, inp, inp_cls); + if (arch == LLM_ARCH_MODERN_BERT) { + // modern bert gte reranker builds mean first then applies prediction head and classifier + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411 + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } else { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } // classification head // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 @@ -2486,7 +2497,15 @@ void llm_graph_context::build_pooling( if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } - cur = ggml_tanh(ctx0, cur); + if (arch == LLM_ARCH_MODERN_BERT) { + cur = ggml_gelu(ctx0, cur); + } else { + cur = ggml_tanh(ctx0, cur); + } + if (cls_norm) { + // head norm + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } } // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en diff --git a/src/llama-graph.h b/src/llama-graph.h index 1d69ff1a6fc..74a4685121d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -1000,7 +1000,8 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; // // sampling (backend sampling) diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 36e353074e0..676efeda709 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -271,6 +271,7 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model.cls_b); add_tensor(model.cls_out); add_tensor(model.cls_out_b); + add_tensor(model.cls_norm); for (const struct llama_layer & layer : model.layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 279a4d5ced0..2aebaddf27d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -908,7 +908,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); + hparams.set_swa_pattern(swa_period, true); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -3513,9 +3513,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_NEO_BERT: @@ -8734,7 +8735,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); // add backend sampling layers (if any) llm->build_sampling(); diff --git a/src/llama-model.h b/src/llama-model.h index b3505914293..5ffba24fe98 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,6 +475,7 @@ struct llama_model { struct ggml_tensor * cls_b = nullptr; struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + struct ggml_tensor * cls_norm = nullptr; struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 0cdf9c324ba..99f1fdd9538 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -27,6 +27,7 @@ std::pair llm_build_delta_net_base::build_delta_ne const int64_t S_v = v->ne[0]; const int64_t H_v = v->ne[1]; + const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k); GGML_ASSERT(S_k == S_v); GGML_ASSERT(H_v % H_k == 0); @@ -35,9 +36,10 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); const float scale = 1.0f / sqrtf(S_k); @@ -52,8 +54,8 @@ std::pair llm_build_delta_net_base::build_delta_ne q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] - b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] const int CS = CHUNK_SIZE; @@ -78,33 +80,76 @@ std::pair llm_build_delta_net_base::build_delta_ne v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); - g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_cs = ggml_cumsum(ctx0, g); + // [CS, g_0, n_chunks, H_v * n_seqs] + // TODO: extend ggml_cumsum with axis parameter to avoid transpose + ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g))); cb(g_cs, "g_cs", il); - ggml_tensor * g_cs_i = g_cs; - ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + ggml_tensor * kb = nullptr; + ggml_tensor * kq = nullptr; + if (kda) { + const int64_t CHB = n_chunks * H_k * n_seqs; - g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] - // [CS, CS, n_chunks, H_v * n_seqs] - ggml_tensor * decay_mask; - decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); - decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); - decay_mask = ggml_exp(ctx0, decay_mask); - cb(decay_mask, "decay_mask", il); + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kb; - kb = ggml_mul_mat(ctx0, k, k_b); - kb = ggml_mul (ctx0, kb, decay_mask); + // decay_mask [chunk_size,chunk_size,S_k,CHB] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB); + + ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB); + ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB); + ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB); + + ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i); + ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); + + // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] + kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j); + kq = ggml_mul_mat(ctx0, decay_q_i, k_j); + + kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs))); + kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs))); + } else { + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + } + + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); // [CS, CS, n_chunks, H_k * n_seqs] ggml_tensor * attn; attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + cb(attn, "attn", il); ggml_tensor * identity; identity = ggml_view_1d(ctx0, attn, CS, 0); @@ -115,6 +160,7 @@ std::pair llm_build_delta_net_base::build_delta_ne cb(lhs, "dnet_add_ch_lhs", il); attn = ggml_neg(ctx0, attn); + cb(attn, "attn_pre_solve", il); ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); attn = ggml_add(ctx0, lin_solve, identity); @@ -123,7 +169,7 @@ std::pair llm_build_delta_net_base::build_delta_ne // [S_v, CS, n_chunks, H_v * n_seqs] v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); - // [CS, 1, n_chunks, H_v * n_seqs] + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); @@ -136,16 +182,10 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); cb(k_cd, "k_cumdecay", il); - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); + // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp)); ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - kq = ggml_mul(ctx0, kq, decay_mask); - kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); - cb(kq, "kq", il); - // vectorized calculation of key_gdiff // improved from the chunked version: // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) @@ -156,8 +196,8 @@ std::pair llm_build_delta_net_base::build_delta_ne // get last element in g_cumsum along CS dimension (ne0) // example: [[x, y, z, ..., last], ...] -> [[last], ...] - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], + // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3], g_cs->nb[1], g_cs->nb[2], g_cs->nb[3], @@ -167,16 +207,15 @@ std::pair llm_build_delta_net_base::build_delta_ne // TODO: remove this cont when CUDA supports non-cont unary ops g_last = ggml_cont(ctx0, g_last); - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); + // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last)); + cb(g_last_exp_t, "g_last_exp_t", il); - // [CS, 1, n_chunks, H_v * n_seqs] + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); cb(g_diff, "g_diff", il); - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); + ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff))); // [S_k, CS, n_chunks, H_v * n_seqs] ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); @@ -227,8 +266,9 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); - s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); + ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); + + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t); s_t = ggml_add(ctx0, s_t, kgv); cb(s_t, "dnet_add_ch_state", il); } @@ -241,9 +281,9 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_row_size(v->type, S_v), ggml_row_size(v->type, S_v * CS * n_chunks), ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); - o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + s = ggml_transpose(ctx0, s_t); + cb(s, "output_state", il); return {o, s}; } @@ -273,9 +313,10 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); const float scale = 1.0f / sqrtf(S_k); @@ -291,8 +332,10 @@ std::pair llm_build_delta_net_base::build_delta_ne cb(b, "b_in", il); cb(g, "g_in", il); - g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + // GDA: [1, 1, H_v, n_seqs] + // KDA: [1, S_k, H_v, n_seqs] + g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); // [S_v, S_v, H_v, n_seqs] g = ggml_exp(ctx0, g); diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 133834021d0..8173d894ef2 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -3,8 +3,6 @@ #include "llama-memory-recurrent.h" -#define CHUNK_SIZE 64 - // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) { @@ -67,7 +65,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t } llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : - llm_build_mamba_base(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -86,17 +84,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Output ids for selecting which tokens to output ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * chunked_causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity); - - ggml_build_forward_expand(gf, chunked_causal_mask); - ggml_build_forward_expand(gf, chunked_identity); - ggml_build_forward_expand(gf, chunked_diag_mask); - // Kimi dimension constants const int64_t n_head = hparams.n_head(); const int64_t head_dim = hparams.n_embd_head_kda; @@ -177,12 +164,22 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); - // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens + + const float eps_norm = hparams.f_norm_rms_eps; + + Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm); + Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); + beta = ggml_sigmoid(ctx0, beta); + + beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs); + g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens std::pair attn_out = n_seq_tokens == 1 ? - build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il); + build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : + build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il); - ggml_tensor * output = attn_out.first; + ggml_tensor * output = ggml_cont(ctx0, attn_out.first); ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il); @@ -393,385 +390,3 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_build_forward_expand(gf, cur); } - -/* - This is a ggml implementation of the naive_chunk_kda function of - https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py -*/ -std::pair llm_build_kimi_linear::build_kda_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - GGML_ASSERT(ggml_is_contiguous(state)); - - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - // TODO: can this ever be false? - const bool use_qk_l2norm = true; - - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } - - const float scale = 1.0f / sqrtf(S_v); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(gk, "gk_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(gk, "gk_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - gk = ggml_pad(ctx0, gk, 0, pad, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(gk, "gk_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - const int64_t HB = H_k * n_seqs; - - q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB); - k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, HB); - k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB); - v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, HB); - v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB); - - gk = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB); - beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB); - - // switch for cumsum - gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB); - cb(gk, "gk", il); - ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk); - cb(gk_cumsum, "gk_cumsum", il); - -/* - Compute Akk and Aqk loop together - Akk loop: - for i in range(BT): - k_i = k[..., i, :] # k_i [B,H,NT,S] - g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S] - A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i) - Aqk loop: - for j in range(BT): - k_j = k[:, :, i, j] - g_j = g[:, :, i, j:j+1, :] - A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) -*/ - const int64_t CHB = n_chunks * H_k * n_seqs; - ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] - ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] - - ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] - // decay_mask [chunk_size,chunk_size,S_k,CHB] - ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i); - cb(decay_mask, "decay_mask", il); - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - cb(decay_mask, "decay_masked", il); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched - decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); - - ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB); - ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); - ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB); - - ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i); - ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); - - // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] - ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j); - ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j); - Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB))); - Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB))); - cb(Akk, "Akk", il); - cb(Aqk, "Aqk", il); - - Akk = ggml_mul(ctx0, Akk, beta); - Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); - cb(Akk, "attn_pre_solve", il); - - Aqk = ggml_mul(ctx0, Aqk, diag_mask); - Aqk = ggml_scale(ctx0, Aqk, scale); // scale q - cb(Aqk, "Aqk_masked", il); - - // for i in range(1, chunk_size): - // row = attn[..., i, :i].clone() - // sub = attn[..., :i, :i].clone() - // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - // - // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) - ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false); - Akk = ggml_mul(ctx0, lin_solve, causal_mask); - Akk = ggml_add(ctx0, Akk, identity); - - cb(Akk, "attn_solved", il); - - // switch back for downstream - gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB); - ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); - cb(gk_cumsum, "gk_cumsum", il); - - // u = (A*beta[..., None, :]) @ v aka U_[t] - ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk); - - ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp); - cb(kbeta_gkexp, "kbeta_gkexp", il); - - ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk); - cb(k_cumdecay, "k_cumdecay", il); - - ggml_tensor * core_attn_out = nullptr; - ggml_tensor * new_state = ggml_dup(ctx0, state); - - cb(new_state, "new_state", il); - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { -// extract one chunk worth of data - auto chunkify = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; - auto chunkify_A = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; - - -// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B] - ggml_tensor * k_chunk = chunkify(k); - ggml_tensor * q_chunk = chunkify(q); - ggml_tensor * vb_chunk = chunkify(vb); - -// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B] - ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); - ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk); - ggml_tensor * Aqk_chunk = chunkify_A(Aqk); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B] - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t] - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - - // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t] - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - - // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B] - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - // or Gamma_[t]*Q_]t] @ S - ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp); - attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q - - // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B] - // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t]) - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk); - - // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - - core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); - - ggml_tensor * gk_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3], - gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3], - gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1))); - - ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last))); - - ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last)); - - ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff); - - ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp); - - // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S) - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff))); - - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - cb(new_state, "output_state", il); - - return {output_tokens, new_state}; -} - -std::pair llm_build_kimi_linear::build_kda_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(gk)); - - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(gk, "gk_in", il); - -// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B] -// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B] -// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B] - gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs); - ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk)); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to gk_t - gk_t = ggml_exp(ctx0, gk_t); - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * gk_t - // S = S * g_i[..., None].exp() - state = ggml_mul(ctx0, state, gk_t); - - ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - -// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B] - k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs); - ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k); - - // v_i - (k_i[..., None] * S).sum(-2) - v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state); - - // b_i[..., None] * k_i - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t); - - // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2)) - // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B] - state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta)))); - - q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs); - state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q); - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - diff --git a/src/models/models.h b/src/models/models.h index 920a8e5798f..e23918f55f0 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -320,8 +320,7 @@ struct llm_build_jamba : public llm_build_mamba_base { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; -// TODO: derive llm_build_delta_net_base instead -struct llm_build_kimi_linear : public llm_build_mamba_base { +struct llm_build_kimi_linear : public llm_build_delta_net_base { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); std::pair build_kda_autoregressive( @@ -542,8 +541,7 @@ struct llm_build_qwen3next : public llm_build_delta_net_base { const llama_model & model; }; -// TODO: derive llm_build_delta_net_base instead -struct llm_build_qwen35 : public llm_graph_context { +struct llm_build_qwen35 : public llm_build_delta_net_base { llm_build_qwen35(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -556,39 +554,12 @@ struct llm_build_qwen35 : public llm_graph_context { ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); - ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -604,7 +575,7 @@ struct llm_build_qwen35 : public llm_graph_context { }; // TODO: derive llm_build_delta_net_base instead -struct llm_build_qwen35moe : public llm_graph_context { +struct llm_build_qwen35moe : public llm_build_delta_net_base { llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -617,38 +588,12 @@ struct llm_build_qwen35moe : public llm_graph_context { ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp index bb12ed819f7..32066c712b4 100644 --- a/src/models/modern-bert.cpp +++ b/src/models/modern-bert.cpp @@ -104,13 +104,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll LLM_NORM, -1); cb(cur, "final_norm_out", -1); - if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - // extracting cls token - cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); - cb(cur, "cls_pooled_embd", -1); - } - - cb(cur, "res_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 94c68dbb268..7e1749b2c81 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -2,10 +2,8 @@ #include "llama-memory-recurrent.h" -#define CHUNK_SIZE 64 - llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : - llm_graph_context(params), model(model) { + llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -25,17 +23,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -45,7 +32,7 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); @@ -95,361 +82,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -// utility to get one slice from the third dimension -// input dim: [x, y, c, b] -// output dim: [x, y, 1, b] -static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { - return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); -} - -std::pair llm_build_qwen35::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen35::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - std::pair llm_build_qwen35::build_qkvz( ggml_tensor * input, int il) { @@ -561,9 +193,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( ggml_tensor * llm_build_qwen35::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -589,6 +218,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -596,6 +228,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); @@ -603,8 +236,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -613,11 +244,12 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -637,7 +269,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -651,31 +286,41 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - // if head keys and value keys are different, repeat Q/K to match V's head count - // V heads are in tiled order (from conversion), so simple tiled repeat works + // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); + // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -689,7 +334,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -698,19 +343,15 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -721,7 +362,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 93da7ea628c..e12a5dea737 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -2,10 +2,8 @@ #include "llama-memory-recurrent.h" -#define CHUNK_SIZE 64 - llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : - llm_graph_context(params), model(model) { + llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -25,17 +23,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -45,7 +32,7 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); @@ -95,362 +82,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -// utility to get one slice from the third dimension -// input dim: [x, y, c, b] -// output dim: [x, y, 1, b] -static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { - return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); -} - -std::pair llm_build_qwen35moe::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen35moe::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - std::pair llm_build_qwen35moe::build_qkvz( ggml_tensor * input, int il) { @@ -562,9 +193,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -590,6 +218,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -597,6 +228,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); @@ -604,8 +236,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -614,11 +244,12 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -638,7 +269,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -652,31 +286,41 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); - // if head keys and value keys are different, repeat Q/K to match V's head count - // V heads are in tiled order (from conversion), so simple tiled repeat works + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); + // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -690,7 +334,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -699,19 +343,15 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -722,7 +362,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 0fdf2d42c25..974120ea6f2 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -306,8 +306,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * beta = ggml_sigmoid(ctx0, b); - beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); - // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); @@ -318,6 +316,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);