From e1cb817483ebda4e3ebc30fd07f4292c654f4339 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 1 Apr 2026 11:50:17 +0200 Subject: [PATCH 1/8] memory: respect unified KV cache in hybrid memory for eval tasks (#21224) The hybrid memory paths (`llama-memory-hybrid.cpp` and `llama-memory-hybrid-iswa.cpp`) always used sequential equal split, ignoring the unified KV cache flag. This caused hellaswag, winogrande, and multiple-choice evaluations to fail on hybrid models (models with both attention and recurrent/SSM layers, such as Qwen3.5-35B-A3B) with: split_equal: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag) PR #19954 fixed this for `llama-kv-cache-iswa.cpp` by automatically enabling unified KV mode and setting n_parallel >= 4 for multi-choice eval tasks. However, the hybrid memory paths were not updated. This commit mirrors the iswa fix: use non-sequential split when KV cache is unified (n_stream == 1), which is automatically set by llama-perplexity for hellaswag/winogrande/multiple-choice since #19954. Tested on Qwen3.5-35B-A3B (hybrid attention+SSM MoE model): - HellaSwag: 83.0% (400 tasks) - Winogrande: 74.5% (400 tasks) - MMLU: 41.2% - ARC-Challenge: 56.2% - TruthfulQA: 37.7% All previously failed with llama_decode() error. --- src/llama-memory-hybrid-iswa.cpp | 6 +++--- src/llama-memory-hybrid.cpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index 411769672af..10e6b459797 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index a1b45e4a3cc..4ce1af592c1 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { From 84f82e846caaf871012143a2d658a974edc410c5 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Wed, 1 Apr 2026 03:04:58 -0700 Subject: [PATCH 2/8] ggml-cuda: Add generic NVFP4 MMQ kernel (#21074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced NVFP4 generic MMQ kernel * Added extra FP8 guard, hope to solve ci HIP failure * Rename tiles and use HIP_FP8_AVAILABLE * Removed remaning FP8 straggler and added const int * Const * Removed DECL_MMQ_CASE artifact * Removed newline * Removed space after else * Changed HIP FP8 NVFP4 conversion gate * Added new line to bottom of mmq.cu 270 * Removed extra spaces * Removed single space in front of else on line 814 * Added NVFP4 to generate cu script so HIP can see it, further tightened logic * Include generated mmq-instance-nvfp4.cu * Added NVFP4 mmq to HIP Check ignore list * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Added function name ending for end if Co-authored-by: Johannes Gäßler * Added function names to closing endif Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 27 ++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 2 - ggml/src/ggml-cuda/mmq.cu | 5 +- ggml/src/ggml-cuda/mmq.cuh | 89 +++++++++++++++++-- .../template-instances/generate_cu_files.py | 2 +- .../template-instances/mmq-instance-nvfp4.cu | 5 ++ scripts/hip/gcn-cdna-vgpr-check.py | 6 +- 7 files changed, 117 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7d7f20af3a0..9affe023403 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { } static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { -#ifdef FP8_AVAILABLE - const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. -#if defined(GGML_USE_HIP) && defined(CDNA3) - // ROCm dose not support fp8 in software on devices with fp8 hardware, +#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 + // ROCm does not support fp8 in software on devices with fp8 hardware, // but CDNA3 supports only e4m3_fnuz (no inf). + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; #else +#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); -#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP) return static_cast(xf) / 2; #else - NO_DEVICE_CODE; -#endif // FP8_AVAILABLE + if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f + return 0.0f; + } + const int exp = (x >> 3) & 0xF; + const int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return static_cast(raw / 2); +#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) +#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 } __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d1239b1c5f7..75b62129ade 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: -#ifdef FP8_AVAILABLE case GGML_TYPE_NVFP4: -#endif // FP8_AVAILABLE case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 9a69f41d159..27b4145ac9a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con case GGML_TYPE_MXFP4: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q2_K: mul_mat_q_case(ctx, args, stream); break; @@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; - } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 255e59f6fc6..51e8dad4ce7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_NVFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: @@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; @@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); +static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); + static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { @@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } + +template +static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kb0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4; + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = threadIdx.x % threads_per_row; + const int row_in_warp = threadIdx.x / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx; + const uint32_t * __restrict__ src_qs = reinterpret_cast(bxi->qs); + const int kqs = 16 * kbx; + const int ksc = 4 * kbx; + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4); + const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y; + x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#else + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y; + x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -// Used for Q3_K, IQ2_S, and IQ2_XS +// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -3261,6 +3327,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; @@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_NVFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index b7b5832293e..40d51f93fa4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -35,7 +35,7 @@ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu new file mode 100644 index 00000000000..2cb140d35a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_NVFP4); diff --git a/scripts/hip/gcn-cdna-vgpr-check.py b/scripts/hip/gcn-cdna-vgpr-check.py index 38db47d3d18..bbbce52ef39 100644 --- a/scripts/hip/gcn-cdna-vgpr-check.py +++ b/scripts/hip/gcn-cdna-vgpr-check.py @@ -139,7 +139,11 @@ def main(): '_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil', '_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil', '_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii', - '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii' + '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii', + '_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii', + '_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii', + '_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii', + '_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii' } functions = parse_log_file(log_file) From 6b949d10785a80ead503f44a2cbe139ffcf8034a Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 1 Apr 2026 18:54:15 +0800 Subject: [PATCH 3/8] sycl : support nvfp4 type in mul_mat (#21227) --- ggml/src/ggml-sycl/common.hpp | 7 ++ ggml/src/ggml-sycl/convert.cpp | 18 +++++ ggml/src/ggml-sycl/dequantize.hpp | 32 +++++++++ ggml/src/ggml-sycl/mmvq.cpp | 22 +++++- ggml/src/ggml-sycl/type.hpp | 112 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 42 +++++++++++ 6 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-sycl/type.hpp diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fcb0db99c6b..fd84c917853 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -23,6 +23,7 @@ #include "ggml-impl.h" #include "ggml-sycl.h" #include "presets.hpp" +#include "type.hpp" #include "sycl_hw.hpp" namespace syclexp = sycl::ext::oneapi::experimental; @@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) { return val; } +static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) { + const uint32_t bits = x * (x != 0x7F && x != 0xFF); + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d17aca2cac4..d7f60cbc9ea 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t }); } +template +static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_nvfp4(vx, y, k); + }); +} + + template static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } @@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index da2a605daa8..3272724f41b 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr } } + +template +static void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_sycl_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_sycl_cast(d * kvalues_mxfp4[q >> 4]); +} + + #endif // GGML_SYCL_DEQUANTIZE_HPP diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 316aa0d0fb5..5abc50fabfe 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_MXFP4: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/type.hpp b/ggml/src/ggml-sycl/type.hpp new file mode 100644 index 00000000000..d7ff89d7d42 --- /dev/null +++ b/ggml/src/ggml-sycl/type.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +inline uint8_t float_to_e4m3(float f) +{ + if (sycl::isnan(f)) { + return 0x7F; // Canonical NaN (positive) + } + + uint32_t bits = sycl::bit_cast(f); + uint32_t sign = (bits >> 31) & 0x1u; + uint32_t exp = (bits >> 23) & 0xFFu; + uint32_t mant = bits & 0x7FFFFFu; + + // Zero + if (exp == 0 && mant == 0) { + return static_cast(sign << 7); + } + + // Extract biased exponent and mantissa for FP8 + int e = static_cast(exp) - 127; // true exponent (IEEE bias 127) + uint32_t m = mant; + + // Handle very large values → NaN (NVIDIA behavior for E4M3) + if (e > 7) { // max exponent for E4M3 is 7 (biased 14) + return static_cast((sign << 7) | 0x7F); + } + + // Handle subnormals and normal numbers + if (e < -6) { // smallest normal exponent is -6 + // Subnormal in FP8: shift mantissa right + int shift = -6 - e; + m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position + if (shift > 23) m = 0; + } else { + // Normal number: adjust exponent bias from 127 to 7 + int new_exp = e + 7; + m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1) + m |= (static_cast(new_exp) << 3); + } + + // Round-to-nearest-even (simple guard + round bit) + // For better accuracy you can add sticky bit, but this is sufficient for most use cases + uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits + if (round_bit) { + m += 1; + // Carry into exponent if mantissa overflows + if ((m & 0x8u) != 0) { + m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling + // If exponent overflows after carry → NaN + if ((m >> 3) > 14) { + return static_cast((sign << 7) | 0x7F); + } + } + } + + uint8_t result = static_cast((sign << 7) | (m & 0x7F)); + return result; +} + +inline float e4m3_to_float(uint8_t x) +{ + if (x == 0) return 0.0f; + + uint8_t sign = (x >> 7) & 0x1u; + uint8_t exp = (x >> 3) & 0xFu; + uint8_t mant = x & 0x7u; + + // NaN (NVIDIA uses 0x7F / 0xFF as NaN) + if (exp == 0xF && mant != 0) { + return std::numeric_limits::quiet_NaN(); + } + if (exp == 0xF) { // 0x7F or 0xFF treated as NaN + return std::numeric_limits::quiet_NaN(); + } + + float val; + + if (exp == 0) { + // Subnormal + val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f); + } else { + // Normal: implicit leading 1 + bias 7 + val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast(exp) - 7.0f); + } + + return sign ? -val : val; +} + +// The actual type definition +struct __nv_fp8_e4m3 { + uint8_t raw; + + __nv_fp8_e4m3() = default; + + explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {} + explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast(h))) {} + + operator float() const { return e4m3_to_float(raw); } + operator sycl::half() const { return static_cast(static_cast(*this)); } + + // Allow direct access for vector loads/stores + operator uint8_t&() { return raw; } + operator uint8_t() const { return raw; } +}; + +using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>; +using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>; + diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 9a267d85a0c..eab9850aed7 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -15,6 +15,7 @@ #include "dpct/helper.hpp" #include "ggml.h" +#include "type.hpp" #include "quants.hpp" typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, @@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) { return x32; } +static __dpct_inline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq, return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & iqs) { + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0]; + sum += d * float(sumi); + } + + return sum; +} static __dpct_inline__ float vec_dot_q5_0_q8_1(const void *__restrict__ vbq, From 296bc0538b14539d694ae742f8ba06a28ca920f8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 16:01:45 +0300 Subject: [PATCH 4/8] ggml : bump version to 0.9.10 (ggml/1454) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index ab558438e95..2ffc3b391fe 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 9) +set(GGML_VERSION_PATCH 10) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 6422036fcb32d72e3aca93dcf97fdfeeca20c27d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 16:02:34 +0300 Subject: [PATCH 5/8] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 7af38c3eb73..abe4128c8f2 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -a04eea0761a85d18f3f504d6ab970c5c9dce705f +50634c28837c24ac68b380b5750b41e701c87d73 From 0356e33aafcc1cb409910244482fac1ec8bafe9f Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 1 Apr 2026 15:31:58 +0200 Subject: [PATCH 6/8] scripts: add function call test script (#21234) * scripts: add function call test script * add reasoning_content * fix lint --- scripts/server-test-function-call.py | 1135 ++++++++++++++++++++++++++ 1 file changed, 1135 insertions(+) create mode 100755 scripts/server-test-function-call.py diff --git a/scripts/server-test-function-call.py b/scripts/server-test-function-call.py new file mode 100755 index 00000000000..b3aae1a961e --- /dev/null +++ b/scripts/server-test-function-call.py @@ -0,0 +1,1135 @@ +#!/usr/bin/env python3 +""" +Test tool calling capability via chat completions endpoint. + +Each test case contains: + - tools: list of tool definitions (OpenAI-compatible) + - messages: initial conversation messages + - mock_tool_responses: dict mapping tool_name -> callable(arguments) -> str (JSON) + - validate: callable(tool_calls_history, final_content) -> (passed: bool, reason: str) +""" + +import argparse +import json +import requests +import sys + +# --------------------------------------------------------------------------- +# Color / formatting helpers +# --------------------------------------------------------------------------- + +RESET = "\x1b[0m" +BOLD = "\x1b[1m" +DIM = "\x1b[2m" +# Foreground colors +CYAN = "\x1b[36m" +YELLOW = "\x1b[33m" +GREEN = "\x1b[32m" +RED = "\x1b[31m" +BLUE = "\x1b[34m" +WHITE = "\x1b[97m" + + +def _print(text="", end="\n"): + sys.stdout.write(text + end) + sys.stdout.flush() + + +def print_header(title): + bar = "─" * 60 + _print(f"\n{BOLD}{CYAN}┌{bar}┐{RESET}") + _print( + f"{BOLD}{CYAN}│ {WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}│{RESET}" + ) + _print(f"{BOLD}{CYAN}└{bar}┘{RESET}") + + +def print_tool_call(name, args): + args_str = json.dumps(args) + _print( + f"\n {BOLD}{YELLOW}⚙ tool call{RESET} {CYAN}{name}{RESET}{DIM}({args_str}){RESET}" + ) + + +def print_tool_result(result): + preview = result[:160] + ("…" if len(result) > 160 else "") + _print(f" {DIM}{BLUE}↳ result{RESET} {DIM}{preview}{RESET}") + + +def print_model_output(text): + # printed inline during streaming; prefix with a visual marker on first chunk + sys.stdout.write(text) + sys.stdout.flush() + + +def print_pass(reason): + _print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}") + + +def print_fail(reason): + _print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}") + + +def print_info(msg): + _print(f"{DIM}{msg}{RESET}") + + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +def chat_completion(url, messages, tools=None, stream=False): + payload = { + "messages": messages, + "stream": stream, + "max_tokens": 4096, + } + if tools: + payload["tools"] = tools + payload["tool_choice"] = "auto" + + try: + response = requests.post(url, json=payload, stream=stream) + response.raise_for_status() + except requests.exceptions.RequestException as e: + body = e.response.content if (e.response is not None) else b"" + print_fail(f"Request error: {e} | body: {body}") + return None + + full_content = "" + reasoning_content = "" + tool_calls: list[dict] = [] + + if stream: + for line in response.iter_lines(): + if not line: + continue + decoded = line.decode("utf-8") + if not decoded.startswith("data: "): + continue + data_str = decoded[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + except json.JSONDecodeError: + continue + choices = data.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + if delta.get("reasoning_content"): + reasoning_content += delta["reasoning_content"] + if delta.get("content"): + full_content += delta["content"] + print_model_output(delta["content"]) + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + while len(tool_calls) <= idx: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + if "id" in tc: + tool_calls[idx]["id"] += tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_calls[idx]["function"]["name"] += tc["function"]["name"] + if "arguments" in tc["function"]: + tool_calls[idx]["function"]["arguments"] += tc["function"][ + "arguments" + ] + else: + data = response.json() + choices = data.get("choices", []) + if choices: + msg = choices[0].get("message", {}) + full_content = msg.get("content") or "" + reasoning_content = msg.get("reasoning_content") or "" + tool_calls = msg.get("tool_calls") or [] + if full_content: + print_model_output(full_content) + + result = {"content": full_content, "tool_calls": tool_calls} + if reasoning_content: + result["reasoning_content"] = reasoning_content + return result + + +def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6): + """ + Drive the multi-turn tool-call loop: + 1. Send messages to model. + 2. If the model returns tool calls, execute mocks and append results. + 3. Repeat until no more tool calls or max_turns reached. + + Returns (all_tool_calls, final_content). + """ + msgs = list(messages) + all_tool_calls: list[dict] = [] + + for _ in range(max_turns): + result = chat_completion(url, msgs, tools=tools, stream=stream) + if result is None: + return all_tool_calls, None + + tcs = result.get("tool_calls") or [] + content = result.get("content") or "" + + if not tcs: + # Print a visual separator before the final model response + if content: + _print(f"\n{DIM}{'·'*60}{RESET}") + _print(f"{DIM} model response:{RESET}\n") + return all_tool_calls, content + + # Record tool calls for validation + all_tool_calls.extend(tcs) + + # Append assistant message with tool calls + assistant_msg: dict = { + "role": "assistant", + "content": content, + "tool_calls": tcs, + } + reasoning = result.get("reasoning_content") + if reasoning: + assistant_msg["reasoning_content"] = reasoning + msgs.append(assistant_msg) + + # Execute each tool call via mock and append tool result messages + for tc in tcs: + tool_name = tc["function"]["name"] + try: + args = json.loads(tc["function"]["arguments"]) + except json.JSONDecodeError: + args = {} + + print_tool_call(tool_name, args) + + mock_fn = mock_tool_responses.get(tool_name) + if mock_fn: + tool_result = mock_fn(args) + else: + tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"}) + + print_tool_result(tool_result) + + msgs.append( + { + "role": "tool", + "tool_call_id": tc.get("id", ""), + "content": tool_result, + } + ) + + return all_tool_calls, None + + +# --------------------------------------------------------------------------- +# Test case runner +# --------------------------------------------------------------------------- + + +def run_test(url, test_case, stream): + name = test_case["name"] + mode = f"{'stream' if stream else 'non-stream'}" + print_header(f"{name} [{mode}]") + + all_tool_calls, final_content = run_agentic_loop( + url, + messages=test_case["messages"], + tools=test_case["tools"], + mock_tool_responses=test_case["mock_tool_responses"], + stream=stream, + ) + + if final_content is None and not all_tool_calls: + print_fail("No response from server.") + return False + + passed, reason = test_case["validate"](all_tool_calls, final_content) + if passed: + print_pass(reason) + else: + print_fail(reason) + return passed + + +# --------------------------------------------------------------------------- +# Test case definitions +# --------------------------------------------------------------------------- + +# ---- Test 1: E-commerce multi-step search (Azzoo = anonymized marketplace) ---- + +_AZZOO_TOOLS = [ + { + "type": "function", + "function": { + "name": "azzoo_search_products", + "description": ( + "Search for products on Azzoo marketplace by keyword. " + "Returns a list of matching products with IDs, titles, ratings and prices." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search keyword or phrase", + }, + "page": { + "type": "string", + "description": "Page number (1-based)", + "default": "1", + }, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "azzoo_get_product", + "description": "Retrieve detailed information about a specific Azzoo product including specs and price.", + "parameters": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "Azzoo product identifier (e.g. AZB12345)", + }, + }, + "required": ["product_id"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "azzoo_get_reviews", + "description": "Fetch customer reviews for an Azzoo product.", + "parameters": { + "type": "object", + "properties": { + "product_id": { + "type": "string", + "description": "Azzoo product identifier", + }, + "page": { + "type": "string", + "description": "Review page number", + "default": "1", + }, + }, + "required": ["product_id"], + }, + }, + }, +] + +_AZZOO_SEARCH_RESULT = { + "results": [ + { + "product_id": "AZB00001", + "title": "SteelBrew Pro Kettle 1.7L", + "rating": 4.6, + "price": 34.99, + }, + { + "product_id": "AZB00002", + "title": "HeatKeep Gooseneck Kettle", + "rating": 4.3, + "price": 27.50, + }, + { + "product_id": "AZB00003", + "title": "QuickBoil Stainless Kettle", + "rating": 4.1, + "price": 21.00, + }, + ] +} +_AZZOO_PRODUCT_RESULT = { + "product_id": "AZB00001", + "title": "SteelBrew Pro Kettle 1.7L", + "price": 34.99, + "rating": 4.6, + "review_count": 2847, + "specs": { + "material": "18/8 stainless steel", + "capacity": "1.7 L", + "auto_shutoff": True, + "keep_warm": "30 min", + "warranty": "2 years", + }, +} +_AZZOO_REVIEWS_RESULT = { + "product_id": "AZB00001", + "average_rating": 4.6, + "reviews": [ + { + "rating": 5, + "title": "Excellent build quality", + "body": "Very sturdy, boils fast and stays warm longer than expected.", + }, + { + "rating": 5, + "title": "Great for loose-leaf tea", + "body": "The wide spout makes filling a teapot easy. No leaks after months of use.", + }, + { + "rating": 3, + "title": "Minor lid issue", + "body": "The lid doesn't always click shut properly, but overall happy with it.", + }, + { + "rating": 4, + "title": "Good value", + "body": "Heats quickly and the auto shutoff works reliably.", + }, + ], +} + +AZZOO_TEST_CASE = { + "name": "Azzoo E-commerce: search -> product detail -> reviews", + "messages": [ + { + "role": "user", + "content": ( + "I need a durable stainless steel tea kettle for my weekly tea gatherings. " + "Please search Azzoo for 'stainless steel tea kettle', then get full details " + "on the top-rated result, and finally fetch its customer reviews so I can " + "check for recurring complaints. Give me a summary with pros and cons." + ), + } + ], + "tools": _AZZOO_TOOLS, + "mock_tool_responses": { + "azzoo_search_products": lambda _: json.dumps(_AZZOO_SEARCH_RESULT), + "azzoo_get_product": lambda _: json.dumps(_AZZOO_PRODUCT_RESULT), + "azzoo_get_reviews": lambda _: json.dumps(_AZZOO_REVIEWS_RESULT), + }, + "validate": lambda tcs, content: _validate_azzoo(tcs, content), +} + + +def _validate_azzoo(tcs, content): + names = [tc["function"]["name"] for tc in tcs] + if not names: + return False, "No tool calls made" + if "azzoo_search_products" not in names: + return False, f"Expected azzoo_search_products to be called, got: {names}" + # After search the model should look up product details + if "azzoo_get_product" not in names and "azzoo_get_reviews" not in names: + return False, f"Expected follow-up product/review lookup, got: {names}" + # Verify product lookup used an ID from search results + for tc in tcs: + if tc["function"]["name"] == "azzoo_get_product": + try: + args = json.loads(tc["function"]["arguments"]) + pid = args.get("product_id", "") + if not pid: + return False, "azzoo_get_product called with empty product_id" + except json.JSONDecodeError: + return False, "azzoo_get_product arguments are not valid JSON" + if not content: + return False, "No final summary produced" + return True, f"All expected tools called in order: {names}" + + +# ---- Test 2: Fitness BMI + exercise recommendations ---- + +_FITNESS_TOOLS = [ + { + "type": "function", + "function": { + "name": "calculate_bmi", + "description": "Calculate Body Mass Index (BMI) from weight and height.", + "parameters": { + "type": "object", + "properties": { + "weight_kg": { + "type": "number", + "description": "Body weight in kilograms", + }, + "height_m": {"type": "number", "description": "Height in meters"}, + }, + "required": ["weight_kg", "height_m"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_exercises", + "description": ( + "Fetch a list of exercises filtered by muscle group, difficulty, category, " + "and/or force type." + ), + "parameters": { + "type": "object", + "properties": { + "muscle": { + "type": "string", + "description": "Target muscle group (e.g. chest, back, legs)", + }, + "difficulty": { + "type": "string", + "description": "Difficulty level: beginner, intermediate, expert", + }, + "category": { + "type": "string", + "description": "Exercise category (e.g. strength, cardio, stretching)", + }, + "force": { + "type": "string", + "description": "Force type: push, pull, static", + }, + }, + "required": [], + }, + }, + }, +] + +_BMI_RESULT = {"bmi": 24.5, "category": "Normal weight", "healthy_range": "18.5 – 24.9"} +_EXERCISES_RESULT = { + "exercises": [ + { + "name": "Push-Up", + "muscle": "chest", + "difficulty": "beginner", + "equipment": "none", + "instructions": "Keep body straight, lower chest to floor.", + }, + { + "name": "Incline Dumbbell Press", + "muscle": "chest", + "difficulty": "beginner", + "equipment": "dumbbells, bench", + "instructions": "Press dumbbells up from chest on incline bench.", + }, + { + "name": "Chest Fly (cables)", + "muscle": "chest", + "difficulty": "beginner", + "equipment": "cable machine", + "instructions": "Bring cables together in an arc motion.", + }, + ] +} + +FITNESS_TEST_CASE = { + "name": "Fitness: BMI calculation + exercise suggestions", + "messages": [ + { + "role": "user", + "content": ( + "I'm a 32-year-old male, 78 kg and 1.80 m tall. " + "Please calculate my BMI and then suggest some beginner chest exercises I can do " + "to build strength. Give me a short personalised plan." + ), + } + ], + "tools": _FITNESS_TOOLS, + "mock_tool_responses": { + "calculate_bmi": lambda _: json.dumps(_BMI_RESULT), + "get_exercises": lambda _: json.dumps(_EXERCISES_RESULT), + }, + "validate": lambda tcs, content: _validate_fitness(tcs, content), +} + + +def _validate_fitness(tcs, content): + names = [tc["function"]["name"] for tc in tcs] + if not names: + return False, "No tool calls made" + if "calculate_bmi" not in names: + return False, f"Expected calculate_bmi to be called, got: {names}" + # Validate BMI args contain plausible values + for tc in tcs: + if tc["function"]["name"] == "calculate_bmi": + try: + args = json.loads(tc["function"]["arguments"]) + w = args.get("weight_kg") + h = args.get("height_m") + if w is None or h is None: + return False, f"calculate_bmi missing weight_kg or height_m: {args}" + if not (50 <= float(w) <= 200): + return False, f"calculate_bmi weight out of plausible range: {w}" + if not (1.0 <= float(h) <= 2.5): + return False, f"calculate_bmi height out of plausible range: {h}" + except (json.JSONDecodeError, ValueError) as e: + return False, f"calculate_bmi argument error: {e}" + if not content: + return False, "No final plan produced" + return True, f"Tools called: {names}" + + +# ---- Test 3: Community class planning (anonymised cooking/topic discovery) ---- + +_COMMUNITY_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_trending_questions", + "description": ( + "Fetch commonly asked questions on a topic from search engine 'People Also Ask' boxes." + ), + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Topic to search for"}, + "max_results": { + "type": "integer", + "description": "Maximum questions to return", + "default": 10, + }, + }, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_mobile_apps", + "description": "Search the mobile app store for apps matching a category or keyword.", + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "Search keyword (e.g. 'Italian cooking')", + }, + "platform": { + "type": "string", + "enum": ["ios", "android", "both"], + "default": "both", + }, + "max_results": { + "type": "integer", + "description": "Number of results", + "default": 10, + }, + }, + "required": ["keyword"], + }, + }, + }, +] + +_TRENDING_QUESTIONS_RESULT = { + "query": "Italian cuisine", + "questions": [ + "What are the most popular Italian dishes?", + "What makes Italian food different from other cuisines?", + "How do you make authentic Italian pasta from scratch?", + "What are traditional Italian desserts?", + "What herbs are commonly used in Italian cooking?", + "Is Italian food healthy?", + "What wine pairs best with Italian pasta?", + ], +} +_APPS_RESULT = { + "keyword": "Italian cooking", + "results": [ + { + "name": "PastaPro", + "rating": 4.5, + "installs": "500K+", + "focus": "pasta recipes only", + }, + { + "name": "CookEasy", + "rating": 4.2, + "installs": "1M+", + "focus": "general cooking, limited Italian content", + }, + { + "name": "ItalianKitchen", + "rating": 3.8, + "installs": "100K+", + "focus": "regional Italian recipes, no video", + }, + ], +} + +COMMUNITY_CLASS_TEST_CASE = { + "name": "Community class planning: trending topics + app gap analysis", + "messages": [ + { + "role": "user", + "content": ( + "I want to start teaching Italian cooking classes at my community centre. " + "First, find out what people commonly ask about Italian cuisine online. " + "Then search for existing Italian cooking apps to see what they cover. " + "Use both results to suggest three unique angles for my classes that fill gaps " + "in what apps already offer." + ), + } + ], + "tools": _COMMUNITY_TOOLS, + "mock_tool_responses": { + "get_trending_questions": lambda _: json.dumps(_TRENDING_QUESTIONS_RESULT), + "search_mobile_apps": lambda _: json.dumps(_APPS_RESULT), + }, + "validate": lambda tcs, content: _validate_community(tcs, content), +} + + +def _validate_community(tcs, content): + names = [tc["function"]["name"] for tc in tcs] + if not names: + return False, "No tool calls made" + missing = [ + t for t in ("get_trending_questions", "search_mobile_apps") if t not in names + ] + if missing: + return False, f"Missing expected tool calls: {missing}; got: {names}" + if not content: + return False, "No class suggestion produced" + return True, f"Both discovery tools called: {names}" + + +# ---- Test 4: Multi-hostname geolocation filter (anonymized gallery discovery) ---- +# Inspired by: checking gallery website server locations to find truly remote venues. +# Anonymized: galleryone.de → halle-eins.de, gallerytwo.fr → galerie-deux.fr, +# gallerythree.it → galleria-tre.it + +_GEO_TOOLS = [ + { + "type": "function", + "function": { + "name": "lookup_ip_geolocation", + "description": ( + "Retrieve geolocation data for an IP address or hostname, including country, " + "city, coordinates, and network info. Useful for verifying physical server " + "locations or personalising regional content." + ), + "parameters": { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "IP address or hostname to look up (e.g. '8.8.8.8' or 'example.com').", + }, + }, + "required": ["host"], + }, + }, + }, +] + +# Mock: one urban (Berlin → discard), two rural (keep) +_GEO_RESPONSES = { + "halle-eins.de": { + "host": "halle-eins.de", + "city": "Berlin", + "country": "DE", + "lat": 52.5200, + "lon": 13.4050, + "is_major_city": True, + }, + "galerie-deux.fr": { + "host": "galerie-deux.fr", + "city": "Rocamadour", + "country": "FR", + "lat": 44.7994, + "lon": 1.6178, + "is_major_city": False, + }, + "galleria-tre.it": { + "host": "galleria-tre.it", + "city": "Matera", + "country": "IT", + "lat": 40.6664, + "lon": 16.6044, + "is_major_city": False, + }, +} + + +def _geo_mock(args): + host = args.get("host", "") + return json.dumps(_GEO_RESPONSES.get(host, {"error": f"unknown host: {host}"})) + + +GEO_TEST_CASE = { + "name": "Gallery geolocation: filter urban venues, keep remote ones", + "messages": [ + { + "role": "user", + "content": ( + "I have abstract paintings to exhibit in remote European galleries. " + "I received enquiries from three venues: halle-eins.de, galerie-deux.fr, " + "and galleria-tre.it. Please look up the geolocation of each website's server. " + "Discard any venue whose server is in a major city (e.g. Berlin, Paris, Rome). " + "For the remaining venues, report their exact coordinates so I can check " + "whether hiking trails are nearby — my work thrives where nature and art meet." + ), + } + ], + "tools": _GEO_TOOLS, + "mock_tool_responses": { + "lookup_ip_geolocation": _geo_mock, + }, + "validate": lambda tcs, content: _validate_geo(tcs, content), +} + + +def _validate_geo(tcs, content): + names = [tc["function"]["name"] for tc in tcs] + if not names: + return False, "No tool calls made" + # Expect exactly one geolocation call per domain (3 total) + geo_calls = [tc for tc in tcs if tc["function"]["name"] == "lookup_ip_geolocation"] + if len(geo_calls) < 3: + return ( + False, + f"Expected geolocation called 3 times (once per domain), got {len(geo_calls)}", + ) + queried_hosts = set() + for tc in geo_calls: + try: + args = json.loads(tc["function"]["arguments"]) + host = args.get("host", "") + if not host: + return False, f"lookup_ip_geolocation called with empty host: {args}" + queried_hosts.add(host) + except json.JSONDecodeError: + return False, "lookup_ip_geolocation arguments are not valid JSON" + expected = {"halle-eins.de", "galerie-deux.fr", "galleria-tre.it"} + if not expected.issubset(queried_hosts): + return ( + False, + f"Not all domains queried. Expected {expected}, got {queried_hosts}", + ) + if not content: + return False, "No final summary produced" + return True, f"All 3 domains geolocated: {sorted(queried_hosts)}" + + +# ---- Test 5: EV fleet expansion — stock → security → property → video ---- +# Inspired by: multi-step business analysis combining finance, cybersecurity, +# real estate and educational content. +# Anonymized: Tesla → Voltara (VLTR), Rivian → Rivex (RVXN), +# Trenton → Halverton + +_EV_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_stock_quote", + "description": "Retrieve the latest market quote for a financial instrument by ticker symbol.", + "parameters": { + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "Ticker symbol (e.g. 'VLTR', 'RVXN')", + }, + "interval": { + "type": "string", + "description": "Time interval: 1min, 5min, 1h, 1day, 1week", + "default": "1day", + }, + }, + "required": ["symbol"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_security_advisories", + "description": ( + "Fetch current cybersecurity advisories from the national security agency, " + "covering known vulnerabilities and exploits for industrial and consumer systems." + ), + "parameters": { + "type": "object", + "properties": { + "keyword": { + "type": "string", + "description": "Filter advisories by keyword or product name", + }, + "limit": { + "type": "integer", + "description": "Maximum number of advisories to return", + "default": 5, + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_commercial_properties", + "description": "Search for commercial properties (offices, garages, warehouses) available for rent or sale in a given city.", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name to search in"}, + "property_type": { + "type": "string", + "description": "Type of property: office, garage, warehouse, premises", + }, + "operation": { + "type": "string", + "enum": ["rent", "sale"], + "default": "rent", + }, + "max_price": { + "type": "integer", + "description": "Maximum monthly rent or sale price", + }, + }, + "required": ["city", "property_type"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_video_recommendations", + "description": "Fetch a list of recommended videos related to a given topic or reference video.", + "parameters": { + "type": "object", + "properties": { + "topic": { + "type": "string", + "description": "Topic or keyword to search for related videos", + }, + }, + "required": ["topic"], + }, + }, + }, +] + +_STOCK_RESULT_VLTR = { + "symbol": "VLTR", + "company": "Voltara Inc.", + "price": 218.45, + "change_pct": "+2.3%", + "market_cap": "694B", + "currency": "USD", +} +_STOCK_RESULT_RVXN = { + "symbol": "RVXN", + "company": "Rivex Motors", + "price": 12.80, + "change_pct": "-1.1%", + "market_cap": "11B", + "currency": "USD", +} +_ADVISORIES_RESULT = { + "count": 2, + "advisories": [ + { + "id": "ICSA-24-102-01", + "title": "Voltara In-Vehicle Infotainment System Authentication Bypass", + "severity": "Medium", + "summary": "Improper authentication in the OTA update module may allow an adjacent attacker to install unsigned firmware.", + "published": "2024-04-11", + }, + { + "id": "ICSA-24-085-03", + "title": "Voltara Charging Management API Input Validation Flaw", + "severity": "Low", + "summary": "Insufficient input validation in the charging session API could expose internal error messages.", + "published": "2024-03-26", + }, + ], +} +_PROPERTIES_RESULT = { + "city": "Halverton", + "listings": [ + { + "id": "HV-0041", + "type": "garage", + "area_sqm": 420, + "monthly_rent": 2800, + "ev_power_outlets": 12, + "address": "14 Ironworks Lane, Halverton", + }, + { + "id": "HV-0089", + "type": "warehouse", + "area_sqm": 900, + "monthly_rent": 4200, + "ev_power_outlets": 30, + "address": "7 Depot Road, Halverton", + }, + ], +} +_VIDEOS_RESULT = { + "topic": "fleet electrification", + "recommendations": [ + { + "title": "How to Build an EV Fleet from Scratch", + "channel": "Fleet Future", + "views": "182K", + }, + { + "title": "EV Charging Infrastructure for Commercial Fleets", + "channel": "GreenDrive Pro", + "views": "94K", + }, + { + "title": "Total Cost of Ownership: Electric vs Diesel Vans", + "channel": "LogisticsTech", + "views": "61K", + }, + ], +} + + +def _ev_stock_mock(args): + symbol = args.get("symbol", "").upper() + if symbol == "VLTR": + return json.dumps(_STOCK_RESULT_VLTR) + if symbol == "RVXN": + return json.dumps(_STOCK_RESULT_RVXN) + return json.dumps({"error": f"Unknown symbol: {symbol}"}) + + +EV_FLEET_TEST_CASE = { + "name": "EV fleet expansion: stock → cybersecurity → property → videos", + "messages": [ + { + "role": "user", + "content": ( + "I'm expanding my courier business into electric vehicles and need a multi-step analysis:\n" + "1. Get the latest stock quote for Voltara (VLTR) and Rivex (RVXN). " + "If either is above $50, continue with that company.\n" + "2. Search for cybersecurity advisories related to that company's vehicle models " + "to understand any tech risks.\n" + "3. Find commercial garage or warehouse properties in Halverton suitable for " + "EV charging infrastructure.\n" + "4. Recommend videos on fleet electrification strategies.\n" + "Please work through all four steps and give me a concise summary." + ), + } + ], + "tools": _EV_TOOLS, + "mock_tool_responses": { + "get_stock_quote": _ev_stock_mock, + "get_security_advisories": lambda _: json.dumps(_ADVISORIES_RESULT), + "search_commercial_properties": lambda _: json.dumps(_PROPERTIES_RESULT), + "get_video_recommendations": lambda _: json.dumps(_VIDEOS_RESULT), + }, + "validate": lambda tcs, content: _validate_ev(tcs, content), +} + + +def _validate_ev(tcs, content): + names = [tc["function"]["name"] for tc in tcs] + if not names: + return False, "No tool calls made" + # Stock quote must come first + if names[0] != "get_stock_quote": + return False, f"Expected get_stock_quote to be called first, got: {names[0]}" + stock_calls = [tc for tc in tcs if tc["function"]["name"] == "get_stock_quote"] + for tc in stock_calls: + try: + args = json.loads(tc["function"]["arguments"]) + sym = args.get("symbol", "") + if not sym: + return False, f"get_stock_quote called with empty symbol: {args}" + except json.JSONDecodeError: + return False, "get_stock_quote arguments are not valid JSON" + # All four pipeline tools expected + required = [ + "get_stock_quote", + "get_security_advisories", + "search_commercial_properties", + "get_video_recommendations", + ] + missing = [t for t in required if t not in names] + if missing: + return False, f"Missing pipeline steps: {missing}" + if not content: + return False, "No final summary produced" + return True, f"Full 4-step pipeline executed: {names}" + + +# --------------------------------------------------------------------------- +# All test cases +# --------------------------------------------------------------------------- + +ALL_TEST_CASES = [ + AZZOO_TEST_CASE, + FITNESS_TEST_CASE, + COMMUNITY_CLASS_TEST_CASE, + GEO_TEST_CASE, + EV_FLEET_TEST_CASE, +] + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Test llama-server tool-calling capability." + ) + parser.add_argument("--host", default="localhost") + parser.add_argument("--port", default=8080, type=int) + parser.add_argument( + "--no-stream", action="store_true", help="Disable streaming mode tests" + ) + parser.add_argument( + "--stream-only", action="store_true", help="Only run streaming mode tests" + ) + parser.add_argument( + "--test", + help="Run only the test whose name contains this substring (case-insensitive)", + ) + args = parser.parse_args() + + url = f"http://{args.host}:{args.port}/v1/chat/completions" + print_info(f"Testing server at {url}") + + modes = [] + if not args.stream_only: + modes.append(False) + if not args.no_stream: + modes.append(True) + + cases: list[dict] = ALL_TEST_CASES + if args.test: + name_filter = args.test.lower() + cases = [c for c in cases if name_filter in str(c["name"]).lower()] + if not cases: + print_fail(f"No test cases matched '{args.test}'") + sys.exit(1) + + total = 0 + passed = 0 + for stream in modes: + for case in cases: + total += 1 + if run_test(url, case, stream=stream): + passed += 1 + + color = GREEN if passed == total else RED + _print(f"\n{BOLD}{color}{'─'*60}{RESET}") + _print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}") + _print(f"{BOLD}{color}{'─'*60}{RESET}\n") + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() From 744c0c7310aad90e99a29c5739e4ee317fb6a748 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 16:58:01 +0300 Subject: [PATCH 7/8] llama : rotate activations for better quantization (#21038) * llama : rotate activations for better quantization * cont : rotate V more + refactor * cont : rotate caches separately + support non-power-of-2 head sizes * cont : simplify * cont : add reference for V rotation * cont : refactor * cont : support context shift * cont : consolidate * cont : dedup + allow different types for the rotation matrix * cont : add env variable to disable rotation * cont : simplify attn rot kv cache logic + rename env * cont : pre-compute the Hadamard matrices --- src/llama-graph.cpp | 119 ++++++++++++++++++----- src/llama-graph.h | 8 ++ src/llama-kv-cache.cpp | 210 ++++++++++++++++++++++++++++++++++++++++- src/llama-kv-cache.h | 26 +++++ 4 files changed, 337 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c2833b75ced..0e7d96ca10d 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -19,7 +19,7 @@ // dedup helpers -static ggml_tensor * build_kq_mask( +static ggml_tensor * build_attn_inp_kq_mask( ggml_context * ctx, const llama_kv_cache_context * mctx, const llama_ubatch & ubatch, @@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; } static bool can_reuse_kq_mask( @@ -52,6 +56,21 @@ static bool can_reuse_kq_mask( // impl +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->get_base()->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->get_base()->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { + mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } + if (inp_attn->self_k_rot) { + attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -2002,13 +2053,13 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0); + return inp; } @@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn( int il) const { GGML_ASSERT(v_mla == nullptr); + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + + if (inp->self_v_rot) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { @@ -2090,9 +2154,7 @@ static std::unique_ptr build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } @@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + if (k_cur) { + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + } + if (inp->self_v_rot) { + if (v_cur) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { cur = build_lora_mm(wo, cur); } @@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - ggml_set_name(inp->self_kq_mask, "self_kq_mask"); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask_swa); - ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); - + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } + inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } @@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask); - + inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } @@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask_swa); - + inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 4855685ef71..bb0ad75198f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -308,6 +308,10 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: assumes v_rot^ == I + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed // llm_graph_params from a previous batch, causing stack-use-after-return @@ -384,6 +388,10 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: using same rotation matrices for both base and swa cache + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + const llama_hparams hparams; const llama_cparams cparams; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 5f57ba9e1d8..3e0fd3107f3 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -13,6 +13,65 @@ #include #include +static bool ggml_is_power_of_2(int n) { + return (n & (n - 1)) == 0; +} + +// orthonormal Walsh-Hadamard rotation matrix +// note: res^2 == I +static void ggml_gen_hadamard(ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + + const int n = tensor->ne[0]; + + assert(ggml_is_power_of_2(n)); + assert(tensor->ne[1] == n); + assert(tensor->ne[2] == 1); + assert(tensor->ne[3] == 1); + + std::vector data_f32; + + float * data = (float *) tensor->data; + + if (tensor->type != GGML_TYPE_F32) { + data_f32.resize(n*n); + data = data_f32.data(); + } + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } + + if (tensor->type != GGML_TYPE_F32) { + ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr); + } +} + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + // // llama_kv_cache // @@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + ggml_is_quantized(type_k) && + !hparams.is_n_embd_k_gqa_variable() && + hparams.n_embd_head_k() % 64 == 0; + + attn_rot_v = + !attn_rot_disable && + ggml_is_quantized(type_v) && + !hparams.is_n_embd_v_gqa_variable() && + hparams.n_embd_head_v() % 64 == 0; + + LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k); + LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v); + + // pre-compute the haramard matrices and keep them in host memory + // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers + if (attn_rot_k || attn_rot_v) { + for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) { + attn_rot_hadamard[n] = std::vector(n*n); + + ggml_init_params params = { + /* .mem_size = */ 1*ggml_tensor_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + + ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n); + tmp->data = attn_rot_hadamard[n].data(); + + ggml_gen_hadamard(tmp); + } + } + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } @@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const { return result; } +ggml_type llama_kv_cache::type_k() const { + return layers[0].k->type; +} + +ggml_type llama_kv_cache::type_v() const { + return layers[0].v->type; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; @@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama return v_idxs; } +ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_k) { + int nrot = 64; + + // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V) + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 + do { + nrot *= 2; + } while (hparams.n_embd_head_k() % nrot == 0); + nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_k_rot"); + } + + return res; +} + +ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_v) { + int nrot = 64; + // using smaller rotation matrices for V seems beneficial + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570 + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_v() % nrot == 0); + //nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_v_rot"); + } + + return res; +} + void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + +void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // rotate back + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + // rotate fwd + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1591,6 +1766,9 @@ class llm_graph_input_k_shift : public llm_graph_input_i { ggml_tensor * k_shift; // I32 [kv_size*n_stream] + // note: assumes k_rot^2 == I + ggml_tensor * k_rot = nullptr; + const llama_kv_cache * kv_self; }; @@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (k_rot) { + kv_self->set_input_k_rot(k_rot); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { @@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + inp->k_rot = build_input_k_rot(ctx); + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +ggml_type llama_kv_cache_context::type_k() const { + return kv->type_k(); +} + +ggml_type llama_kv_cache_context::type_v() const { + return kv->type_v(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } @@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con return kv->build_input_v_idxs(ctx, ubatch); } +ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const { + return kv->build_input_v_rot(ctx); +} + void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } @@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const { + kv->set_input_v_rot(dst); +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 90a0610c49d..d4569a06f71 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -152,6 +152,9 @@ class llama_kv_cache : public llama_memory_i { bool get_has_shift() const; + ggml_type type_k() const; + ggml_type type_v() const; + // // graph_build API // @@ -191,6 +194,9 @@ class llama_kv_cache : public llama_memory_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; @@ -199,6 +205,9 @@ class llama_kv_cache : public llama_memory_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -226,6 +235,13 @@ class llama_kv_cache : public llama_memory_i { // SWA const uint32_t n_swa = 0; + // env: LLAMA_ATTN_ROT_DISABLE + bool attn_rot_k = false; + bool attn_rot_v = false; + + // pre-computed hadamard martrices + std::unordered_map> attn_rot_hadamard; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; @@ -262,6 +278,7 @@ class llama_kv_cache : public llama_memory_i { ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -328,6 +345,9 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; + ggml_type type_k() const; + ggml_type type_v() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; @@ -347,6 +367,9 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -354,6 +377,9 @@ class llama_kv_cache_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: llama_memory_status status; From 1d6d4cf7a5361046f778414c5b1f5ecbc07eeb77 Mon Sep 17 00:00:00 2001 From: Jonathan <47618606+jbuchananr@users.noreply.github.com> Date: Wed, 1 Apr 2026 07:22:44 -0700 Subject: [PATCH 8/8] fix: tool call parsing for LFM2 and LFM2.5 models (#21242) * fix: tool call parsing for LFM2 and LFM2.5 models' * refactor: add test / break out lfm2 and lfm2.5 parsing logic --- common/chat.cpp | 100 ++++++++++++++++++++++--- models/templates/LFM2.5-Instruct.jinja | 45 +++++++++++ tests/test-chat.cpp | 61 +++++++++++++++ 3 files changed, 197 insertions(+), 9 deletions(-) create mode 100644 models/templates/LFM2.5-Instruct.jinja diff --git a/common/chat.cpp b/common/chat.cpp index c2ca17c7430..7536c0cd015 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1274,11 +1274,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp return data; } -// LFM2 format: -// - Reasoning: {reasoning} (optional, only if enable_thinking is true) -// - Content: text after reasoning (optional) -// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|> -// Tool calls can appear multiple times (parallel tool calls) +// LFM2 format: uses <|tool_list_start|>[...]<|tool_list_end|> in system prompt +// and <|tool_call_start|>[name(arg="val")]<|tool_call_end|> for tool calls. +// - Reasoning: {reasoning} (optional) +// - Content: text before a tool call (optional) +// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")] +// Tool calls can appear multiple times (parallel tool calls supported) static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { common_chat_params data; @@ -1319,9 +1320,9 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { return generation_prompt + reasoning + p.content(p.rest()) + end; } - auto tool_calls = p.rule("tool-calls", - p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) + + p.trigger_rule("tool-call", + p.literal(TOOL_CALL_START) + p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) + p.literal(TOOL_CALL_END) ) @@ -1349,6 +1350,80 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START } }; } + return data; +} + +// LFM2.5 format: uses plain "List of tools: [...]" in system prompt, no wrapper tokens. +// Tool calls are bare [name(arg="val")], though model may optionally emit <|tool_call_start|>. +// - Reasoning: {reasoning} (optional) +// - Content: text before a tool call (optional) +// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")] +// Tool calls can appear multiple times (parallel tool calls supported) +static common_chat_params common_chat_params_init_lfm2_5(const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + common_chat_params data; + + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.preserved_tokens = { + "<|tool_call_start|>", + "<|tool_call_end|>", + "", + "", + }; + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + + const std::string THINK_START = ""; + const std::string THINK_END = ""; + + data.thinking_start_tag = THINK_START; + data.thinking_end_tag = THINK_END; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START); + auto end = p.end(); + + auto reasoning = p.eps(); + if (extract_reasoning && inputs.enable_thinking) { + reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END); + } + + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return generation_prompt + reasoning + p.content(p.rest()) + end; + } + + auto tool_calls = p.rule("tool-calls", + p.trigger_rule("tool-call", + p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) + ) + ); + + auto content = p.content(p.until_one_of({"<|tool_call_start|>", "["})); + auto maybe_start = p.optional(p.literal("<|tool_call_start|>")); + return generation_prompt + reasoning + content + maybe_start + tool_calls + end; + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + foreach_function(inputs.tools, [&](const json & tool) { + const std::string name = tool.at("function").at("name"); + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[" + name + "(" }); + }); + } return data; } @@ -1530,14 +1605,21 @@ static std::optional try_specialized_template( return common_chat_params_init_kimi_k2(tmpl, params); } - // LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format - // Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers + // LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list + // and <|tool_call_start|>[...]<|tool_call_end|> around each tool call if (src.find("<|tool_list_start|>") != std::string::npos && src.find("<|tool_list_end|>") != std::string::npos) { LOG_DBG("Using specialized template: LFM2\n"); return common_chat_params_init_lfm2(tmpl, params); } + // LFM2.5 format detection: template uses plain "List of tools: [...]" with no special tokens + if (src.find("List of tools: [") != std::string::npos && + src.find("<|tool_list_start|>") == std::string::npos) { + LOG_DBG("Using specialized template: LFM2.5\n"); + return common_chat_params_init_lfm2_5(tmpl, params); + } + // GigaChatV3 format detection if (src.find("<|role_sep|>") != std::string::npos && src.find("<|message_sep|>") != std::string::npos && diff --git a/models/templates/LFM2.5-Instruct.jinja b/models/templates/LFM2.5-Instruct.jinja new file mode 100644 index 00000000000..7778756dd92 --- /dev/null +++ b/models/templates/LFM2.5-Instruct.jinja @@ -0,0 +1,45 @@ +{{- bos_token -}} +{%- set keep_past_thinking = keep_past_thinking | default(false) -%} +{%- set ns = namespace(system_prompt="") -%} +{%- if messages[0]["role"] == "system" -%} + {%- set ns.system_prompt = messages[0]["content"] -%} + {%- set messages = messages[1:] -%} +{%- endif -%} +{%- if tools -%} + {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: [" -%} + {%- for tool in tools -%} + {%- if tool is not string -%} + {%- set tool = tool | tojson -%} + {%- endif -%} + {%- set ns.system_prompt = ns.system_prompt + tool -%} + {%- if not loop.last -%} + {%- set ns.system_prompt = ns.system_prompt + ", " -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.system_prompt = ns.system_prompt + "]" -%} +{%- endif -%} +{%- if ns.system_prompt -%} + {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}} +{%- endif -%} +{%- set ns.last_assistant_index = -1 -%} +{%- for message in messages -%} + {%- if message["role"] == "assistant" -%} + {%- set ns.last_assistant_index = loop.index0 -%} + {%- endif -%} +{%- endfor -%} +{%- for message in messages -%} + {{- "<|im_start|>" + message["role"] + "\n" -}} + {%- set content = message["content"] -%} + {%- if content is not string -%} + {%- set content = content | tojson -%} + {%- endif -%} + {%- if message["role"] == "assistant" and not keep_past_thinking and loop.index0 != ns.last_assistant_index -%} + {%- if "" in content -%} + {%- set content = content.split("")[-1] | trim -%} + {%- endif -%} + {%- endif -%} + {{- content + "<|im_end|>\n" -}} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} +{%- endif -%} \ No newline at end of file diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 1c4da681955..b66916687bf 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -2712,6 +2712,67 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); } + // LFM2.5 tests - uses plain "List of tools: [...]" and bare [name(args)] without wrapper tokens + { + auto tst = peg_tester("models/templates/LFM2.5-Instruct.jinja", detailed_debug); + + // Basic content only + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + + // Single tool call without reasoning + tst.test("[special_function(arg1=1)]") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call with string argument + tst.test("[get_time(city=\"XYZCITY\")]") + .tools({ get_time_tool }) + .expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}")) + .run(); + + // Tool call with reasoning (enable_thinking=true) + tst.test("I'm\nthinking[special_function(arg1=1)]") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + // Multiple tool calls (parallel) + tst.test("[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]") + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + + // Tool call with content before tool call + tst.test("Let me check the time.[get_time(city=\"Paris\")]") + .tools({ get_time_tool }) + .expect(message_with_reasoning_content_and_multiple_tool_calls( + "", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } } + )) + .run(); + + // Partial tool call (streaming) + tst.test("[special_function(arg1=") + .tools({ special_function_tool }) + .is_partial(true) + .expect(simple_assist_msg("", "", "special_function", "{\"arg1\": ")) + .run(); + + // Tool call with empty arguments + tst.test("[empty_args()]") + .tools({ empty_args_tool }) + .expect(simple_assist_msg("", "", "empty_args", "{}")) + .run(); + } + // Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format // Format: <|tools_prefix|>[{"function_name": {...arguments...}}]<|tools_suffix|> {