diff --git a/common/chat.cpp b/common/chat.cpp
index c2ca17c7430..7536c0cd015 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -1274,11 +1274,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
return data;
}
-// LFM2 format:
-// - Reasoning: {reasoning} (optional, only if enable_thinking is true)
-// - Content: text after reasoning (optional)
-// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
-// Tool calls can appear multiple times (parallel tool calls)
+// LFM2 format: uses <|tool_list_start|>[...]<|tool_list_end|> in system prompt
+// and <|tool_call_start|>[name(arg="val")]<|tool_call_end|> for tool calls.
+// - Reasoning: {reasoning} (optional)
+// - Content: text before a tool call (optional)
+// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
+// Tool calls can appear multiple times (parallel tool calls supported)
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
@@ -1319,9 +1320,9 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
-
auto tool_calls = p.rule("tool-calls",
- p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) +
+ p.trigger_rule("tool-call",
+ p.literal(TOOL_CALL_START) +
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) +
p.literal(TOOL_CALL_END)
)
@@ -1349,6 +1350,80 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START }
};
}
+ return data;
+}
+
+// LFM2.5 format: uses plain "List of tools: [...]" in system prompt, no wrapper tokens.
+// Tool calls are bare [name(arg="val")], though model may optionally emit <|tool_call_start|>.
+// - Reasoning: {reasoning} (optional)
+// - Content: text before a tool call (optional)
+// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
+// Tool calls can appear multiple times (parallel tool calls supported)
+static common_chat_params common_chat_params_init_lfm2_5(const common_chat_template & tmpl,
+ const autoparser::generation_params & inputs) {
+ common_chat_params data;
+
+ data.prompt = common_chat_template_direct_apply(tmpl, inputs);
+ data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
+ data.supports_thinking = true;
+ data.preserved_tokens = {
+ "<|tool_call_start|>",
+ "<|tool_call_end|>",
+ "",
+ "",
+ };
+
+ auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
+ auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
+ auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
+
+ const std::string THINK_START = "";
+ const std::string THINK_END = "";
+
+ data.thinking_start_tag = THINK_START;
+ data.thinking_end_tag = THINK_END;
+
+ auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
+ auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
+ auto end = p.end();
+
+ auto reasoning = p.eps();
+ if (extract_reasoning && inputs.enable_thinking) {
+ reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
+ }
+
+ if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
+ return generation_prompt + reasoning + p.content(p.rest()) + end;
+ }
+
+ auto tool_calls = p.rule("tool-calls",
+ p.trigger_rule("tool-call",
+ p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls)
+ )
+ );
+
+ auto content = p.content(p.until_one_of({"<|tool_call_start|>", "["}));
+ auto maybe_start = p.optional(p.literal("<|tool_call_start|>"));
+ return generation_prompt + reasoning + content + maybe_start + tool_calls + end;
+ });
+
+ data.parser = parser.save();
+
+ if (include_grammar) {
+ data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ auto schema = function.at("parameters");
+ builder.resolve_refs(schema);
+ });
+ parser.build_grammar(builder, data.grammar_lazy);
+ });
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const std::string name = tool.at("function").at("name");
+ data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[" + name + "(" });
+ });
+ }
return data;
}
@@ -1530,14 +1605,21 @@ static std::optional try_specialized_template(
return common_chat_params_init_kimi_k2(tmpl, params);
}
- // LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format
- // Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers
+ // LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list
+ // and <|tool_call_start|>[...]<|tool_call_end|> around each tool call
if (src.find("<|tool_list_start|>") != std::string::npos &&
src.find("<|tool_list_end|>") != std::string::npos) {
LOG_DBG("Using specialized template: LFM2\n");
return common_chat_params_init_lfm2(tmpl, params);
}
+ // LFM2.5 format detection: template uses plain "List of tools: [...]" with no special tokens
+ if (src.find("List of tools: [") != std::string::npos &&
+ src.find("<|tool_list_start|>") == std::string::npos) {
+ LOG_DBG("Using specialized template: LFM2.5\n");
+ return common_chat_params_init_lfm2_5(tmpl, params);
+ }
+
// GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos &&
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index ab558438e95..2ffc3b391fe 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
-set(GGML_VERSION_PATCH 9)
+set(GGML_VERSION_PATCH 10)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 7d7f20af3a0..9affe023403 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
}
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
-#ifdef FP8_AVAILABLE
- const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
-#if defined(GGML_USE_HIP) && defined(CDNA3)
- // ROCm dose not support fp8 in software on devices with fp8 hardware,
+#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
+ // ROCm does not support fp8 in software on devices with fp8 hardware,
// but CDNA3 supports only e4m3_fnuz (no inf).
+ const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits);
+ return static_cast(xf) / 2;
#else
+#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
+ const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits);
-#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
return static_cast(xf) / 2;
#else
- NO_DEVICE_CODE;
-#endif // FP8_AVAILABLE
+ if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
+ return 0.0f;
+ }
+ const int exp = (x >> 3) & 0xF;
+ const int man = x & 0x7;
+ float raw;
+ if (exp == 0) {
+ raw = ldexpf((float) man, -9);
+ } else {
+ raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
+ }
+ return static_cast(raw / 2);
+#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
+#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index d1239b1c5f7..75b62129ade 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
-#ifdef FP8_AVAILABLE
case GGML_TYPE_NVFP4:
-#endif // FP8_AVAILABLE
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 9a69f41d159..27b4145ac9a 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
case GGML_TYPE_MXFP4:
mul_mat_q_case(ctx, args, stream);
break;
+ case GGML_TYPE_NVFP4:
+ mul_mat_q_case(ctx, args, stream);
+ break;
case GGML_TYPE_Q2_K:
mul_mat_q_case(ctx, args, stream);
break;
@@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
+ case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
-
}
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 255e59f6fc6..51e8dad4ce7 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_MXFP4:
return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_NVFP4:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q2_K:
return MMQ_Q8_1_DS_LAYOUT_D2S6;
case GGML_TYPE_Q3_K:
@@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
+ case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
}
-#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
-#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
-#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
-#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
-#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
+#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
+static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
+
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
}
}
+
+template
+static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
+ int * __restrict__ x_tile,
+ const int kb0,
+ const int i_max,
+ const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
+ constexpr int rows_per_warp = warp_size / threads_per_row;
+ const int kbx = threadIdx.x % threads_per_row;
+ const int row_in_warp = threadIdx.x / threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
+
+ if constexpr (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
+ const uint32_t * __restrict__ src_qs = reinterpret_cast(bxi->qs);
+ const int kqs = 16 * kbx;
+ const int ksc = 4 * kbx;
+
+#pragma unroll
+ for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
+ const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
+ const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
+ x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
+ x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
+#else
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
+ x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
+ x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+ }
+}
+
template
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
-// Used for Q3_K, IQ2_S, and IQ2_XS
+// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
template
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3261,6 +3327,14 @@ struct mmq_type_traits {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
};
+template
+struct mmq_type_traits {
+ static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
+};
+
template
struct mmq_type_traits {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
+extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index b7b5832293e..40d51f93fa4 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -35,7 +35,7 @@
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
- "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
+ "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu
new file mode 100644
index 00000000000..2cb140d35a3
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_NVFP4);
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index fcb0db99c6b..fd84c917853 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -23,6 +23,7 @@
#include "ggml-impl.h"
#include "ggml-sycl.h"
#include "presets.hpp"
+#include "type.hpp"
#include "sycl_hw.hpp"
namespace syclexp = sycl::ext::oneapi::experimental;
@@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
return val;
}
+static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
+ const uint32_t bits = x * (x != 0x7F && x != 0xFF);
+ const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits);
+ return static_cast(xf) / 2;
+}
+
#endif // GGML_SYCL_COMMON_HPP
diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp
index d17aca2cac4..d7f60cbc9ea 100644
--- a/ggml/src/ggml-sycl/convert.cpp
+++ b/ggml/src/ggml-sycl/convert.cpp
@@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
});
}
+template
+static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+ GGML_ASSERT(k % QK_NVFP4 == 0);
+ const int nb = k / QK_NVFP4;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_nvfp4(vx, y, k);
+ });
+}
+
+
template
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
+ case GGML_TYPE_NVFP4:
+ return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl;
#ifdef GGML_SYCL_HAS_BF16
@@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return convert_unary_sycl;
#endif
default:
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
@@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
+ case GGML_TYPE_NVFP4:
+ return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl;
#ifdef GGML_SYCL_HAS_BF16
@@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return convert_unary_sycl;
#endif
default:
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp
index da2a605daa8..3272724f41b 100644
--- a/ggml/src/ggml-sycl/dequantize.hpp
+++ b/ggml/src/ggml-sycl/dequantize.hpp
@@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
}
}
+
+template
+static void dequantize_block_nvfp4(
+ const void * __restrict__ vx,
+ dst_t * __restrict__ yy,
+ const int64_t ne) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int64_t i = item_ct1.get_group(2);
+ const int tid = item_ct1.get_local_id(2);
+
+ const int64_t base = i * QK_NVFP4;
+ if (base >= ne) {
+ return;
+ }
+
+ const block_nvfp4 * x = (const block_nvfp4 *) vx;
+ const block_nvfp4 & xb = x[i];
+
+ const int sub = tid / (QK_NVFP4_SUB / 2);
+ const int j = tid % (QK_NVFP4_SUB / 2);
+
+ const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
+ const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
+
+ const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
+ const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
+
+ yy[y0] = ggml_sycl_cast(d * kvalues_mxfp4[q & 0x0F]);
+ yy[y1] = ggml_sycl_cast(d * kvalues_mxfp4[q >> 4]);
+}
+
+
#endif // GGML_SYCL_DEQUANTIZE_HPP
diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp
index 316aa0d0fb5..5abc50fabfe 100644
--- a/ggml/src/ggml-sycl/mmvq.cpp
+++ b/ggml/src/ggml-sycl/mmvq.cpp
@@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
}
}
+static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_NVFP4 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+
+ {
+ stream->submit([&](sycl::handler & cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
@@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
case GGML_TYPE_MXFP4:
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
+ case GGML_TYPE_NVFP4:
+ mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
default:
- GGML_ABORT("fatal error");
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
}
}
GGML_UNUSED(src1);
diff --git a/ggml/src/ggml-sycl/type.hpp b/ggml/src/ggml-sycl/type.hpp
new file mode 100644
index 00000000000..d7ff89d7d42
--- /dev/null
+++ b/ggml/src/ggml-sycl/type.hpp
@@ -0,0 +1,112 @@
+#pragma once
+
+#include
+#include
+#include
+
+inline uint8_t float_to_e4m3(float f)
+{
+ if (sycl::isnan(f)) {
+ return 0x7F; // Canonical NaN (positive)
+ }
+
+ uint32_t bits = sycl::bit_cast(f);
+ uint32_t sign = (bits >> 31) & 0x1u;
+ uint32_t exp = (bits >> 23) & 0xFFu;
+ uint32_t mant = bits & 0x7FFFFFu;
+
+ // Zero
+ if (exp == 0 && mant == 0) {
+ return static_cast(sign << 7);
+ }
+
+ // Extract biased exponent and mantissa for FP8
+ int e = static_cast(exp) - 127; // true exponent (IEEE bias 127)
+ uint32_t m = mant;
+
+ // Handle very large values → NaN (NVIDIA behavior for E4M3)
+ if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
+ return static_cast((sign << 7) | 0x7F);
+ }
+
+ // Handle subnormals and normal numbers
+ if (e < -6) { // smallest normal exponent is -6
+ // Subnormal in FP8: shift mantissa right
+ int shift = -6 - e;
+ m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
+ if (shift > 23) m = 0;
+ } else {
+ // Normal number: adjust exponent bias from 127 to 7
+ int new_exp = e + 7;
+ m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
+ m |= (static_cast(new_exp) << 3);
+ }
+
+ // Round-to-nearest-even (simple guard + round bit)
+ // For better accuracy you can add sticky bit, but this is sufficient for most use cases
+ uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
+ if (round_bit) {
+ m += 1;
+ // Carry into exponent if mantissa overflows
+ if ((m & 0x8u) != 0) {
+ m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
+ // If exponent overflows after carry → NaN
+ if ((m >> 3) > 14) {
+ return static_cast((sign << 7) | 0x7F);
+ }
+ }
+ }
+
+ uint8_t result = static_cast((sign << 7) | (m & 0x7F));
+ return result;
+}
+
+inline float e4m3_to_float(uint8_t x)
+{
+ if (x == 0) return 0.0f;
+
+ uint8_t sign = (x >> 7) & 0x1u;
+ uint8_t exp = (x >> 3) & 0xFu;
+ uint8_t mant = x & 0x7u;
+
+ // NaN (NVIDIA uses 0x7F / 0xFF as NaN)
+ if (exp == 0xF && mant != 0) {
+ return std::numeric_limits::quiet_NaN();
+ }
+ if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
+ return std::numeric_limits::quiet_NaN();
+ }
+
+ float val;
+
+ if (exp == 0) {
+ // Subnormal
+ val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
+ } else {
+ // Normal: implicit leading 1 + bias 7
+ val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast(exp) - 7.0f);
+ }
+
+ return sign ? -val : val;
+}
+
+// The actual type definition
+struct __nv_fp8_e4m3 {
+ uint8_t raw;
+
+ __nv_fp8_e4m3() = default;
+
+ explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
+ explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast(h))) {}
+
+ operator float() const { return e4m3_to_float(raw); }
+ operator sycl::half() const { return static_cast(static_cast(*this)); }
+
+ // Allow direct access for vector loads/stores
+ operator uint8_t&() { return raw; }
+ operator uint8_t() const { return raw; }
+};
+
+using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
+using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;
+
diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp
index 9a267d85a0c..eab9850aed7 100644
--- a/ggml/src/ggml-sycl/vecdotq.hpp
+++ b/ggml/src/ggml-sycl/vecdotq.hpp
@@ -15,6 +15,7 @@
#include "dpct/helper.hpp"
#include "ggml.h"
+#include "type.hpp"
#include "quants.hpp"
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
@@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
return x32;
}
+static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
+ const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
+
+ int x32 = x16[2*i32 + 0] << 0;
+ x32 |= x16[2*i32 + 1] << 16;
+
+ return x32;
+}
+
+static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
+ return ((const int *) x)[i32]; // assume at least 4 byte alignment
+}
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 =
@@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
return d * sumi;
}
+#define VDR_NVFP4_Q8_1_MMVQ 4
+#define VDR_NVFP4_Q8_1_MMQ 8
+
+static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
+ const block_q8_1 * __restrict__ bq8_1,
+ const int32_t & iqs) {
+ const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
+ float sum = 0.0f;
+#pragma unroll
+ for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
+ const int32_t iqs0 = iqs + 2*i;
+ const int32_t iqs1 = iqs0 + 1;
+ const int32_t is = iqs0 >> 1;
+ const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
+ const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
+ const block_q8_1 * bq8 = bq8_1 + (is >> 1);
+ const int32_t i8 = ((is & 1) << 2);
+
+ int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
+ sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
+ sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
+ sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
+
+ const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
+ sum += d * float(sumi);
+ }
+
+ return sum;
+}
static __dpct_inline__ float
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
diff --git a/models/templates/LFM2.5-Instruct.jinja b/models/templates/LFM2.5-Instruct.jinja
new file mode 100644
index 00000000000..7778756dd92
--- /dev/null
+++ b/models/templates/LFM2.5-Instruct.jinja
@@ -0,0 +1,45 @@
+{{- bos_token -}}
+{%- set keep_past_thinking = keep_past_thinking | default(false) -%}
+{%- set ns = namespace(system_prompt="") -%}
+{%- if messages[0]["role"] == "system" -%}
+ {%- set ns.system_prompt = messages[0]["content"] -%}
+ {%- set messages = messages[1:] -%}
+{%- endif -%}
+{%- if tools -%}
+ {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: [" -%}
+ {%- for tool in tools -%}
+ {%- if tool is not string -%}
+ {%- set tool = tool | tojson -%}
+ {%- endif -%}
+ {%- set ns.system_prompt = ns.system_prompt + tool -%}
+ {%- if not loop.last -%}
+ {%- set ns.system_prompt = ns.system_prompt + ", " -%}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- set ns.system_prompt = ns.system_prompt + "]" -%}
+{%- endif -%}
+{%- if ns.system_prompt -%}
+ {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
+{%- endif -%}
+{%- set ns.last_assistant_index = -1 -%}
+{%- for message in messages -%}
+ {%- if message["role"] == "assistant" -%}
+ {%- set ns.last_assistant_index = loop.index0 -%}
+ {%- endif -%}
+{%- endfor -%}
+{%- for message in messages -%}
+ {{- "<|im_start|>" + message["role"] + "\n" -}}
+ {%- set content = message["content"] -%}
+ {%- if content is not string -%}
+ {%- set content = content | tojson -%}
+ {%- endif -%}
+ {%- if message["role"] == "assistant" and not keep_past_thinking and loop.index0 != ns.last_assistant_index -%}
+ {%- if "" in content -%}
+ {%- set content = content.split("")[-1] | trim -%}
+ {%- endif -%}
+ {%- endif -%}
+ {{- content + "<|im_end|>\n" -}}
+{%- endfor -%}
+{%- if add_generation_prompt -%}
+ {{- "<|im_start|>assistant\n" -}}
+{%- endif -%}
\ No newline at end of file
diff --git a/scripts/hip/gcn-cdna-vgpr-check.py b/scripts/hip/gcn-cdna-vgpr-check.py
index 38db47d3d18..bbbce52ef39 100644
--- a/scripts/hip/gcn-cdna-vgpr-check.py
+++ b/scripts/hip/gcn-cdna-vgpr-check.py
@@ -139,7 +139,11 @@ def main():
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
- '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
+ '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+ '_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+ '_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+ '_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+ '_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
}
functions = parse_log_file(log_file)
diff --git a/scripts/server-test-function-call.py b/scripts/server-test-function-call.py
new file mode 100755
index 00000000000..b3aae1a961e
--- /dev/null
+++ b/scripts/server-test-function-call.py
@@ -0,0 +1,1135 @@
+#!/usr/bin/env python3
+"""
+Test tool calling capability via chat completions endpoint.
+
+Each test case contains:
+ - tools: list of tool definitions (OpenAI-compatible)
+ - messages: initial conversation messages
+ - mock_tool_responses: dict mapping tool_name -> callable(arguments) -> str (JSON)
+ - validate: callable(tool_calls_history, final_content) -> (passed: bool, reason: str)
+"""
+
+import argparse
+import json
+import requests
+import sys
+
+# ---------------------------------------------------------------------------
+# Color / formatting helpers
+# ---------------------------------------------------------------------------
+
+RESET = "\x1b[0m"
+BOLD = "\x1b[1m"
+DIM = "\x1b[2m"
+# Foreground colors
+CYAN = "\x1b[36m"
+YELLOW = "\x1b[33m"
+GREEN = "\x1b[32m"
+RED = "\x1b[31m"
+BLUE = "\x1b[34m"
+WHITE = "\x1b[97m"
+
+
+def _print(text="", end="\n"):
+ sys.stdout.write(text + end)
+ sys.stdout.flush()
+
+
+def print_header(title):
+ bar = "─" * 60
+ _print(f"\n{BOLD}{CYAN}┌{bar}┐{RESET}")
+ _print(
+ f"{BOLD}{CYAN}│ {WHITE}{title}{CYAN}{' ' * max(0, 58 - len(title))}│{RESET}"
+ )
+ _print(f"{BOLD}{CYAN}└{bar}┘{RESET}")
+
+
+def print_tool_call(name, args):
+ args_str = json.dumps(args)
+ _print(
+ f"\n {BOLD}{YELLOW}⚙ tool call{RESET} {CYAN}{name}{RESET}{DIM}({args_str}){RESET}"
+ )
+
+
+def print_tool_result(result):
+ preview = result[:160] + ("…" if len(result) > 160 else "")
+ _print(f" {DIM}{BLUE}↳ result{RESET} {DIM}{preview}{RESET}")
+
+
+def print_model_output(text):
+ # printed inline during streaming; prefix with a visual marker on first chunk
+ sys.stdout.write(text)
+ sys.stdout.flush()
+
+
+def print_pass(reason):
+ _print(f"\n{BOLD}{GREEN}✔ PASS{RESET} {reason}")
+
+
+def print_fail(reason):
+ _print(f"\n{BOLD}{RED}✘ FAIL{RESET} {reason}")
+
+
+def print_info(msg):
+ _print(f"{DIM}{msg}{RESET}")
+
+
+# ---------------------------------------------------------------------------
+# HTTP helpers
+# ---------------------------------------------------------------------------
+
+
+def chat_completion(url, messages, tools=None, stream=False):
+ payload = {
+ "messages": messages,
+ "stream": stream,
+ "max_tokens": 4096,
+ }
+ if tools:
+ payload["tools"] = tools
+ payload["tool_choice"] = "auto"
+
+ try:
+ response = requests.post(url, json=payload, stream=stream)
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ body = e.response.content if (e.response is not None) else b""
+ print_fail(f"Request error: {e} | body: {body}")
+ return None
+
+ full_content = ""
+ reasoning_content = ""
+ tool_calls: list[dict] = []
+
+ if stream:
+ for line in response.iter_lines():
+ if not line:
+ continue
+ decoded = line.decode("utf-8")
+ if not decoded.startswith("data: "):
+ continue
+ data_str = decoded[6:]
+ if data_str == "[DONE]":
+ break
+ try:
+ data = json.loads(data_str)
+ except json.JSONDecodeError:
+ continue
+ choices = data.get("choices", [])
+ if not choices:
+ continue
+ delta = choices[0].get("delta", {})
+ if delta.get("reasoning_content"):
+ reasoning_content += delta["reasoning_content"]
+ if delta.get("content"):
+ full_content += delta["content"]
+ print_model_output(delta["content"])
+ for tc in delta.get("tool_calls", []):
+ idx = tc.get("index", 0)
+ while len(tool_calls) <= idx:
+ tool_calls.append(
+ {
+ "id": "",
+ "type": "function",
+ "function": {"name": "", "arguments": ""},
+ }
+ )
+ if "id" in tc:
+ tool_calls[idx]["id"] += tc["id"]
+ if "function" in tc:
+ if "name" in tc["function"]:
+ tool_calls[idx]["function"]["name"] += tc["function"]["name"]
+ if "arguments" in tc["function"]:
+ tool_calls[idx]["function"]["arguments"] += tc["function"][
+ "arguments"
+ ]
+ else:
+ data = response.json()
+ choices = data.get("choices", [])
+ if choices:
+ msg = choices[0].get("message", {})
+ full_content = msg.get("content") or ""
+ reasoning_content = msg.get("reasoning_content") or ""
+ tool_calls = msg.get("tool_calls") or []
+ if full_content:
+ print_model_output(full_content)
+
+ result = {"content": full_content, "tool_calls": tool_calls}
+ if reasoning_content:
+ result["reasoning_content"] = reasoning_content
+ return result
+
+
+def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6):
+ """
+ Drive the multi-turn tool-call loop:
+ 1. Send messages to model.
+ 2. If the model returns tool calls, execute mocks and append results.
+ 3. Repeat until no more tool calls or max_turns reached.
+
+ Returns (all_tool_calls, final_content).
+ """
+ msgs = list(messages)
+ all_tool_calls: list[dict] = []
+
+ for _ in range(max_turns):
+ result = chat_completion(url, msgs, tools=tools, stream=stream)
+ if result is None:
+ return all_tool_calls, None
+
+ tcs = result.get("tool_calls") or []
+ content = result.get("content") or ""
+
+ if not tcs:
+ # Print a visual separator before the final model response
+ if content:
+ _print(f"\n{DIM}{'·'*60}{RESET}")
+ _print(f"{DIM} model response:{RESET}\n")
+ return all_tool_calls, content
+
+ # Record tool calls for validation
+ all_tool_calls.extend(tcs)
+
+ # Append assistant message with tool calls
+ assistant_msg: dict = {
+ "role": "assistant",
+ "content": content,
+ "tool_calls": tcs,
+ }
+ reasoning = result.get("reasoning_content")
+ if reasoning:
+ assistant_msg["reasoning_content"] = reasoning
+ msgs.append(assistant_msg)
+
+ # Execute each tool call via mock and append tool result messages
+ for tc in tcs:
+ tool_name = tc["function"]["name"]
+ try:
+ args = json.loads(tc["function"]["arguments"])
+ except json.JSONDecodeError:
+ args = {}
+
+ print_tool_call(tool_name, args)
+
+ mock_fn = mock_tool_responses.get(tool_name)
+ if mock_fn:
+ tool_result = mock_fn(args)
+ else:
+ tool_result = json.dumps({"error": f"Unknown tool: {tool_name}"})
+
+ print_tool_result(tool_result)
+
+ msgs.append(
+ {
+ "role": "tool",
+ "tool_call_id": tc.get("id", ""),
+ "content": tool_result,
+ }
+ )
+
+ return all_tool_calls, None
+
+
+# ---------------------------------------------------------------------------
+# Test case runner
+# ---------------------------------------------------------------------------
+
+
+def run_test(url, test_case, stream):
+ name = test_case["name"]
+ mode = f"{'stream' if stream else 'non-stream'}"
+ print_header(f"{name} [{mode}]")
+
+ all_tool_calls, final_content = run_agentic_loop(
+ url,
+ messages=test_case["messages"],
+ tools=test_case["tools"],
+ mock_tool_responses=test_case["mock_tool_responses"],
+ stream=stream,
+ )
+
+ if final_content is None and not all_tool_calls:
+ print_fail("No response from server.")
+ return False
+
+ passed, reason = test_case["validate"](all_tool_calls, final_content)
+ if passed:
+ print_pass(reason)
+ else:
+ print_fail(reason)
+ return passed
+
+
+# ---------------------------------------------------------------------------
+# Test case definitions
+# ---------------------------------------------------------------------------
+
+# ---- Test 1: E-commerce multi-step search (Azzoo = anonymized marketplace) ----
+
+_AZZOO_TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "azzoo_search_products",
+ "description": (
+ "Search for products on Azzoo marketplace by keyword. "
+ "Returns a list of matching products with IDs, titles, ratings and prices."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "query": {
+ "type": "string",
+ "description": "Search keyword or phrase",
+ },
+ "page": {
+ "type": "string",
+ "description": "Page number (1-based)",
+ "default": "1",
+ },
+ },
+ "required": ["query"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "azzoo_get_product",
+ "description": "Retrieve detailed information about a specific Azzoo product including specs and price.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "product_id": {
+ "type": "string",
+ "description": "Azzoo product identifier (e.g. AZB12345)",
+ },
+ },
+ "required": ["product_id"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "azzoo_get_reviews",
+ "description": "Fetch customer reviews for an Azzoo product.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "product_id": {
+ "type": "string",
+ "description": "Azzoo product identifier",
+ },
+ "page": {
+ "type": "string",
+ "description": "Review page number",
+ "default": "1",
+ },
+ },
+ "required": ["product_id"],
+ },
+ },
+ },
+]
+
+_AZZOO_SEARCH_RESULT = {
+ "results": [
+ {
+ "product_id": "AZB00001",
+ "title": "SteelBrew Pro Kettle 1.7L",
+ "rating": 4.6,
+ "price": 34.99,
+ },
+ {
+ "product_id": "AZB00002",
+ "title": "HeatKeep Gooseneck Kettle",
+ "rating": 4.3,
+ "price": 27.50,
+ },
+ {
+ "product_id": "AZB00003",
+ "title": "QuickBoil Stainless Kettle",
+ "rating": 4.1,
+ "price": 21.00,
+ },
+ ]
+}
+_AZZOO_PRODUCT_RESULT = {
+ "product_id": "AZB00001",
+ "title": "SteelBrew Pro Kettle 1.7L",
+ "price": 34.99,
+ "rating": 4.6,
+ "review_count": 2847,
+ "specs": {
+ "material": "18/8 stainless steel",
+ "capacity": "1.7 L",
+ "auto_shutoff": True,
+ "keep_warm": "30 min",
+ "warranty": "2 years",
+ },
+}
+_AZZOO_REVIEWS_RESULT = {
+ "product_id": "AZB00001",
+ "average_rating": 4.6,
+ "reviews": [
+ {
+ "rating": 5,
+ "title": "Excellent build quality",
+ "body": "Very sturdy, boils fast and stays warm longer than expected.",
+ },
+ {
+ "rating": 5,
+ "title": "Great for loose-leaf tea",
+ "body": "The wide spout makes filling a teapot easy. No leaks after months of use.",
+ },
+ {
+ "rating": 3,
+ "title": "Minor lid issue",
+ "body": "The lid doesn't always click shut properly, but overall happy with it.",
+ },
+ {
+ "rating": 4,
+ "title": "Good value",
+ "body": "Heats quickly and the auto shutoff works reliably.",
+ },
+ ],
+}
+
+AZZOO_TEST_CASE = {
+ "name": "Azzoo E-commerce: search -> product detail -> reviews",
+ "messages": [
+ {
+ "role": "user",
+ "content": (
+ "I need a durable stainless steel tea kettle for my weekly tea gatherings. "
+ "Please search Azzoo for 'stainless steel tea kettle', then get full details "
+ "on the top-rated result, and finally fetch its customer reviews so I can "
+ "check for recurring complaints. Give me a summary with pros and cons."
+ ),
+ }
+ ],
+ "tools": _AZZOO_TOOLS,
+ "mock_tool_responses": {
+ "azzoo_search_products": lambda _: json.dumps(_AZZOO_SEARCH_RESULT),
+ "azzoo_get_product": lambda _: json.dumps(_AZZOO_PRODUCT_RESULT),
+ "azzoo_get_reviews": lambda _: json.dumps(_AZZOO_REVIEWS_RESULT),
+ },
+ "validate": lambda tcs, content: _validate_azzoo(tcs, content),
+}
+
+
+def _validate_azzoo(tcs, content):
+ names = [tc["function"]["name"] for tc in tcs]
+ if not names:
+ return False, "No tool calls made"
+ if "azzoo_search_products" not in names:
+ return False, f"Expected azzoo_search_products to be called, got: {names}"
+ # After search the model should look up product details
+ if "azzoo_get_product" not in names and "azzoo_get_reviews" not in names:
+ return False, f"Expected follow-up product/review lookup, got: {names}"
+ # Verify product lookup used an ID from search results
+ for tc in tcs:
+ if tc["function"]["name"] == "azzoo_get_product":
+ try:
+ args = json.loads(tc["function"]["arguments"])
+ pid = args.get("product_id", "")
+ if not pid:
+ return False, "azzoo_get_product called with empty product_id"
+ except json.JSONDecodeError:
+ return False, "azzoo_get_product arguments are not valid JSON"
+ if not content:
+ return False, "No final summary produced"
+ return True, f"All expected tools called in order: {names}"
+
+
+# ---- Test 2: Fitness BMI + exercise recommendations ----
+
+_FITNESS_TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "calculate_bmi",
+ "description": "Calculate Body Mass Index (BMI) from weight and height.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "weight_kg": {
+ "type": "number",
+ "description": "Body weight in kilograms",
+ },
+ "height_m": {"type": "number", "description": "Height in meters"},
+ },
+ "required": ["weight_kg", "height_m"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_exercises",
+ "description": (
+ "Fetch a list of exercises filtered by muscle group, difficulty, category, "
+ "and/or force type."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "muscle": {
+ "type": "string",
+ "description": "Target muscle group (e.g. chest, back, legs)",
+ },
+ "difficulty": {
+ "type": "string",
+ "description": "Difficulty level: beginner, intermediate, expert",
+ },
+ "category": {
+ "type": "string",
+ "description": "Exercise category (e.g. strength, cardio, stretching)",
+ },
+ "force": {
+ "type": "string",
+ "description": "Force type: push, pull, static",
+ },
+ },
+ "required": [],
+ },
+ },
+ },
+]
+
+_BMI_RESULT = {"bmi": 24.5, "category": "Normal weight", "healthy_range": "18.5 – 24.9"}
+_EXERCISES_RESULT = {
+ "exercises": [
+ {
+ "name": "Push-Up",
+ "muscle": "chest",
+ "difficulty": "beginner",
+ "equipment": "none",
+ "instructions": "Keep body straight, lower chest to floor.",
+ },
+ {
+ "name": "Incline Dumbbell Press",
+ "muscle": "chest",
+ "difficulty": "beginner",
+ "equipment": "dumbbells, bench",
+ "instructions": "Press dumbbells up from chest on incline bench.",
+ },
+ {
+ "name": "Chest Fly (cables)",
+ "muscle": "chest",
+ "difficulty": "beginner",
+ "equipment": "cable machine",
+ "instructions": "Bring cables together in an arc motion.",
+ },
+ ]
+}
+
+FITNESS_TEST_CASE = {
+ "name": "Fitness: BMI calculation + exercise suggestions",
+ "messages": [
+ {
+ "role": "user",
+ "content": (
+ "I'm a 32-year-old male, 78 kg and 1.80 m tall. "
+ "Please calculate my BMI and then suggest some beginner chest exercises I can do "
+ "to build strength. Give me a short personalised plan."
+ ),
+ }
+ ],
+ "tools": _FITNESS_TOOLS,
+ "mock_tool_responses": {
+ "calculate_bmi": lambda _: json.dumps(_BMI_RESULT),
+ "get_exercises": lambda _: json.dumps(_EXERCISES_RESULT),
+ },
+ "validate": lambda tcs, content: _validate_fitness(tcs, content),
+}
+
+
+def _validate_fitness(tcs, content):
+ names = [tc["function"]["name"] for tc in tcs]
+ if not names:
+ return False, "No tool calls made"
+ if "calculate_bmi" not in names:
+ return False, f"Expected calculate_bmi to be called, got: {names}"
+ # Validate BMI args contain plausible values
+ for tc in tcs:
+ if tc["function"]["name"] == "calculate_bmi":
+ try:
+ args = json.loads(tc["function"]["arguments"])
+ w = args.get("weight_kg")
+ h = args.get("height_m")
+ if w is None or h is None:
+ return False, f"calculate_bmi missing weight_kg or height_m: {args}"
+ if not (50 <= float(w) <= 200):
+ return False, f"calculate_bmi weight out of plausible range: {w}"
+ if not (1.0 <= float(h) <= 2.5):
+ return False, f"calculate_bmi height out of plausible range: {h}"
+ except (json.JSONDecodeError, ValueError) as e:
+ return False, f"calculate_bmi argument error: {e}"
+ if not content:
+ return False, "No final plan produced"
+ return True, f"Tools called: {names}"
+
+
+# ---- Test 3: Community class planning (anonymised cooking/topic discovery) ----
+
+_COMMUNITY_TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_trending_questions",
+ "description": (
+ "Fetch commonly asked questions on a topic from search engine 'People Also Ask' boxes."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string", "description": "Topic to search for"},
+ "max_results": {
+ "type": "integer",
+ "description": "Maximum questions to return",
+ "default": 10,
+ },
+ },
+ "required": ["query"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "search_mobile_apps",
+ "description": "Search the mobile app store for apps matching a category or keyword.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "keyword": {
+ "type": "string",
+ "description": "Search keyword (e.g. 'Italian cooking')",
+ },
+ "platform": {
+ "type": "string",
+ "enum": ["ios", "android", "both"],
+ "default": "both",
+ },
+ "max_results": {
+ "type": "integer",
+ "description": "Number of results",
+ "default": 10,
+ },
+ },
+ "required": ["keyword"],
+ },
+ },
+ },
+]
+
+_TRENDING_QUESTIONS_RESULT = {
+ "query": "Italian cuisine",
+ "questions": [
+ "What are the most popular Italian dishes?",
+ "What makes Italian food different from other cuisines?",
+ "How do you make authentic Italian pasta from scratch?",
+ "What are traditional Italian desserts?",
+ "What herbs are commonly used in Italian cooking?",
+ "Is Italian food healthy?",
+ "What wine pairs best with Italian pasta?",
+ ],
+}
+_APPS_RESULT = {
+ "keyword": "Italian cooking",
+ "results": [
+ {
+ "name": "PastaPro",
+ "rating": 4.5,
+ "installs": "500K+",
+ "focus": "pasta recipes only",
+ },
+ {
+ "name": "CookEasy",
+ "rating": 4.2,
+ "installs": "1M+",
+ "focus": "general cooking, limited Italian content",
+ },
+ {
+ "name": "ItalianKitchen",
+ "rating": 3.8,
+ "installs": "100K+",
+ "focus": "regional Italian recipes, no video",
+ },
+ ],
+}
+
+COMMUNITY_CLASS_TEST_CASE = {
+ "name": "Community class planning: trending topics + app gap analysis",
+ "messages": [
+ {
+ "role": "user",
+ "content": (
+ "I want to start teaching Italian cooking classes at my community centre. "
+ "First, find out what people commonly ask about Italian cuisine online. "
+ "Then search for existing Italian cooking apps to see what they cover. "
+ "Use both results to suggest three unique angles for my classes that fill gaps "
+ "in what apps already offer."
+ ),
+ }
+ ],
+ "tools": _COMMUNITY_TOOLS,
+ "mock_tool_responses": {
+ "get_trending_questions": lambda _: json.dumps(_TRENDING_QUESTIONS_RESULT),
+ "search_mobile_apps": lambda _: json.dumps(_APPS_RESULT),
+ },
+ "validate": lambda tcs, content: _validate_community(tcs, content),
+}
+
+
+def _validate_community(tcs, content):
+ names = [tc["function"]["name"] for tc in tcs]
+ if not names:
+ return False, "No tool calls made"
+ missing = [
+ t for t in ("get_trending_questions", "search_mobile_apps") if t not in names
+ ]
+ if missing:
+ return False, f"Missing expected tool calls: {missing}; got: {names}"
+ if not content:
+ return False, "No class suggestion produced"
+ return True, f"Both discovery tools called: {names}"
+
+
+# ---- Test 4: Multi-hostname geolocation filter (anonymized gallery discovery) ----
+# Inspired by: checking gallery website server locations to find truly remote venues.
+# Anonymized: galleryone.de → halle-eins.de, gallerytwo.fr → galerie-deux.fr,
+# gallerythree.it → galleria-tre.it
+
+_GEO_TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "lookup_ip_geolocation",
+ "description": (
+ "Retrieve geolocation data for an IP address or hostname, including country, "
+ "city, coordinates, and network info. Useful for verifying physical server "
+ "locations or personalising regional content."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "host": {
+ "type": "string",
+ "description": "IP address or hostname to look up (e.g. '8.8.8.8' or 'example.com').",
+ },
+ },
+ "required": ["host"],
+ },
+ },
+ },
+]
+
+# Mock: one urban (Berlin → discard), two rural (keep)
+_GEO_RESPONSES = {
+ "halle-eins.de": {
+ "host": "halle-eins.de",
+ "city": "Berlin",
+ "country": "DE",
+ "lat": 52.5200,
+ "lon": 13.4050,
+ "is_major_city": True,
+ },
+ "galerie-deux.fr": {
+ "host": "galerie-deux.fr",
+ "city": "Rocamadour",
+ "country": "FR",
+ "lat": 44.7994,
+ "lon": 1.6178,
+ "is_major_city": False,
+ },
+ "galleria-tre.it": {
+ "host": "galleria-tre.it",
+ "city": "Matera",
+ "country": "IT",
+ "lat": 40.6664,
+ "lon": 16.6044,
+ "is_major_city": False,
+ },
+}
+
+
+def _geo_mock(args):
+ host = args.get("host", "")
+ return json.dumps(_GEO_RESPONSES.get(host, {"error": f"unknown host: {host}"}))
+
+
+GEO_TEST_CASE = {
+ "name": "Gallery geolocation: filter urban venues, keep remote ones",
+ "messages": [
+ {
+ "role": "user",
+ "content": (
+ "I have abstract paintings to exhibit in remote European galleries. "
+ "I received enquiries from three venues: halle-eins.de, galerie-deux.fr, "
+ "and galleria-tre.it. Please look up the geolocation of each website's server. "
+ "Discard any venue whose server is in a major city (e.g. Berlin, Paris, Rome). "
+ "For the remaining venues, report their exact coordinates so I can check "
+ "whether hiking trails are nearby — my work thrives where nature and art meet."
+ ),
+ }
+ ],
+ "tools": _GEO_TOOLS,
+ "mock_tool_responses": {
+ "lookup_ip_geolocation": _geo_mock,
+ },
+ "validate": lambda tcs, content: _validate_geo(tcs, content),
+}
+
+
+def _validate_geo(tcs, content):
+ names = [tc["function"]["name"] for tc in tcs]
+ if not names:
+ return False, "No tool calls made"
+ # Expect exactly one geolocation call per domain (3 total)
+ geo_calls = [tc for tc in tcs if tc["function"]["name"] == "lookup_ip_geolocation"]
+ if len(geo_calls) < 3:
+ return (
+ False,
+ f"Expected geolocation called 3 times (once per domain), got {len(geo_calls)}",
+ )
+ queried_hosts = set()
+ for tc in geo_calls:
+ try:
+ args = json.loads(tc["function"]["arguments"])
+ host = args.get("host", "")
+ if not host:
+ return False, f"lookup_ip_geolocation called with empty host: {args}"
+ queried_hosts.add(host)
+ except json.JSONDecodeError:
+ return False, "lookup_ip_geolocation arguments are not valid JSON"
+ expected = {"halle-eins.de", "galerie-deux.fr", "galleria-tre.it"}
+ if not expected.issubset(queried_hosts):
+ return (
+ False,
+ f"Not all domains queried. Expected {expected}, got {queried_hosts}",
+ )
+ if not content:
+ return False, "No final summary produced"
+ return True, f"All 3 domains geolocated: {sorted(queried_hosts)}"
+
+
+# ---- Test 5: EV fleet expansion — stock → security → property → video ----
+# Inspired by: multi-step business analysis combining finance, cybersecurity,
+# real estate and educational content.
+# Anonymized: Tesla → Voltara (VLTR), Rivian → Rivex (RVXN),
+# Trenton → Halverton
+
+_EV_TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_stock_quote",
+ "description": "Retrieve the latest market quote for a financial instrument by ticker symbol.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "symbol": {
+ "type": "string",
+ "description": "Ticker symbol (e.g. 'VLTR', 'RVXN')",
+ },
+ "interval": {
+ "type": "string",
+ "description": "Time interval: 1min, 5min, 1h, 1day, 1week",
+ "default": "1day",
+ },
+ },
+ "required": ["symbol"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_security_advisories",
+ "description": (
+ "Fetch current cybersecurity advisories from the national security agency, "
+ "covering known vulnerabilities and exploits for industrial and consumer systems."
+ ),
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "keyword": {
+ "type": "string",
+ "description": "Filter advisories by keyword or product name",
+ },
+ "limit": {
+ "type": "integer",
+ "description": "Maximum number of advisories to return",
+ "default": 5,
+ },
+ },
+ "required": [],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "search_commercial_properties",
+ "description": "Search for commercial properties (offices, garages, warehouses) available for rent or sale in a given city.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {"type": "string", "description": "City name to search in"},
+ "property_type": {
+ "type": "string",
+ "description": "Type of property: office, garage, warehouse, premises",
+ },
+ "operation": {
+ "type": "string",
+ "enum": ["rent", "sale"],
+ "default": "rent",
+ },
+ "max_price": {
+ "type": "integer",
+ "description": "Maximum monthly rent or sale price",
+ },
+ },
+ "required": ["city", "property_type"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_video_recommendations",
+ "description": "Fetch a list of recommended videos related to a given topic or reference video.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "topic": {
+ "type": "string",
+ "description": "Topic or keyword to search for related videos",
+ },
+ },
+ "required": ["topic"],
+ },
+ },
+ },
+]
+
+_STOCK_RESULT_VLTR = {
+ "symbol": "VLTR",
+ "company": "Voltara Inc.",
+ "price": 218.45,
+ "change_pct": "+2.3%",
+ "market_cap": "694B",
+ "currency": "USD",
+}
+_STOCK_RESULT_RVXN = {
+ "symbol": "RVXN",
+ "company": "Rivex Motors",
+ "price": 12.80,
+ "change_pct": "-1.1%",
+ "market_cap": "11B",
+ "currency": "USD",
+}
+_ADVISORIES_RESULT = {
+ "count": 2,
+ "advisories": [
+ {
+ "id": "ICSA-24-102-01",
+ "title": "Voltara In-Vehicle Infotainment System Authentication Bypass",
+ "severity": "Medium",
+ "summary": "Improper authentication in the OTA update module may allow an adjacent attacker to install unsigned firmware.",
+ "published": "2024-04-11",
+ },
+ {
+ "id": "ICSA-24-085-03",
+ "title": "Voltara Charging Management API Input Validation Flaw",
+ "severity": "Low",
+ "summary": "Insufficient input validation in the charging session API could expose internal error messages.",
+ "published": "2024-03-26",
+ },
+ ],
+}
+_PROPERTIES_RESULT = {
+ "city": "Halverton",
+ "listings": [
+ {
+ "id": "HV-0041",
+ "type": "garage",
+ "area_sqm": 420,
+ "monthly_rent": 2800,
+ "ev_power_outlets": 12,
+ "address": "14 Ironworks Lane, Halverton",
+ },
+ {
+ "id": "HV-0089",
+ "type": "warehouse",
+ "area_sqm": 900,
+ "monthly_rent": 4200,
+ "ev_power_outlets": 30,
+ "address": "7 Depot Road, Halverton",
+ },
+ ],
+}
+_VIDEOS_RESULT = {
+ "topic": "fleet electrification",
+ "recommendations": [
+ {
+ "title": "How to Build an EV Fleet from Scratch",
+ "channel": "Fleet Future",
+ "views": "182K",
+ },
+ {
+ "title": "EV Charging Infrastructure for Commercial Fleets",
+ "channel": "GreenDrive Pro",
+ "views": "94K",
+ },
+ {
+ "title": "Total Cost of Ownership: Electric vs Diesel Vans",
+ "channel": "LogisticsTech",
+ "views": "61K",
+ },
+ ],
+}
+
+
+def _ev_stock_mock(args):
+ symbol = args.get("symbol", "").upper()
+ if symbol == "VLTR":
+ return json.dumps(_STOCK_RESULT_VLTR)
+ if symbol == "RVXN":
+ return json.dumps(_STOCK_RESULT_RVXN)
+ return json.dumps({"error": f"Unknown symbol: {symbol}"})
+
+
+EV_FLEET_TEST_CASE = {
+ "name": "EV fleet expansion: stock → cybersecurity → property → videos",
+ "messages": [
+ {
+ "role": "user",
+ "content": (
+ "I'm expanding my courier business into electric vehicles and need a multi-step analysis:\n"
+ "1. Get the latest stock quote for Voltara (VLTR) and Rivex (RVXN). "
+ "If either is above $50, continue with that company.\n"
+ "2. Search for cybersecurity advisories related to that company's vehicle models "
+ "to understand any tech risks.\n"
+ "3. Find commercial garage or warehouse properties in Halverton suitable for "
+ "EV charging infrastructure.\n"
+ "4. Recommend videos on fleet electrification strategies.\n"
+ "Please work through all four steps and give me a concise summary."
+ ),
+ }
+ ],
+ "tools": _EV_TOOLS,
+ "mock_tool_responses": {
+ "get_stock_quote": _ev_stock_mock,
+ "get_security_advisories": lambda _: json.dumps(_ADVISORIES_RESULT),
+ "search_commercial_properties": lambda _: json.dumps(_PROPERTIES_RESULT),
+ "get_video_recommendations": lambda _: json.dumps(_VIDEOS_RESULT),
+ },
+ "validate": lambda tcs, content: _validate_ev(tcs, content),
+}
+
+
+def _validate_ev(tcs, content):
+ names = [tc["function"]["name"] for tc in tcs]
+ if not names:
+ return False, "No tool calls made"
+ # Stock quote must come first
+ if names[0] != "get_stock_quote":
+ return False, f"Expected get_stock_quote to be called first, got: {names[0]}"
+ stock_calls = [tc for tc in tcs if tc["function"]["name"] == "get_stock_quote"]
+ for tc in stock_calls:
+ try:
+ args = json.loads(tc["function"]["arguments"])
+ sym = args.get("symbol", "")
+ if not sym:
+ return False, f"get_stock_quote called with empty symbol: {args}"
+ except json.JSONDecodeError:
+ return False, "get_stock_quote arguments are not valid JSON"
+ # All four pipeline tools expected
+ required = [
+ "get_stock_quote",
+ "get_security_advisories",
+ "search_commercial_properties",
+ "get_video_recommendations",
+ ]
+ missing = [t for t in required if t not in names]
+ if missing:
+ return False, f"Missing pipeline steps: {missing}"
+ if not content:
+ return False, "No final summary produced"
+ return True, f"Full 4-step pipeline executed: {names}"
+
+
+# ---------------------------------------------------------------------------
+# All test cases
+# ---------------------------------------------------------------------------
+
+ALL_TEST_CASES = [
+ AZZOO_TEST_CASE,
+ FITNESS_TEST_CASE,
+ COMMUNITY_CLASS_TEST_CASE,
+ GEO_TEST_CASE,
+ EV_FLEET_TEST_CASE,
+]
+
+
+# ---------------------------------------------------------------------------
+# Entry point
+# ---------------------------------------------------------------------------
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Test llama-server tool-calling capability."
+ )
+ parser.add_argument("--host", default="localhost")
+ parser.add_argument("--port", default=8080, type=int)
+ parser.add_argument(
+ "--no-stream", action="store_true", help="Disable streaming mode tests"
+ )
+ parser.add_argument(
+ "--stream-only", action="store_true", help="Only run streaming mode tests"
+ )
+ parser.add_argument(
+ "--test",
+ help="Run only the test whose name contains this substring (case-insensitive)",
+ )
+ args = parser.parse_args()
+
+ url = f"http://{args.host}:{args.port}/v1/chat/completions"
+ print_info(f"Testing server at {url}")
+
+ modes = []
+ if not args.stream_only:
+ modes.append(False)
+ if not args.no_stream:
+ modes.append(True)
+
+ cases: list[dict] = ALL_TEST_CASES
+ if args.test:
+ name_filter = args.test.lower()
+ cases = [c for c in cases if name_filter in str(c["name"]).lower()]
+ if not cases:
+ print_fail(f"No test cases matched '{args.test}'")
+ sys.exit(1)
+
+ total = 0
+ passed = 0
+ for stream in modes:
+ for case in cases:
+ total += 1
+ if run_test(url, case, stream=stream):
+ passed += 1
+
+ color = GREEN if passed == total else RED
+ _print(f"\n{BOLD}{color}{'─'*60}{RESET}")
+ _print(f"{BOLD}{color} Results: {passed}/{total} passed{RESET}")
+ _print(f"{BOLD}{color}{'─'*60}{RESET}\n")
+ sys.exit(0 if passed == total else 1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last
index 7af38c3eb73..abe4128c8f2 100644
--- a/scripts/sync-ggml.last
+++ b/scripts/sync-ggml.last
@@ -1 +1 @@
-a04eea0761a85d18f3f504d6ab970c5c9dce705f
+50634c28837c24ac68b380b5750b41e701c87d73
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index c2833b75ced..0e7d96ca10d 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -19,7 +19,7 @@
// dedup helpers
-static ggml_tensor * build_kq_mask(
+static ggml_tensor * build_attn_inp_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
@@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
- return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+ ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+ ggml_set_input(res);
+ ggml_set_name(res, "attn_inp_kq_mask");
+
+ return res;
}
static bool can_reuse_kq_mask(
@@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
// impl
+static ggml_tensor * ggml_mul_mat_aux(
+ ggml_context * ctx,
+ ggml_tensor * cur,
+ ggml_tensor * rot) {
+ const auto n = rot->ne[0];
+
+ ggml_tensor * res;
+
+ res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
+ res = ggml_mul_mat (ctx, rot, res);
+ res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
+
+ return res;
+}
+
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+
+ if (self_k_rot) {
+ mctx->set_input_k_rot(self_k_rot);
+ }
+
+ if (self_v_rot) {
+ mctx->set_input_v_rot(self_v_rot);
+ }
}
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
@@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
+
+ if (self_k_rot) {
+ mctx->get_base()->set_input_k_rot(self_k_rot);
+ }
+
+ if (self_v_rot) {
+ mctx->get_base()->set_input_v_rot(self_v_rot);
+ }
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
+ if (inp_attn->self_k_rot) {
+ mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
+ }
+
+ if (inp_attn->self_v_rot) {
+ mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
+ }
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
+ if (inp_attn->self_k_rot) {
+ attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
+ }
+
+ if (inp_attn->self_v_rot) {
+ attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
+ }
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@@ -2002,13 +2053,13 @@ static std::unique_ptr build_attn_inp_kv_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
-
- ggml_set_input(inp->self_kq_mask);
-
+ inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
+ inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
+ inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
+
return inp;
}
@@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
int il) const {
GGML_ASSERT(v_mla == nullptr);
+ if (inp->self_k_rot) {
+ q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
+ k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
+ }
+
+ if (inp->self_v_rot) {
+ v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
+ }
+
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
@@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
+ if (inp->self_v_rot) {
+ cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
+ }
+
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
@@ -2090,9 +2154,7 @@ static std::unique_ptr build_attn_inp_k_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
- ggml_set_input(inp->self_kq_mask);
-
+ inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
@@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
+ if (inp->self_k_rot) {
+ q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
+ if (k_cur) {
+ k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
+ }
+ }
+ if (inp->self_v_rot) {
+ if (v_cur) {
+ v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
+ }
+ }
+
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
@@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
+ if (inp->self_v_rot) {
+ cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
+ }
+
if (wo) {
cur = build_lora_mm(wo, cur);
}
@@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
- inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
- ggml_set_input(inp->self_kq_mask);
- ggml_set_name(inp->self_kq_mask, "self_kq_mask");
-
+ inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
- ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
@@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
- inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
- ggml_set_input(inp->self_kq_mask_swa);
- ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
-
+ inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
- ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
+ inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
+ inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
+
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
@@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
- inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
- ggml_set_input(inp_attn->self_kq_mask);
-
+ inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
@@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
- inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
- ggml_set_input(inp_attn->self_kq_mask_swa);
-
+ inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}
diff --git a/src/llama-graph.h b/src/llama-graph.h
index 4855685ef71..bb0ad75198f 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -308,6 +308,10 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
+ // note: assumes v_rot^ == I
+ ggml_tensor * self_k_rot = nullptr;
+ ggml_tensor * self_v_rot = nullptr;
+
// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
@@ -384,6 +388,10 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
+ // note: using same rotation matrices for both base and swa cache
+ ggml_tensor * self_k_rot = nullptr;
+ ggml_tensor * self_v_rot = nullptr;
+
const llama_hparams hparams;
const llama_cparams cparams;
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index 5f57ba9e1d8..3e0fd3107f3 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -13,6 +13,65 @@
#include