diff --git a/sdk/cpp/.clang-format b/sdk/cpp/.clang-format new file mode 100644 index 00000000..751f30aa --- /dev/null +++ b/sdk/cpp/.clang-format @@ -0,0 +1,47 @@ +--- +Language: Cpp +BasedOnStyle: Microsoft + +# Match the existing project style +Standard: c++17 +ColumnLimit: 120 + +# Indentation +IndentWidth: 4 +TabWidth: 4 +UseTab: Never +AccessModifierOffset: -4 +IndentCaseLabels: false +NamespaceIndentation: All + +# Braces +BreakBeforeBraces: Custom +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterStruct: false + BeforeCatch: true + BeforeElse: true + IndentBraces: false + +# Alignment +AlignAfterOpenBracket: Align +AlignOperands: Align +AlignTrailingComments: true + +# Includes +SortIncludes: false +IncludeBlocks: Preserve + +# Misc +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AllowShortBlocksOnASingleLine: Empty +PointerAlignment: Left +SpaceAfterCStyleCast: false +SpaceBeforeParens: ControlStatements diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt new file mode 100644 index 00000000..064c46ca --- /dev/null +++ b/sdk/cpp/CMakeLists.txt @@ -0,0 +1,155 @@ +cmake_minimum_required(VERSION 3.20) + +# VS hot reload policy (safe-guarded) +if (POLICY CMP0141) + cmake_policy(SET CMP0141 NEW) + if (MSVC) + set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT + "$<$:ProgramDatabase>") + endif() +endif() + +project(CppSdk LANGUAGES CXX) + +# ----------------------------- +# Windows-only + compiler guard +# ----------------------------- +if (NOT WIN32) + message(FATAL_ERROR "CppSdk is Windows-only for now (uses Win32/WIL headers).") +endif() + +# Accept MSVC OR clang-cl (Clang in MSVC compatibility mode). +# VS CMake Open-Folder often uses clang-cl by default. +if (NOT (MSVC OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC"))) + message(STATUS "CMAKE_CXX_COMPILER_ID = ${CMAKE_CXX_COMPILER_ID}") + message(STATUS "CMAKE_CXX_COMPILER = ${CMAKE_CXX_COMPILER}") + message(STATUS "CMAKE_CXX_SIMULATE_ID = ${CMAKE_CXX_SIMULATE_ID}") + message(FATAL_ERROR "Need MSVC or clang-cl (MSVC-compatible toolchain).") +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Optional: target Windows 10+ APIs (adjust if you need older) +add_compile_definitions(_WIN32_WINNT=0x0A00 WINVER=0x0A00) + +include(FetchContent) + +# ----------------------------- +# nlohmann_json (clean CMake target) +# ----------------------------- +FetchContent_Declare( + nlohmann_json + GIT_REPOSITORY https://github.com/nlohmann/json.git + GIT_TAG v3.12.0 +) +FetchContent_MakeAvailable(nlohmann_json) + +# ----------------------------- +# WIL (download headers only; DO NOT run WIL's CMake) +# This avoids NuGet/test requirements and missing wil::wil targets. +# ----------------------------- +FetchContent_Declare( + wil_src + GIT_REPOSITORY https://github.com/microsoft/wil.git + GIT_TAG v1.0.250325.1 +) +FetchContent_Populate(wil_src) + +# ----------------------------- +# Microsoft GSL (Guidelines Support Library) +# Provides gsl::span for C++17 (std::span is C++20) +# ----------------------------- +FetchContent_Declare( + gsl + GIT_REPOSITORY https://github.com/microsoft/GSL.git + GIT_TAG v4.0.0 +) +FetchContent_MakeAvailable(gsl) + +# ----------------------------- +# Google Test (for unit tests) +# ----------------------------- +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +# Prevent GoogleTest from overriding our compiler/linker options on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +# ----------------------------- +# SDK library (STATIC) +# List ONLY .cpp files here. +# ----------------------------- +add_library(CppSdk STATIC + src/foundry_local.cpp + # Add more .cpp files as you migrate: + # src/parser.cpp + # src/dllmain.cpp + # src/pch.cpp +) + +target_include_directories(CppSdk + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${wil_src_SOURCE_DIR}/include +) + +target_link_libraries(CppSdk + PUBLIC + nlohmann_json::nlohmann_json + Microsoft.GSL::GSL +) + +# ----------------------------- +# Sample executable +# ----------------------------- +add_executable(CppSdkSample + sample/main.cpp +) + +target_link_libraries(CppSdkSample PRIVATE CppSdk) + +# ----------------------------- +# Unit tests +# ----------------------------- +enable_testing() + +add_executable(CppSdkTests + test/parser_and_types_test.cpp + test/model_variant_test.cpp + test/catalog_test.cpp + test/client_test.cpp +) + +target_include_directories(CppSdkTests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test +) + +target_compile_definitions(CppSdkTests PRIVATE FL_TESTS) + +target_link_libraries(CppSdkTests + PRIVATE + CppSdk + GTest::gtest_main +) + +# Copy testdata files next to the test executable so file-based tests can find them. +add_custom_command(TARGET CppSdkTests POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/test/testdata + $/testdata +) + +include(GoogleTest) +gtest_discover_tests(CppSdkTests + WORKING_DIRECTORY $ +) + +# Make Visual Studio start/debug this target by default +set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + PROPERTY VS_STARTUP_PROJECT CppSdkSample) diff --git a/sdk/cpp/CMakePresets.json b/sdk/cpp/CMakePresets.json new file mode 100644 index 00000000..3defcc5c --- /dev/null +++ b/sdk/cpp/CMakePresets.json @@ -0,0 +1,99 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "windows-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_C_COMPILER": "cl.exe", + "CMAKE_CXX_COMPILER": "cl.exe" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "x64-debug", + "displayName": "MSVC x64 Debug", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + }, + { + "name": "x64-release", + "displayName": "MSVC x64 Release", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release" + } + }, + { + "name": "x86-debug", + "displayName": "MSVC x86 Debug", + "inherits": "windows-base", + "architecture": { + "value": "x86", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + }, + { + "name": "x86-release", + "displayName": "MSVC x86 Release", + "inherits": "windows-base", + "architecture": { + "value": "x86", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release" + } + } + ], + "buildPresets": [ + { + "name": "x64-debug", + "configurePreset": "x64-debug", + "displayName": "MSVC x64 Debug Build" + }, + { + "name": "x64-release", + "configurePreset": "x64-release", + "displayName": "MSVC x64 Release Build" + } + ], + "testPresets": [ + { + "name": "x64-debug", + "configurePreset": "x64-debug", + "displayName": "MSVC x64 Debug Tests", + "output": { + "outputOnFailure": true + } + }, + { + "name": "x64-release", + "configurePreset": "x64-release", + "displayName": "MSVC x64 Release Tests", + "output": { + "outputOnFailure": true + } + } + ] +} diff --git a/sdk/cpp/include/configuration.h b/sdk/cpp/include/configuration.h new file mode 100644 index 00000000..59fe63e3 --- /dev/null +++ b/sdk/cpp/include/configuration.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include +#include +#include +#include +#include "log_level.h" + +namespace FoundryLocal { + + /// Optional configuration for the built-in web service. + struct WebServiceConfig { + // URL/s to bind the web service to. + // Default: 127.0.0.1:0 (random ephemeral port). + // Multiple URLs can be specified as a semicolon-separated list. + std::optional urls; + + // If the web service is running in a separate process, provide its URL here. + std::optional external_url; + }; + + struct Configuration { + // Construct a Configuration with just an application name. + // All other fields use their defaults. + Configuration(std::string name) : app_name(std::move(name)) {} + + // Your application name. MUST be set to a valid name. + std::string app_name; + + // Application data directory. + // Default: {home}/.{appname}, where {home} is the user's home directory and {appname} is the app_name value. + std::optional app_data_dir; + + // Model cache directory. + // Default: {appdata}/cache/models, where {appdata} is the app_data_dir value. + std::optional model_cache_dir; + + // Log directory. + // Default: {appdata}/logs + std::optional logs_dir; + + // Logging level. + // Valid values are: Verbose, Debug, Information, Warning, Error, Fatal. + // Default: LogLevel.Warning + LogLevel log_level = LogLevel::Warning; + + // Optional web service configuration. + std::optional web; + + // Additional settings that Foundry Local Core can consume. + std::optional> additional_settings; + + void Validate() const { + if (app_name.empty()) { + throw std::invalid_argument("Configuration app_name must be set to a valid application name."); + } + + constexpr std::string_view invalidChars = R"(\/:?\"<>|)"; + if (app_name.find_first_of(invalidChars) != std::string::npos) { + throw std::invalid_argument("Configuration app_name value contains invalid characters."); + } + } + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/core_interop_request.h b/sdk/cpp/include/core_interop_request.h new file mode 100644 index 00000000..de03a61e --- /dev/null +++ b/sdk/cpp/include/core_interop_request.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include + +namespace FoundryLocal { + + class CoreInteropRequest final { + public: + explicit CoreInteropRequest(std::string command) : command_(std::move(command)) {} + + CoreInteropRequest& AddParam(std::string_view key, std::string_view value) { + params_[std::string(key)] = std::string(value); + return *this; + } + + template CoreInteropRequest& AddParam(std::string_view key, const T& value) { + params_[std::string(key)] = value; + return *this; + } + + CoreInteropRequest& AddJsonParam(std::string_view key, const nlohmann::json& jsonValue) { + params_[std::string(key)] = jsonValue.dump(); + return *this; + } + + std::string ToJson() const { + nlohmann::json wrapper; + if (!params_.empty()) { + wrapper["Params"] = params_; + } + return wrapper.dump(); + } + + const std::string& Command() const noexcept { return command_; } + + private: + std::string command_; + nlohmann::json params_; + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/flcore_native.h b/sdk/cpp/include/flcore_native.h new file mode 100644 index 00000000..8cde9ec8 --- /dev/null +++ b/sdk/cpp/include/flcore_native.h @@ -0,0 +1,34 @@ +#pragma once +#include +#include + +extern "C" { + // Layout must match C# structs exactly +#pragma pack(push, 8) + struct RequestBuffer { + const void* Command; + int32_t CommandLength; + const void* Data; + int32_t DataLength; + }; + + struct ResponseBuffer { + void* Data; + int32_t DataLength; + void* Error; + int32_t ErrorLength; + }; + + // Callback signature: void(*)(void* data, int length, void* userData) + using UserCallbackFn = void(__cdecl*)(void*, int32_t, void*); + + // Exported function pointer types + using execute_command_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*); + using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, void* /*userData*/); + using free_response_fn = void(__cdecl*)(ResponseBuffer*); + + static_assert(std::is_standard_layout::value, "RequestBuffer must be standard layout"); + static_assert(std::is_standard_layout::value, "ResponseBuffer must be standard layout"); + +#pragma pack(pop) +} diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h new file mode 100644 index 00000000..7db8987b --- /dev/null +++ b/sdk/cpp/include/foundry_local.h @@ -0,0 +1,413 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "configuration.h" +#include "foundry_local_internal_core.h" + +#include "logger.h" + +namespace FoundryLocal { +#ifdef FL_TESTS + namespace Testing { + struct MockObjectFactory; + } +#endif + + enum class DeviceType { + Invalid, + CPU, + GPU, + NPU + }; + + /// Reason the model stopped generating tokens. + enum class FinishReason { + None, + Stop, + Length, + ToolCalls, + ContentFilter + }; + + struct Runtime { + DeviceType device_type = DeviceType::Invalid; + std::string execution_provider; + }; + + struct PromptTemplate { + std::string system; + std::string user; + std::string assistant; + std::string prompt; + }; + + struct AudioCreateTranscriptionResponse { + std::string text; + }; + + /// JSON Schema property definition used to describe tool function parameters. + struct PropertyDefinition { + std::string type; + std::optional description; + std::optional> properties; + std::optional> required; + }; + + /// Describes a function that a model may call. + struct FunctionDefinition { + std::string name; + std::optional description; + std::optional parameters; + }; + + /// A tool definition following the OpenAI tool calling spec. + struct ToolDefinition { + std::string type = "function"; + FunctionDefinition function; + }; + + /// A parsed function call returned by the model. + struct FunctionCall { + std::string name; + std::string arguments; ///< JSON string of the arguments + }; + + /// A tool call returned by the model in a chat completion response. + struct ToolCall { + std::string id; + std::string type; + std::optional function_call; + }; + + /// Controls whether and how the model calls tools. + enum class ToolChoiceKind { + Auto, + None, + Required + }; + + struct ChatMessage { + std::string role; + std::string content; + std::optional tool_call_id; ///< For role="tool" responses + std::vector tool_calls; + }; + + struct ChatChoice { + int index = 0; + FinishReason finish_reason = FinishReason::None; + + // non-streaming + std::optional message; + + // streaming + std::optional delta; + }; + + struct ChatCompletionCreateResponse { + int64_t created = 0; + std::string id; + + bool is_delta = false; + bool successful = false; + int http_status_code = 0; + + std::vector choices; + + /// Returns the object type string. Derived from is_delta — no allocation. + const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } + + /// Returns the created timestamp as an ISO 8601 string. + /// Computed lazily; only allocates when called. + std::string GetCreatedAtIso() const; + }; + + struct ChatSettings { + std::optional frequency_penalty; + std::optional max_tokens; + std::optional n; + std::optional temperature; + std::optional presence_penalty; + std::optional random_seed; + std::optional top_k; + std::optional top_p; + std::optional tool_choice; + }; + + using DownloadProgressCallback = std::function; + + // Forward declarations + class ModelVariant; + + struct Parameter { + std::string name; + std::optional value; + }; + + struct ModelSettings { + std::vector parameters; + }; + + struct ModelInfo { + std::string id; + std::string name; + uint32_t version = 0; + std::string alias; + std::optional display_name; + std::string provider_type; + std::string uri; + std::string model_type; + std::optional prompt_template; + std::optional publisher; + std::optional model_settings; + std::optional license; + std::optional license_description; + bool cached = false; + std::optional task; + std::optional runtime; + std::optional file_size_mb; + std::optional supports_tool_calling; + std::optional max_output_tokens; + std::optional min_fl_version; + int64_t created_at_unix = 0; + }; + + class AudioClient final { + public: + explicit AudioClient(gsl::not_null model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath) const; + + using StreamCallback = std::function; + void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; + + private: + AudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class ModelVariant; + }; + + class ChatClient final { + public: + explicit ChatClient(gsl::not_null model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + const ChatSettings& settings) const; + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const; + + using StreamCallback = std::function; + void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const; + + void CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk) const; + + private: + ChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, + const ChatSettings& settings, bool stream) const; + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class ModelVariant; + }; + + class ModelVariant final { + public: + const ModelInfo& GetInfo() const; + const std::filesystem::path& GetPath() const; + void Download(DownloadProgressCallback onProgress = nullptr) const; + void Load() const; + + bool IsLoaded() const; + bool IsCached() const; + void Unload() const; + void RemoveFromCache(); + + [[deprecated("Use AudioClient(model) constructor instead")]] + AudioClient GetAudioClient() const; + + [[deprecated("Use ChatClient(model) constructor instead")]] + ChatClient GetChatClient() const; + + const std::string& GetId() const noexcept; + const std::string& GetAlias() const noexcept; + uint32_t GetVersion() const noexcept; + + private: + static std::string MakeModelParamRequest(std::string_view modelId); + explicit ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger); + + ModelInfo info_; + mutable std::filesystem::path cachedPath_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class Catalog; + friend class AudioClient; + friend class ChatClient; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class Model final { + public: + gsl::span GetAllModelVariants() const; + const ModelVariant* GetLatestVariant(gsl::not_null variant) const; + + bool IsLoaded() const { return SelectedVariant().IsLoaded(); } + bool IsCached() const { return SelectedVariant().IsCached(); } + const std::filesystem::path& GetPath() const { return SelectedVariant().GetPath(); } + void Download(DownloadProgressCallback onProgress = nullptr) const { + SelectedVariant().Download(std::move(onProgress)); + } + void Load() const { SelectedVariant().Load(); } + void Unload() const { SelectedVariant().Unload(); } + void RemoveFromCache() { SelectedVariant().RemoveFromCache(); } + [[deprecated("Use AudioClient(model) constructor instead")]] + AudioClient GetAudioClient() const { + return SelectedVariant().GetAudioClient(); + } + + [[deprecated("Use ChatClient(model) constructor instead")]] + ChatClient GetChatClient() const { + return SelectedVariant().GetChatClient(); + } + + const std::string& GetId() const; + const std::string& GetAlias() const; + void SelectVariant(gsl::not_null variant) const; + + private: + explicit Model(gsl::not_null core, gsl::not_null logger); + ModelVariant& SelectedVariant(); + const ModelVariant& SelectedVariant() const; + + gsl::not_null core_; + + std::vector variants_; + mutable std::optional selectedVariantIndex_; + gsl::not_null logger_; + + friend class Catalog; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class Catalog final { + public: + Catalog(const Catalog&) = delete; + Catalog& operator=(const Catalog&) = delete; + Catalog(Catalog&&) = delete; + Catalog& operator=(Catalog&&) = delete; + + static std::unique_ptr Create(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + const std::string& GetName() const { return name_; } + std::vector ListModels() const; + std::vector GetLoadedModels() const; + std::vector GetCachedModels() const; + + const Model* GetModel(std::string_view modelId) const; + const ModelVariant* GetModelVariant(std::string_view modelVariantId) const; + + private: + void UpdateModels() const; + + mutable std::chrono::steady_clock::time_point lastFetch_{}; + + mutable std::unordered_map byAlias_; + mutable std::unordered_map modelIdToModelVariant_; + + explicit Catalog(gsl::not_null injected, + gsl::not_null logger); + + gsl::not_null core_; + std::string name_; + gsl::not_null logger_; + + friend class FoundryLocalManager; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class FoundryLocalManager final { + public: + FoundryLocalManager(const FoundryLocalManager&) = delete; + FoundryLocalManager& operator=(const FoundryLocalManager&) = delete; + FoundryLocalManager(FoundryLocalManager&& other) noexcept; + FoundryLocalManager& operator=(FoundryLocalManager&& other) noexcept; + + explicit FoundryLocalManager(Configuration configuration, ILogger* logger = nullptr); + ~FoundryLocalManager(); + + const Catalog& GetCatalog() const; + + /// Start the optional built-in web service. + /// Provides an OpenAI-compatible REST endpoint. + /// After startup, GetUrls() returns the actual bound URL/s. + /// Requires Configuration::Web to be set. + void StartWebService(); + + /// Stop the web service if started. + void StopWebService(); + + /// Returns the bound URL/s after StartWebService(), or empty if not started. + gsl::span GetUrls() const noexcept; + + /// Ensure execution providers are downloaded and registered. + /// Once downloaded, EPs are not re-downloaded unless a new version is available. + void EnsureEpsDownloaded() const; + + private: + bool OwnsLogger() const noexcept { return logger_ == &defaultLogger_; } + + Configuration config_; + + void Initialize(); + + NullLogger defaultLogger_; + std::unique_ptr core_; + std::unique_ptr catalog_; + ILogger* logger_; + std::vector urls_; + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/foundry_local_exception.h b/sdk/cpp/include/foundry_local_exception.h new file mode 100644 index 00000000..6ca886a1 --- /dev/null +++ b/sdk/cpp/include/foundry_local_exception.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +#include "logger.h" + +namespace FoundryLocal { + + class FoundryLocalException final : public std::runtime_error { + public: + explicit FoundryLocalException(std::string message) : std::runtime_error(std::move(message)) {} + + FoundryLocalException(std::string message, ILogger& logger) : std::runtime_error(std::move(message)) { + logger.Log(LogLevel::Error, what()); + } + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/foundry_local_internal_core.h b/sdk/cpp/include/foundry_local_internal_core.h new file mode 100644 index 00000000..eedfa5d4 --- /dev/null +++ b/sdk/cpp/include/foundry_local_internal_core.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include "logger.h" + +namespace FoundryLocal { + namespace Internal { + struct IFoundryLocalCore { + virtual ~IFoundryLocalCore() = default; + + virtual std::string call(std::string_view command, ILogger& logger, + const std::string* dataArgument = nullptr, void* callback = nullptr, + void* data = nullptr) const = 0; + virtual void unload() = 0; + }; + + } // namespace Internal +} // namespace FoundryLocal \ No newline at end of file diff --git a/sdk/cpp/include/log_level.h b/sdk/cpp/include/log_level.h new file mode 100644 index 00000000..d9b82863 --- /dev/null +++ b/sdk/cpp/include/log_level.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +namespace FoundryLocal { + + enum class LogLevel { + Verbose, + Debug, + Information, + Warning, + Error, + Fatal + }; + + inline std::string_view LogLevelToString(LogLevel level) noexcept { + switch (level) { + case LogLevel::Verbose: + return "Verbose"; + case LogLevel::Debug: + return "Debug"; + case LogLevel::Information: + return "Information"; + case LogLevel::Warning: + return "Warning"; + case LogLevel::Error: + return "Error"; + case LogLevel::Fatal: + return "Fatal"; + } + return "Unknown"; + } + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/logger.h b/sdk/cpp/include/logger.h new file mode 100644 index 00000000..98d10155 --- /dev/null +++ b/sdk/cpp/include/logger.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include "log_level.h" + +namespace FoundryLocal { + class ILogger { + public: + virtual ~ILogger() = default; + virtual void Log(LogLevel level, std::string_view message) noexcept = 0; + }; + + class NullLogger final : public ILogger { + public: + void Log(LogLevel, std::string_view) noexcept override {} + }; +} // namespace FoundryLocal diff --git a/sdk/cpp/include/parser.h b/sdk/cpp/include/parser.h new file mode 100644 index 00000000..5396596d --- /dev/null +++ b/sdk/cpp/include/parser.h @@ -0,0 +1,288 @@ +#pragma once +#include +#include +#include "foundry_local.h" +#include + +namespace FoundryLocal { + inline DeviceType parse_device_type(std::string_view v) { + if (v == "CPU") { + return DeviceType::CPU; + } + if (v == "NPU") { + return DeviceType::NPU; + } + if (v == "GPU") { + return DeviceType::GPU; + } + return DeviceType::Invalid; + } + + inline FinishReason parse_finish_reason(std::string_view v) { + if (v == "stop") + return FinishReason::Stop; + if (v == "length") + return FinishReason::Length; + if (v == "tool_calls") + return FinishReason::ToolCalls; + if (v == "content_filter") + return FinishReason::ContentFilter; + return FinishReason::None; + } + + // ---------- Helpers ---------- + inline std::string get_string_or_empty(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + std::string out = ""; + if (it != j.end() && it->is_string()) { + out = it->get(); + } + return out; + } + + inline void from_json(const nlohmann::json& j, Runtime& r) { + std::string deviceType; + j.at("deviceType").get_to(deviceType); + j.at("executionProvider").get_to(r.execution_provider); + + r.device_type = parse_device_type(std::move(deviceType)); + } + + inline void from_json(const nlohmann::json& j, PromptTemplate& p) { + p.system = get_string_or_empty(j, "system"); + p.user = get_string_or_empty(j, "user"); + p.assistant = get_string_or_empty(j, "assistant"); + p.prompt = get_string_or_empty(j, "prompt"); + } + + inline std::optional get_opt_string(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_string()) { + return it->get(); + } + return std::nullopt; + } + + inline std::optional get_opt_int(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } + return std::nullopt; + } + + inline std::optional get_opt_i64(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } + return std::nullopt; + } + + inline std::optional get_opt_bool(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_boolean()) { + return it->get(); + } + return std::nullopt; + } + + inline void from_json(const nlohmann::json& j, Parameter& p) { + j.at("name").get_to(p.name); + p.value = get_opt_string(j, "value"); + } + + inline void from_json(const nlohmann::json& j, ModelSettings& ms) { + ms.parameters.clear(); + if (auto it = j.find("parameters"); it != j.end() && it->is_array()) { + ms.parameters = it->get>(); + } + } + + inline void from_json(const nlohmann::json& j, ModelInfo& m) { + j.at("id").get_to(m.id); + j.at("name").get_to(m.name); + j.at("version").get_to(m.version); + j.at("alias").get_to(m.alias); + j.at("providerType").get_to(m.provider_type); + j.at("uri").get_to(m.uri); + j.at("modelType").get_to(m.model_type); + + m.display_name = get_opt_string(j, "displayName"); + m.publisher = get_opt_string(j, "publisher"); + m.license = get_opt_string(j, "license"); + m.license_description = get_opt_string(j, "licenseDescription"); + m.task = get_opt_string(j, "task"); + if (auto it = j.find("fileSizeMb"); it != j.end() && !it->is_null() && it->is_number_integer()) { + auto v = it->get(); + m.file_size_mb = (v >= 0) ? static_cast(v) : 0u; + } + m.supports_tool_calling = get_opt_bool(j, "supportsToolCalling"); + m.max_output_tokens = get_opt_i64(j, "maxOutputTokens"); + m.min_fl_version = get_opt_string(j, "minFLVersion"); + + if (auto it = j.find("cached"); it != j.end() && it->is_boolean()) { + m.cached = it->get(); + } + else { + m.cached = false; + } + + if (auto it = j.find("createdAt"); it != j.end() && it->is_number_integer()) { + m.created_at_unix = it->get(); + } + else { + m.created_at_unix = 0; + } + + // nested optional objects + if (auto it = j.find("modelSettings"); it != j.end() && it->is_object()) { + m.model_settings = it->get(); + } + else { + m.model_settings.reset(); + } + + if (auto it = j.find("promptTemplate"); it != j.end() && it->is_object()) { + m.prompt_template = it->get(); + } + else { + m.prompt_template.reset(); + } + + if (auto it = j.find("runtime"); it != j.end() && it->is_object()) { + m.runtime = it->get(); + } + else { + m.runtime.reset(); + } + } + + // ---------- Tool calling: to_json (serialization for requests) ---------- + + inline void to_json(nlohmann::json& j, const PropertyDefinition& pd) { + j = nlohmann::json{{"type", pd.type}}; + if (pd.description) + j["description"] = *pd.description; + if (pd.properties) { + nlohmann::json props = nlohmann::json::object(); + for (const auto& [key, val] : *pd.properties) { + nlohmann::json pj; + to_json(pj, val); + props[key] = std::move(pj); + } + j["properties"] = std::move(props); + } + if (pd.required) + j["required"] = *pd.required; + } + + inline void to_json(nlohmann::json& j, const FunctionDefinition& fd) { + j = nlohmann::json{{"name", fd.name}}; + if (fd.description) + j["description"] = *fd.description; + if (fd.parameters) { + nlohmann::json pj; + to_json(pj, *fd.parameters); + j["parameters"] = std::move(pj); + } + } + + inline void to_json(nlohmann::json& j, const ToolDefinition& td) { + j = nlohmann::json{{"type", td.type}}; + nlohmann::json fj; + to_json(fj, td.function); + j["function"] = std::move(fj); + } + + // ---------- Tool calling: from_json (deserialization from responses) ---------- + + inline void from_json(const nlohmann::json& j, FunctionCall& fc) { + fc.name = get_string_or_empty(j, "name"); + if (j.contains("arguments")) { + const auto& args = j.at("arguments"); + if (args.is_string()) + fc.arguments = args.get(); + else + fc.arguments = args.dump(); + } + } + + inline void from_json(const nlohmann::json& j, ToolCall& tc) { + tc.id = get_string_or_empty(j, "id"); + tc.type = get_string_or_empty(j, "type"); + if (j.contains("function") && j.at("function").is_object()) + tc.function_call = j.at("function").get(); + } + + inline void from_json(const nlohmann::json& j, ChatMessage& m) { + if (j.contains("role")) + j.at("role").get_to(m.role); + if (j.contains("content") && !j.at("content").is_null()) + j.at("content").get_to(m.content); + + m.tool_call_id = get_opt_string(j, "tool_call_id"); + + m.tool_calls.clear(); + if (j.contains("tool_calls") && j.at("tool_calls").is_array()) { + for (const auto& tc : j.at("tool_calls")) { + if (tc.is_object()) + m.tool_calls.push_back(tc.get()); + } + } + } + + inline void from_json(const nlohmann::json& j, ChatChoice& c) { + if (j.contains("index")) + j.at("index").get_to(c.index); + if (j.contains("finish_reason") && !j.at("finish_reason").is_null()) + c.finish_reason = parse_finish_reason(j.at("finish_reason").get()); + + if (j.contains("message") && !j.at("message").is_null()) + c.message = j.at("message").get(); + + if (j.contains("delta") && !j.at("delta").is_null()) + c.delta = j.at("delta").get(); + } + + inline void from_json(const nlohmann::json& j, ChatCompletionCreateResponse& r) { + if (j.contains("created")) + j.at("created").get_to(r.created); + r.id = get_string_or_empty(j, "id"); + if (j.contains("IsDelta")) + j.at("IsDelta").get_to(r.is_delta); + if (j.contains("Successful")) + j.at("Successful").get_to(r.successful); + if (j.contains("HttpStatusCode")) + j.at("HttpStatusCode").get_to(r.http_status_code); + + r.choices.clear(); + if (j.contains("choices") && j.at("choices").is_array()) { + r.choices = j.at("choices").get>(); + } + } + + // ---------- Tool choice helpers ---------- + + inline std::string tool_choice_to_string(ToolChoiceKind kind) { + switch (kind) { + case ToolChoiceKind::Auto: return "auto"; + case ToolChoiceKind::None: return "none"; + case ToolChoiceKind::Required: return "required"; + } + return "auto"; + } + +} // namespace FoundryLocal \ No newline at end of file diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp new file mode 100644 index 00000000..7e152d4b --- /dev/null +++ b/sdk/cpp/sample/main.cpp @@ -0,0 +1,400 @@ +#include + +#include +#include +#include + +#include "foundry_local.h" + + +using namespace FoundryLocal; + +// --------------------------------------------------------------------------- +// Logger +// --------------------------------------------------------------------------- +class StdLogger final : public ILogger { +public: + void Log(LogLevel level, std::string_view message) noexcept override { + const char* tag = "UNK"; + switch (level) { + case LogLevel::Information: + tag = "INFO"; + break; + case LogLevel::Warning: + tag = "WARN"; + break; + case LogLevel::Error: + tag = "ERROR"; + break; + default: + tag = "DEBUG"; + break; + } + std::fprintf(stderr, "[FoundryLocal][%s] %.*s\n", tag, static_cast(message.size()), message.data()); + } +}; + +// --------------------------------------------------------------------------- +// Example 1 – Browse the catalog +// --------------------------------------------------------------------------- +void BrowseCatalog(FoundryLocalManager& manager) { + std::cout << "\n=== Example 1: Browse Catalog ===\n"; + + auto& catalog = manager.GetCatalog(); + std::cout << "Catalog: " << catalog.GetName() << "\n"; + + auto models = catalog.ListModels(); + std::cout << "Models in catalog: " << models.size() << "\n"; + + for (const auto* model : models) { + std::cout << " - " << model->GetAlias() << " (" << model->GetId() << ")" + << " cached=" << (model->IsCached() ? "yes" : "no") + << " loaded=" << (model->IsLoaded() ? "yes" : "no") << "\n"; + + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + std::cout << " variant: " << info.name << " v" << info.version; + if (info.runtime) + std::cout << " device=" << (info.runtime->device_type == DeviceType::GPU ? "GPU" : "CPU"); + if (info.file_size_mb) + std::cout << " size=" << *info.file_size_mb << "MB"; + std::cout << "\n"; + } + } +} + +// --------------------------------------------------------------------------- +// Example 2 – Download, load, chat (non-streaming), then unload +// --------------------------------------------------------------------------- +void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 2: Non-Streaming Chat ===\n"; + + auto& catalog = manager.GetCatalog(); + auto models = catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + std::cout << "Model loaded: " << model->GetAlias() << "\n"; + + // Get the selected variant pointer for ChatClient + const auto& selectedVariant = model->GetAllModelVariants()[0]; + ChatClient chat(&selectedVariant); + + std::vector messages = {{"system", "You are a helpful assistant. Keep answers brief."}, + {"user", "What is the capital of Croatia?"}}; + + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 128; + + auto response = chat.CompleteChat(messages, settings); + + if (!response.choices.empty() && response.choices[0].message) { + std::cout << "Assistant: " << response.choices[0].message->content << "\n"; + } + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + +// --------------------------------------------------------------------------- +// Example 3 – Streaming chat +// --------------------------------------------------------------------------- +void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 3: Streaming Chat ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Load(); + + const auto& selectedVariant = model->GetAllModelVariants()[0]; + ChatClient chat(&selectedVariant); + + std::vector messages = {{"user", "Explain quantum computing in three sentences."}}; + + ChatSettings settings; + settings.temperature = 0.9f; + settings.max_tokens = 256; + + std::cout << "Assistant: "; + chat.CompleteChatStreaming(messages, settings, [](const ChatCompletionCreateResponse& chunk) { + if (chunk.choices.empty()) + return; + const auto& choice = chunk.choices[0]; + if (choice.delta && !choice.delta->content.empty()) { + std::cout << choice.delta->content << std::flush; + } + else if (choice.message && !choice.message->content.empty()) { + std::cout << choice.message->content << std::flush; + } + }); + std::cout << "\n"; + + model->Unload(); +} + +// --------------------------------------------------------------------------- +// Example 4 – Audio transcription +// --------------------------------------------------------------------------- +void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, const std::string& audioPath) { + std::cout << "\n=== Example 4: Audio Transcription ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + + const auto& selectedVariant = model->GetAllModelVariants()[0]; + AudioClient audio(&selectedVariant); + + std::cout << "Transcribing: " << audioPath << "\n"; + auto result = audio.TranscribeAudio(audioPath); + std::cout << "Transcription: " << result.text << "\n"; + + // Streaming alternative: + audio.TranscribeAudioStreaming( + audioPath, [](const AudioCreateTranscriptionResponse& chunk) { std::cout << chunk.text << std::flush; }); + std::cout << "\n"; + + model->Unload(); +} + +// --------------------------------------------------------------------------- +// Example 5 – Tool calling +// --------------------------------------------------------------------------- +// Tool calling lets you define functions that the model can decide to invoke. +// The flow is: +// 1. You describe your tools (functions) as ToolDefinition objects. +// 2. You send a chat request with those tools attached. +// 3. The model may respond with finish_reason = ToolCalls and include +// ToolCall objects in the message, each containing the function name +// and a JSON string of arguments. +// 4. YOUR CODE executes the real function using those arguments. +// 5. You add a message with role = "tool" containing the result, then +// send the conversation back so the model can formulate a final answer. +// +// This lets the model "reach out" to external capabilities (calculators, +// databases, APIs, etc.) while keeping the actual execution in your code. +// --------------------------------------------------------------------------- +void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 5: Tool Calling ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + std::cout << "Model loaded: " << model->GetAlias() << "\n"; + + const auto& selectedVariant = model->GetAllModelVariants()[0]; + ChatClient chat(&selectedVariant); + + // ── Step 1: Define tools ────────────────────────────────────────────── + // Each tool describes a function the model can call. The PropertyDefinition + // mirrors a JSON Schema so the model knows what arguments are expected. + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", // function name + "Multiply two integers and return the result.", // description + PropertyDefinition{ + "object", // top-level schema type + std::nullopt, // no top-level description + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} // both params are required + } + } + }}; + + // ── Step 2: Send the first request ──────────────────────────────────── + // tool_choice = Required forces the model to always produce a tool call. + // In production you'd typically use Auto so the model decides on its own. + std::vector messages = { + {"system", "You are a helpful AI assistant. Use the provided tools when appropriate."}, + {"user", "What is 7 multiplied by 6?"} + }; + + ChatSettings settings; + settings.temperature = 0.0f; + settings.max_tokens = 500; + settings.tool_choice = ToolChoiceKind::Required; + + std::cout << "Sending chat request with tool definitions...\n"; + auto response = chat.CompleteChat(messages, tools, settings); + + // ── Step 3: Inspect the model's tool call ───────────────────────────── + if (response.choices.empty()) { + std::cerr << "No choices returned.\n"; + model->Unload(); + return; + } + + const auto& firstChoice = response.choices[0]; + + // The model signals it wants to call a tool via finish_reason == ToolCalls. + if (firstChoice.finish_reason == FinishReason::ToolCalls && + firstChoice.message && !firstChoice.message->tool_calls.empty()) + { + const auto& tc = firstChoice.message->tool_calls[0]; + std::cout << "Model requested tool call:\n" + << " function : " << (tc.function_call ? tc.function_call->name : "(none)") << "\n" + << " arguments: " << (tc.function_call ? tc.function_call->arguments : "{}") << "\n"; + + // ── Step 4: Execute the tool locally ────────────────────────────── + // Parse the arguments JSON and perform the actual computation. + // In a real application this could be a web request, DB query, etc. + std::string toolResult; + if (tc.function_call && tc.function_call->name == "multiply_numbers") { + // The arguments string is JSON, e.g. {"first": 7, "second": 6} + // For brevity we hard-code the expected result here. + toolResult = "7 x 6 = 42."; + std::cout << " result : " << toolResult << "\n"; + } else { + toolResult = "Unknown tool."; + } + + // ── Step 5: Feed the tool result back ───────────────────────────── + // Add the assistant's message (including the raw tool_call content) + // and then a "tool" message with the result. + messages.push_back({"tool", toolResult}); + + // Add a follow-up system instruction so the model uses the tool output. + messages.push_back({"system", "Respond only with the answer generated by the tool."}); + + // Switch to Auto so the model can answer without calling tools again. + settings.tool_choice = ToolChoiceKind::Auto; + + std::cout << "\nSending tool result back to model...\n"; + auto followUp = chat.CompleteChat(messages, tools, settings); + + if (!followUp.choices.empty() && followUp.choices[0].message) { + std::cout << "Assistant: " << followUp.choices[0].message->content << "\n"; + } + } + else { + // The model answered directly without a tool call. + if (firstChoice.message) + std::cout << "Assistant: " << firstChoice.message->content << "\n"; + } + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + +// --------------------------------------------------------------------------- +// Example 6 – Model variant inspection & selection +// --------------------------------------------------------------------------- +void InspectVariants(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 6: Variant Inspection ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + auto variants = model->GetAllModelVariants(); + std::cout << "Model '" << alias << "' has " << variants.size() << " variant(s):\n"; + + for (const auto& v : variants) { + const auto& info = v.GetInfo(); + std::cout << " " << info.name << " v" << info.version << " cached=" << (v.IsCached() ? "yes" : "no"); + if (info.display_name) + std::cout << " display=\"" << *info.display_name << "\""; + if (info.publisher) + std::cout << " publisher=" << *info.publisher; + if (info.license) + std::cout << " license=" << *info.license; + if (info.runtime) { + std::cout << " device=" + << (info.runtime->device_type == DeviceType::GPU ? "GPU" + : info.runtime->device_type == DeviceType::NPU ? "NPU" + : "CPU") + << " ep=" << info.runtime->execution_provider; + } + if (info.supports_tool_calling) + std::cout << " tools=" << (*info.supports_tool_calling ? "yes" : "no"); + std::cout << "\n"; + } + + // Select a specific variant by pointer (e.g. prefer the GPU variant) + for (const auto& v : variants) { + if (v.GetInfo().runtime && v.GetInfo().runtime->device_type == DeviceType::GPU) { + model->SelectVariant(&v); + std::cout << "Selected GPU variant: " << model->GetId() << "\n"; + break; + } + } +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- +int main() { + try { + StdLogger logger; + FoundryLocalManager manager({"SampleApp"}, &logger); + + // 1. Browse the full catalog + BrowseCatalog(manager); + + // 2. Non-streaming chat (change alias to a model in your catalog) + ChatNonStreaming(manager, "phi-3.5-mini"); + + // 3. Streaming chat + ChatStreaming(manager, "phi-3.5-mini"); + + // 4. Audio transcription (uncomment and set a valid alias + wav path) + // TranscribeAudio(manager, "whisper-small", R"(C:\path\to\your\audio.wav)"); + + // 5. Tool calling (define tools, let the model call them, feed results back) + ChatWithToolCalling(manager, "phi-3.5-mini"); + + // 6. Inspect model variants and select one + InspectVariants(manager, "phi-3.5-mini"); + + return 0; + } + catch (const std::exception& ex) { + std::cerr << "Fatal: " << ex.what() << std::endl; + return 1; + } +} \ No newline at end of file diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp new file mode 100644 index 00000000..37da932e --- /dev/null +++ b/sdk/cpp/src/foundry_local.cpp @@ -0,0 +1,843 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "core_interop_request.h" +#include "configuration.h" +#include "foundry_local.h" +#include "flcore_native.h" +#include "foundry_local_internal_core.h" +#include "parser.h" +#include "logger.h" +#include +#include "foundry_local_exception.h" + +// Internal private namespace. +namespace { + std::filesystem::path getExecutableDir() { + auto exePath = wil::GetModuleFileNameW(nullptr); + return std::filesystem::path(exePath.get()).parent_path(); + } +} // namespace + +namespace { + // Wrap Params: { ... } into a request object + inline nlohmann::json MakeParams(nlohmann::json params) { + return nlohmann::json{ {"Params", std::move(params)} }; + } + + // Most common: Params { "Model": } + inline nlohmann::json MakeModelParams(std::string_view model) { + return MakeParams(nlohmann::json{ {"Model", std::string(model)} }); + } + + // Serialize + call + inline std::string CallWithJson(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, FoundryLocal::ILogger& logger) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload); + } + + // Serialize + call with native callback + inline std::string CallWithJsonAndCallback(FoundryLocal::Internal::IFoundryLocalCore* core, + std::string_view command, const nlohmann::json& requestJson, FoundryLocal::ILogger& logger, + void* callback, void* userData) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload, callback, userData); + } + + // Overload: allow Params object directly + inline std::string CallWithParams(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& params, FoundryLocal::ILogger& logger) { + return CallWithJson(core, command, MakeParams(params), logger); + } + + // Overload: no payload + inline std::string CallNoArgs(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, + FoundryLocal::ILogger& logger) { + return core->call(command, logger, nullptr); + } + + std::vector GetLoadedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, + FoundryLocal::ILogger& logger) { + std::string raw = core->call("list_loaded_models", logger); + try { + auto parsed = nlohmann::json::parse(raw); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw FoundryLocal::FoundryLocalException( + "Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); + } + } + + std::vector GetCachedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, + FoundryLocal::ILogger& logger) { + std::string raw = core->call("get_cached_models", logger); + + try { + auto parsed = nlohmann::json::parse(raw); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw FoundryLocal::FoundryLocalException( + "Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); + } + } + + inline void StripSuffixAfterColon(std::string& id) { + const auto pos = id.find_last_of(':'); + if (pos != std::string::npos) { + id.erase(pos); + } + } + + std::vector + CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, + std::vector ids) { + std::vector out; + out.reserve(ids.size()); + + for (auto& id : ids) { + StripSuffixAfterColon(id); + + auto it = modelIdToModelVariant.find(id); + if (it != modelIdToModelVariant.end()) { + out.emplace_back(it->second); + } + } + return out; + } + +} // namespace + +namespace FoundryLocal { + inline static void* RequireProc(HMODULE mod, const char* name) { + if (void* p = ::GetProcAddress(mod, name)) + return p; + throw std::runtime_error(std::string("GetProcAddress failed for ") + name); + } + + struct Core : FoundryLocal::Internal::IFoundryLocalCore { + using ResponseHandle = std::unique_ptr; + + Core() = default; + ~Core() = default; + + void loadEmbedded() { + loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); + } + + void unload() { + module_.reset(); + execCmd_ = nullptr; + execCbCmd_ = nullptr; + freeResCmd_ = nullptr; + } + std::string call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, + void* callback = nullptr, void* data = nullptr) const override { + if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { + throw FoundryLocalException( + "Core is not loaded. Cannot call command: " + std::string(command), logger); + } + + RequestBuffer request{}; + request.Command = command.empty() ? nullptr : command.data(); + request.CommandLength = static_cast(command.size()); + + if (dataArgument && !dataArgument->empty()) { + request.Data = dataArgument->data(); + request.DataLength = static_cast(dataArgument->size()); + } + + ResponseBuffer response{}; + auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { + if (fn) fn(buf); + }; + std::unique_ptr responseGuard(&response, safeDeleter); + + using CallbackFn = void (*)(void*, int32_t, void*); + + if (callback != nullptr) { + auto cb = reinterpret_cast(callback); + execCbCmd_(&request, &response, reinterpret_cast(cb), data); + } + else { + execCmd_(&request, &response); + } + + std::string result; + if (response.Error && response.ErrorLength > 0) { + std::string err(static_cast(response.Error), response.ErrorLength); + throw FoundryLocalException( + std::string("Command failed [").append(command).append("]: ").append(err), logger); + } + + if (response.Data && response.DataLength > 0) { + result.assign(static_cast(response.Data), response.DataLength); + } + + return result; + } + + private: + wil::unique_hmodule module_; + execute_command_fn execCmd_{}; + execute_command_with_callback_fn execCbCmd_{}; + free_response_fn freeResCmd_{}; + + void loadFromPath(const std::filesystem::path& path) { + wil::unique_hmodule m(::LoadLibraryW(path.c_str())); + if (!m) + throw std::runtime_error("LoadLibraryW failed"); + + execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); + execCbCmd_ = reinterpret_cast( + RequireProc(m.get(), "execute_command_with_callback")); + freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); + + module_ = std::move(m); + } + }; + + /// + /// AudioClient + /// + + AudioClient::AudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) { + } + + AudioCreateTranscriptionResponse AudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { + nlohmann::json openAiReq = { {"Model", modelId_}, {"FileName", audioFilePath.string()} }; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + AudioCreateTranscriptionResponse response; + response.text = core_->call(req.Command(), *logger_, &json); + + return response; + } + + void AudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const { + nlohmann::json openAiReq = { {"Model", modelId_}, {"FileName", audioFilePath.string()} }; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + struct State { + const StreamCallback* cb; + std::exception_ptr exception; + } state{ &onChunk, nullptr }; + + auto streamCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + try { + std::string text(static_cast(data), static_cast(len)); + AudioCreateTranscriptionResponse chunk; + chunk.text = std::move(text); + (*(st->cb))(chunk); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), + reinterpret_cast(&state)); + + if (state.exception) { + std::rethrow_exception(state.exception); + } + } + + + std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { + if (created == 0) return {}; + std::time_t t = static_cast(created); + std::tm tm{}; +#ifdef _WIN32 + gmtime_s(&tm, &t); +#else + gmtime_r(&t, &tm); +#endif + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); + return buf; + } + + /// + /// ChatClient + /// + + ChatClient::ChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) { + } + + std::string ChatClient::BuildChatRequestJson(gsl::span messages, gsl::span tools, + const ChatSettings& settings, bool stream) const { + nlohmann::json jMessages = nlohmann::json::array(); + for (const auto& msg : messages) { + nlohmann::json jMsg = { {"role", msg.role}, {"content", msg.content} }; + if (msg.tool_call_id) + jMsg["tool_call_id"] = *msg.tool_call_id; + jMessages.push_back(std::move(jMsg)); + } + + nlohmann::json req = { {"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream} }; + + if (!tools.empty()) { + nlohmann::json jTools = nlohmann::json::array(); + for (const auto& tool : tools) { + nlohmann::json jTool; + to_json(jTool, tool); + jTools.push_back(std::move(jTool)); + } + req["tools"] = std::move(jTools); + } + + if (settings.tool_choice) + req["tool_choice"] = tool_choice_to_string(*settings.tool_choice); + if (settings.top_k) + req["metadata"] = { {"top_k", *settings.top_k} }; + if (settings.frequency_penalty) + req["frequency_penalty"] = *settings.frequency_penalty; + if (settings.presence_penalty) + req["presence_penalty"] = *settings.presence_penalty; + if (settings.max_tokens) + req["max_completion_tokens"] = *settings.max_tokens; + if (settings.n) + req["n"] = *settings.n; + if (settings.temperature) + req["temperature"] = *settings.temperature; + if (settings.top_p) + req["top_p"] = *settings.top_p; + if (settings.random_seed) + req["seed"] = *settings.random_seed; + + return req.dump(); + } + + ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, + const ChatSettings& settings) const { + return CompleteChat(messages, {}, settings); + } + + ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, + gsl::span tools, const ChatSettings& settings) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + std::string rawResult = core_->call(req.Command(), *logger_, &json); + + return nlohmann::json::parse(rawResult).get(); + } + + void ChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const { + CompleteChatStreaming(messages, {}, settings, onChunk); + } + + void ChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + std::string json = req.ToJson(); + + struct State { + const StreamCallback* cb; + std::exception_ptr exception; + } state{ &onChunk, nullptr }; + + auto streamCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + std::string s(static_cast(data), static_cast(len)); + + try { + auto parsed = nlohmann::json::parse(s).get(); + + (*(st->cb))(parsed); + } + catch (const nlohmann::json::exception& e) { + st->exception = std::make_exception_ptr( + FoundryLocalException(std::string("Error while parsing streaming chat chunk: ") + e.what())); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), + reinterpret_cast(&state)); + + if (state.exception) { + std::rethrow_exception(state.exception); + } + } + + /// + /// ModelVariant + /// + + ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) + : core_(core), info_(std::move(info)), logger_(logger) { + } + + const ModelInfo& ModelVariant::GetInfo() const { + return info_; + } + + void ModelVariant::RemoveFromCache() { + try { + CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); + cachedPath_.clear(); + } + catch (const std::exception& ex) { + throw FoundryLocalException("Error removing model from cache [" + info_.name + "]: " + ex.what(), *logger_); + } + } + + void ModelVariant::Unload() const { + try { + CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); + } + catch (const std::exception& ex) { + throw FoundryLocalException("Error unloading model [" + info_.name + "]: " + ex.what(), *logger_); + } + } + + bool ModelVariant::IsLoaded() const { + std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); + for (auto& id : loadedModelIds) { + auto pos = id.find_last_of(':'); + if (pos != std::string::npos) { + id.erase(pos); + } + + if (id == info_.name) { + return true; + } + } + + return false; + } + + bool ModelVariant::IsCached() const { + auto cachedModels = GetCachedModelsInternal(core_, *logger_); + for (auto& id : cachedModels) { + StripSuffixAfterColon(id); + if (id == info_.name) { + return true; + } + } + return false; + } + + void ModelVariant::Download(DownloadProgressCallback onProgress) const { + if (IsCached()) { + logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); + return; + } + + if (onProgress) { + struct ProgressState { + DownloadProgressCallback* cb; + ILogger* logger; + } state{ &onProgress, logger_ }; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + auto* st = static_cast(user); + std::string perc(static_cast(data), static_cast((std::min)(4, static_cast(len)))); + try { + float value = std::stof(perc); + (*(st->cb))(value); + } catch (...) { + st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + } + }; + + CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, + reinterpret_cast(+nativeCallback), reinterpret_cast(&state)); + } else { + CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); + } + } + + void ModelVariant::Load() const { + CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); + } + + const std::filesystem::path& ModelVariant::GetPath() const { + if (cachedPath_.empty()) { + cachedPath_ = std::filesystem::path(CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_)); + } + return cachedPath_; + } + + const std::string& ModelVariant::GetId() const noexcept { + return info_.id; + } + + const std::string& ModelVariant::GetAlias() const noexcept { + return info_.alias; + } + + uint32_t ModelVariant::GetVersion() const noexcept { + return info_.version; + } + + AudioClient::AudioClient(gsl::not_null model) + : AudioClient(model->core_, model->info_.name, model->logger_) { + if (!model->IsLoaded()) { + throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); + } + } + + AudioClient ModelVariant::GetAudioClient() const { + return AudioClient(this); + } + + ChatClient::ChatClient(gsl::not_null model) + : ChatClient(model->core_, model->info_.name, model->logger_) { + if (!model->IsLoaded()) { + throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); + } + } + + ChatClient ModelVariant::GetChatClient() const { + return ChatClient(this); + } + + /// + /// Model + /// + Model::Model(gsl::not_null core, gsl::not_null logger) + : core_(core), logger_(logger) { + } + + ModelVariant& Model::SelectedVariant() { + if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + throw FoundryLocalException("Model has no selected variant", *logger_); + } + return variants_[*selectedVariantIndex_]; + } + + const ModelVariant& Model::SelectedVariant() const { + if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + throw FoundryLocalException("Model has no selected variant", *logger_); + } + return variants_[*selectedVariantIndex_]; + } + + gsl::span Model::GetAllModelVariants() const { + return variants_; + } + + const ModelVariant* Model::GetLatestVariant(gsl::not_null variant) const { + const auto& targetName = variant->GetInfo().name; + + for (const auto& v : variants_) { + if (v.GetInfo().name == targetName) { + return &v; + } + } + + throw FoundryLocalException( + "Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", *logger_); + } + + const std::string& Model::GetId() const { + return SelectedVariant().GetId(); + } + + const std::string& Model::GetAlias() const { + return SelectedVariant().GetAlias(); + } + + void Model::SelectVariant(gsl::not_null variant) const { + auto it = std::find_if(variants_.begin(), variants_.end(), + [&](const ModelVariant& v) { return &v == variant.get(); }); + + if (it == variants_.end()) { + throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", + *logger_); + } + + selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); + } + + /// + /// Catalog + /// + + Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) + : core_(injected), logger_(logger) { + try { + name_ = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); + } + catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error getting catalog name: ") + ex.what(), *logger_); + } + } + + std::vector Catalog::GetLoadedModels() const { + return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); + } + + std::vector Catalog::GetCachedModels() const { + return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); + } + + const Model* Catalog::GetModel(std::string_view modelId) const { + auto it = byAlias_.find(std::string(modelId)); + if (it != byAlias_.end()) { + return &it->second; + } + return nullptr; + } + + std::vector Catalog::ListModels() const { + UpdateModels(); + + std::vector out; + out.reserve(byAlias_.size()); + for (auto& kv : byAlias_) + out.emplace_back(&kv.second); + + return out; + } + + void Catalog::UpdateModels() const { + using clock = std::chrono::steady_clock; + + // TODO: make this configurable + constexpr auto kRefreshInterval = std::chrono::hours(6); + + const auto now = clock::now(); + if (lastFetch_.time_since_epoch() != clock::duration::zero() && (now - lastFetch_) < kRefreshInterval) { + return; + } + + const std::string raw = core_->call("get_model_list", *logger_); + const auto arr = nlohmann::json::parse(raw); + + byAlias_.clear(); + modelIdToModelVariant_.clear(); + + for (const auto& j : arr) { + const std::string alias = j.at("alias").get(); + if (alias.rfind("openai-", 0) == 0) + continue; + + auto it = byAlias_.find(alias); + if (it == byAlias_.end()) { + Model m(core_, logger_); + it = byAlias_.emplace(alias, std::move(m)).first; + } + + ModelInfo modelVariantInfo; + from_json(j, modelVariantInfo); + ModelVariant modelVariant(core_, modelVariantInfo, logger_); + it->second.variants_.emplace_back(std::move(modelVariant)); + + for (const auto& v : it->second.variants_) { + modelIdToModelVariant_[v.GetInfo().name] = &v; + } + it->second.selectedVariantIndex_ = 0; + } + + lastFetch_ = now; + } + + const ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + auto it = modelIdToModelVariant_.find(std::string(id)); + if (it != modelIdToModelVariant_.end()) { + return it->second; + } + return nullptr; + } + + /// + /// FoundryLocalManager + /// + + FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) + : config_(std::move(configuration)), core_(std::make_unique()), logger_(logger ? logger : &defaultLogger_) { + static_cast(core_.get())->loadEmbedded(); + Initialize(); + catalog_ = Catalog::Create(core_.get(), logger_); + } + + FoundryLocalManager::FoundryLocalManager(FoundryLocalManager&& other) noexcept + : config_(std::move(other.config_)), + core_(std::move(other.core_)), + catalog_(std::move(other.catalog_)), + logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), + urls_(std::move(other.urls_)) { + other.logger_ = &other.defaultLogger_; + } + + FoundryLocalManager& FoundryLocalManager::operator=(FoundryLocalManager&& other) noexcept { + if (this != &other) { + config_ = std::move(other.config_); + core_ = std::move(other.core_); + catalog_ = std::move(other.catalog_); + logger_ = other.OwnsLogger() ? &defaultLogger_ : other.logger_; + urls_ = std::move(other.urls_); + other.logger_ = &other.defaultLogger_; + } + return *this; + } + + FoundryLocalManager::~FoundryLocalManager() { + // Unload all loaded models before tearing down. + if (catalog_) { + try { + auto loadedModels = catalog_->GetLoadedModels(); + for (const auto* variant : loadedModels) { + try { + variant->Unload(); + } catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error unloading model during destruction: ") + ex.what()); + } + } + } catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error retrieving loaded models during destruction: ") + ex.what()); + } + } + + if (!urls_.empty()) { + try { + StopWebService(); + } catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, std::string("Error stopping web service during destruction: ") + ex.what()); + } + } + } + + const Catalog& FoundryLocalManager::GetCatalog() const { + return *catalog_; + } + + void FoundryLocalManager::StartWebService() { + if (!config_.web) { + throw FoundryLocalException("Web service configuration was not provided.", *logger_); + } + + try { + std::string raw = core_->call("start_service", *logger_); + auto arr = nlohmann::json::parse(raw); + urls_ = arr.get>(); + } catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error starting web service: ") + ex.what(), *logger_); + } + } + + void FoundryLocalManager::StopWebService() { + if (!config_.web) { + throw FoundryLocalException("Web service configuration was not provided.", *logger_); + } + + try { + core_->call("stop_service", *logger_); + urls_.clear(); + } catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error stopping web service: ") + ex.what(), *logger_); + } + } + + gsl::span FoundryLocalManager::GetUrls() const noexcept { + return urls_; + } + + void FoundryLocalManager::EnsureEpsDownloaded() const { + try { + core_->call("ensure_eps_downloaded", *logger_); + } catch (const std::exception& ex) { + throw FoundryLocalException( + std::string("Error ensuring execution providers downloaded: ") + ex.what(), *logger_); + } + } + + void FoundryLocalManager::Initialize() { + config_.Validate(); + + try { + CoreInteropRequest initReq("initialize"); + initReq.AddParam("AppName", config_.app_name); + initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); + + if (config_.app_data_dir) { + initReq.AddParam("AppDataDir", config_.app_data_dir->string()); + } + if (config_.logs_dir) { + initReq.AddParam("LogsDir", config_.logs_dir->string()); + } + if (config_.web && config_.web->urls) { + initReq.AddParam("WebServiceUrls", *config_.web->urls); + } + if (config_.additional_settings) { + for (const auto& [key, value] : *config_.additional_settings) { + if (!key.empty()) { + initReq.AddParam(key, value); + } + } + } + + std::string initJson = initReq.ToJson(); + core_->call(initReq.Command(), *logger_, &initJson); + + if (config_.model_cache_dir) { + std::string current = core_->call("get_cache_directory", *logger_); + + if (current != config_.model_cache_dir->string()) { + CoreInteropRequest setReq("set_cache_directory"); + setReq.AddParam("Directory", config_.model_cache_dir->string()); + std::string setJson = setReq.ToJson(); + core_->call(setReq.Command(), *logger_, &setJson); + + logger_->Log(LogLevel::Information, + std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); + } + else { + logger_->Log(LogLevel::Information, std::string("Model cache directory already set to: ") + current); + } + } + } + catch (const std::exception& ex) { + throw FoundryLocalException(std::string("FoundryLocalManager::Initialize failed: ") + ex.what(), *logger_); + } + } + +} // namespace FoundryLocal diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp new file mode 100644 index 00000000..e40d7c11 --- /dev/null +++ b/sdk/cpp/test/catalog_test.cpp @@ -0,0 +1,372 @@ +#include + +#include +#include +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +using Factory = MockObjectFactory; + +class CatalogTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeModelListJson(const std::vector>& models) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& [name, alias] : models) { + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson(name, alias))); + } + return arr.dump(); + } + + std::unique_ptr MakeCatalog() { + core_.OnCall("get_catalog_name", "test-catalog"); + return Factory::CreateCatalog(&core_, &logger_); + } +}; + +TEST_F(CatalogTest, GetName) { + auto catalog = MakeCatalog(); + EXPECT_EQ("test-catalog", catalog->GetName()); +} + +TEST_F(CatalogTest, Create_ThrowsOnCoreError) { + core_.OnCallThrow("get_catalog_name", "catalog error"); + EXPECT_THROW(MockObjectFactory::CreateCatalog(&core_, &logger_), FoundryLocalException); +} + +TEST_F(CatalogTest, ListModels_Empty) { + core_.OnCall("get_model_list", "[]"); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + EXPECT_TRUE(models.empty()); +} + +TEST_F(CatalogTest, ListModels_SingleModel) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("my-model", models[0]->GetAlias()); +} + +TEST_F(CatalogTest, ListModels_MultipleVariantsSameAlias) { + // Two variants of the same model (same alias, different names) + nlohmann::json arr = nlohmann::json::array(); + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson("model-v1", "my-model", 1))); + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson("model-v2", "my-model", 2))); + core_.OnCall("get_model_list", arr.dump()); + + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + + // Should be grouped into one Model + ASSERT_EQ(1u, models.size()); + EXPECT_EQ(2u, models[0]->GetAllModelVariants().size()); +} + +TEST_F(CatalogTest, ListModels_DifferentAliases) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "alias-a"}, {"model-b", "alias-b"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + EXPECT_EQ(2u, models.size()); +} + +TEST_F(CatalogTest, ListModels_FiltersOpenAIPrefix) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "my-model"}, {"openai-model", "openai-stuff"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("my-model", models[0]->GetAlias()); +} + +TEST_F(CatalogTest, GetModel_Found) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto* model = catalog->GetModel("my-model"); + ASSERT_NE(nullptr, model); + EXPECT_EQ("my-model", model->GetAlias()); +} + +TEST_F(CatalogTest, GetModel_NotFound) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); +} + +TEST_F(CatalogTest, GetModelVariant_Found) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto* variant = catalog->GetModelVariant("model-1"); + ASSERT_NE(nullptr, variant); + EXPECT_EQ("model-1", variant->GetId()); +} + +TEST_F(CatalogTest, GetModelVariant_NotFound) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); + + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent")); +} + +TEST_F(CatalogTest, GetLoadedModels) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); + core_.OnCall("list_loaded_models", R"(["model-1:v1"])"); + + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto loaded = catalog->GetLoadedModels(); + ASSERT_EQ(1u, loaded.size()); + EXPECT_EQ("model-1", loaded[0]->GetId()); +} + +TEST_F(CatalogTest, GetCachedModels) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); + core_.OnCall("get_cached_models", R"(["model-1:1", "model-2:1"])"); + + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto cached = catalog->GetCachedModels(); + EXPECT_EQ(2u, cached.size()); +} + +TEST_F(CatalogTest, ListModels_CachesResults) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + catalog->ListModels(); + catalog->ListModels(); + + // Should only call get_model_list once due to caching + EXPECT_EQ(1, core_.GetCallCount("get_model_list")); +} + +class FileBasedCatalogTest : public ::testing::Test { +protected: + NullLogger logger_; + + static std::string TestDataPath(const std::string& filename) { return "testdata/" + filename; } +}; + +TEST_F(FileBasedCatalogTest, RealModelsList) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(2u, models.size()); + + int phi_models = 0, mistral_models = 0; + size_t phi_variants = 0, mistral_variants = 0; + + for (const auto* model : models) { + if (model->GetAlias() == "phi-4") { + phi_models++; + phi_variants = model->GetAllModelVariants().size(); + } + else if (model->GetAlias() == "mistral-7b-v0.2") { + mistral_models++; + mistral_variants = model->GetAllModelVariants().size(); + } + } + + EXPECT_EQ(1, phi_models); + EXPECT_EQ(1, mistral_models); + EXPECT_EQ(2u, phi_variants); + EXPECT_EQ(2u, mistral_variants); +} + +TEST_F(FileBasedCatalogTest, RealModelsList_VariantDetails) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + const auto* gpuVariant = catalog->GetModelVariant("Phi-4-generic-gpu"); + ASSERT_NE(nullptr, gpuVariant); + + const auto& info = gpuVariant->GetInfo(); + EXPECT_EQ("Phi-4-generic-gpu", info.id); + EXPECT_EQ("Phi-4-generic-gpu", info.name); + EXPECT_EQ("phi-4", info.alias); + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("Phi-4 (GPU)", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("Microsoft", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("MIT", *info.license); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::GPU, info.runtime->device_type); + EXPECT_EQ("DML", info.runtime->execution_provider); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(8192u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(4096, *info.max_output_tokens); + ASSERT_TRUE(info.prompt_template.has_value()); + EXPECT_EQ("<|system|>", info.prompt_template->system); + EXPECT_EQ("<|user|>", info.prompt_template->user); + EXPECT_EQ("<|assistant|>", info.prompt_template->assistant); + EXPECT_EQ("<|prompt|>", info.prompt_template->prompt); +} + +TEST_F(FileBasedCatalogTest, RealModelsList_CpuVariantDetails) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + const auto* cpuVariant = catalog->GetModelVariant("Phi-4-generic-cpu"); + ASSERT_NE(nullptr, cpuVariant); + + const auto& info = cpuVariant->GetInfo(); + EXPECT_EQ("Phi-4-generic-cpu", info.name); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::CPU, info.runtime->device_type); + EXPECT_EQ("ORT", info.runtime->execution_provider); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(4096u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(*info.supports_tool_calling); + EXPECT_FALSE(info.prompt_template.has_value()); +} + +TEST_F(FileBasedCatalogTest, EmptyModelsList) { + auto core = FileBackedCore::FromModelList(TestDataPath("empty_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + EXPECT_TRUE(models.empty()); +} + +TEST_F(FileBasedCatalogTest, MalformedJson) { + auto core = FileBackedCore::FromModelList(TestDataPath("malformed_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_ANY_THROW(catalog->ListModels()); +} + +TEST_F(FileBasedCatalogTest, MissingNameField) { + auto core = FileBackedCore::FromModelList(TestDataPath("missing_name_field_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + try { + catalog->ListModels(); + FAIL() << "Expected exception for missing 'name' field"; + } + catch (const std::exception& e) { + std::string msg = e.what(); + EXPECT_NE(std::string::npos, msg.find("name")) << "Actual: " << msg; + } +} + +TEST_F(FileBasedCatalogTest, CachedModels) { + auto core = + FileBackedCore::FromBoth(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate internal maps + + auto cached = catalog->GetCachedModels(); + ASSERT_EQ(2u, cached.size()); + + std::vector names; + names.reserve(cached.size()); + for (const auto* mv : cached) + names.push_back(mv->GetInfo().name); + + EXPECT_NE(std::find(names.begin(), names.end(), "Phi-4-generic-gpu"), names.end()); + EXPECT_NE(std::find(names.begin(), names.end(), "Phi-4-generic-cpu"), names.end()); +} + +TEST_F(FileBasedCatalogTest, CoreErrorOnModelList) { + auto core = FileBackedCore::FromModelList("testdata/nonexistent_file.json"); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_ANY_THROW(catalog->ListModels()); +} + +TEST_F(FileBasedCatalogTest, MixedOpenAIAndLocal_FiltersOpenAIPrefix) { + auto core = FileBackedCore::FromModelList(TestDataPath("mixed_openai_and_local.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("phi-4", models[0]->GetAlias()); +} + +TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel) { + auto core = FileBackedCore::FromModelList(TestDataPath("three_variants_one_model.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ(3u, models[0]->GetAllModelVariants().size()); +} + +TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel_CachedSubset) { + auto core = FileBackedCore::FromBoth(TestDataPath("three_variants_one_model.json"), + TestDataPath("single_cached_model.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + auto cached = catalog->GetCachedModels(); + ASSERT_EQ(1u, cached.size()); + EXPECT_EQ("multi-v1-cpu", cached[0]->GetInfo().name); +} + +TEST_F(FileBasedCatalogTest, GetModelByAlias) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + const auto* model = catalog->GetModel("phi-4"); + ASSERT_NE(nullptr, model); + EXPECT_EQ("phi-4", model->GetAlias()); + EXPECT_EQ(2u, model->GetAllModelVariants().size()); + + const auto* missing = catalog->GetModel("nonexistent-alias"); + EXPECT_EQ(nullptr, missing); +} + +TEST_F(FileBasedCatalogTest, GetModelVariant_NotInCatalog) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent-variant-id")); +} + +TEST_F(FileBasedCatalogTest, LoadedModels) { + auto core = FileBackedCore::FromAll(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json"), + TestDataPath("valid_loaded_models.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + auto loaded = catalog->GetLoadedModels(); + ASSERT_EQ(1u, loaded.size()); + EXPECT_EQ("Phi-4-generic-gpu", loaded[0]->GetInfo().name); +} diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp new file mode 100644 index 00000000..0857bc92 --- /dev/null +++ b/sdk/cpp/test/client_test.cpp @@ -0,0 +1,541 @@ +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +using Factory = MockObjectFactory; + +class ChatClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeChatResponseJson(const std::string& content = "Hello!") { + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-test"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", content}}}}}}}; + return resp.dump(); + } + + ModelVariant MakeLoadedVariant(const std::string& name = "chat-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":v1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(ChatClientTest, CompleteChat_BasicResponse) { + core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "Say hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + EXPECT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("Hello world!", response.choices[0].message->content); +} + +TEST_F(ChatClientTest, CompleteChat_WithSettings) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 100; + settings.top_p = 0.9f; + settings.frequency_penalty = 0.5f; + settings.presence_penalty = 0.3f; + settings.n = 2; + settings.random_seed = 42; + settings.top_k = 10; + + auto response = client.CompleteChat(messages, settings); + + // Verify the request JSON contains the settings + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_NEAR(0.7f, openAiReq["temperature"].get(), 0.001f); + EXPECT_EQ(100, openAiReq["max_completion_tokens"].get()); + EXPECT_NEAR(0.9f, openAiReq["top_p"].get(), 0.001f); + EXPECT_NEAR(0.5f, openAiReq["frequency_penalty"].get(), 0.001f); + EXPECT_NEAR(0.3f, openAiReq["presence_penalty"].get(), 0.001f); + EXPECT_EQ(2, openAiReq["n"].get()); + EXPECT_EQ(42, openAiReq["seed"].get()); + EXPECT_EQ(10, openAiReq["metadata"]["top_k"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_RequestFormat) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_EQ("chat-model", openAiReq["model"].get()); + EXPECT_FALSE(openAiReq["stream"].get()); + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_EQ("system", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(ChatClientTest, CompleteChatStreaming) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hello"}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", {{{"index", 0}, {"finish_reason", "stop"}, {"delta", {{"content", " world"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + cb(s1.data(), static_cast(s1.size()), userData); + cb(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + std::vector chunks; + client.CompleteChatStreaming(messages, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_TRUE(chunks[0].is_delta); + ASSERT_TRUE(chunks[0].choices[0].delta.has_value()); + EXPECT_EQ("Hello", chunks[0].choices[0].delta->content); + EXPECT_EQ(" world", chunks[1].choices[0].delta->content); +} + +TEST_F(ChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { + nlohmann::json chunk = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string s = chunk.dump(); + cb(s.data(), static_cast(s.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChatStreaming( + messages, settings, + [](const ChatCompletionCreateResponse&) { throw std::runtime_error("callback error"); }), + std::runtime_error); +} + +TEST_F(ChatClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(ChatClient client(&variant), FoundryLocalException); +} + +TEST_F(ChatClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + EXPECT_EQ("chat-model", client.GetModelId()); +} + +// ---------- Tool calling tests ---------- + +TEST_F(ChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", + "A tool for multiplying two numbers.", + PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} + } + } + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + // Verify the request JSON contains tools and tool_choice + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_TRUE(openAiReq.contains("tools")); + ASSERT_TRUE(openAiReq["tools"].is_array()); + EXPECT_EQ(1u, openAiReq["tools"].size()); + EXPECT_EQ("function", openAiReq["tools"][0]["type"].get()); + EXPECT_EQ("multiply_numbers", openAiReq["tools"][0]["function"]["name"].get()); + EXPECT_EQ("A tool for multiplying two numbers.", openAiReq["tools"][0]["function"]["description"].get()); + EXPECT_EQ("object", openAiReq["tools"][0]["function"]["parameters"]["type"].get()); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"].contains("properties")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("first")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("second")); + + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_FALSE(openAiReq.contains("tools")); + EXPECT_FALSE(openAiReq.contains("tool_choice")); +} + +TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { + // Simulate a response with tool calls from the model + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", resp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_TRUE(response.choices[0].message.has_value()); + + const auto& msg = *response.choices[0].message; + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_1", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("multiply_numbers", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"first\": 7, \"second\": 6}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(ChatClientTest, CompleteChat_ToolChoiceAuto) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Auto; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_ToolChoiceNone) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::None; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("none", openAiReq["tool_choice"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_ToolMessageWithToolCallId) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "42"; + toolMsg.tool_call_id = "call_1"; + + std::vector messages = { + {"user", "What is 7 * 6?", {}}, + std::move(toolMsg) + }; + ChatSettings settings; + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_FALSE(openAiReq["messages"][0].contains("tool_call_id")); + EXPECT_EQ("call_1", openAiReq["messages"][1]["tool_call_id"].get()); + EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(ChatClientTest, CompleteChatStreaming_WithTools) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", nullptr}, + {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"delta", + {{"content", ""}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + cb(s1.data(), static_cast(s1.size()), userData); + cb(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + + std::vector tools = {{ + "function", + FunctionDefinition{"multiply", "Multiply numbers."} + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + std::vector chunks; + client.CompleteChatStreaming(messages, tools, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ(FinishReason::ToolCalls, chunks[1].choices[0].finish_reason); + ASSERT_TRUE(chunks[1].choices[0].delta.has_value()); + ASSERT_EQ(1u, chunks[1].choices[0].delta->tool_calls.size()); + EXPECT_EQ("multiply", chunks[1].choices[0].delta->tool_calls[0].function_call->name); + + // Verify tools were included in the request + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_TRUE(openAiReq.contains("tools")); + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +class AudioClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeLoadedVariant(const std::string& name = "audio-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":v1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(AudioClientTest, TranscribeAudio) { + core_.OnCall("audio_transcribe", "Hello world transcribed text"); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + auto response = client.TranscribeAudio("test.wav"); + + EXPECT_EQ("Hello world transcribed text", response.text); +} + +TEST_F(AudioClientTest, TranscribeAudio_RequestFormat) { + core_.OnCall("audio_transcribe", "text"); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + client.TranscribeAudio("audio.wav"); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("audio-model", openAiReq["Model"].get()); + EXPECT_EQ("audio.wav", openAiReq["FileName"].get()); +} + +TEST_F(AudioClientTest, TranscribeAudioStreaming) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string text1 = "Hello "; + std::string text2 = "world!"; + cb(text1.data(), static_cast(text1.size()), userData); + cb(text2.data(), static_cast(text2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + + std::vector chunks; + client.TranscribeAudioStreaming( + "test.wav", [&](const AudioCreateTranscriptionResponse& chunk) { chunks.push_back(chunk.text); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ("Hello ", chunks[0]); + EXPECT_EQ("world!", chunks[1]); +} + +TEST_F(AudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string text = "test"; + cb(text.data(), static_cast(text.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + + EXPECT_THROW( + client.TranscribeAudioStreaming( + "test.wav", [](const AudioCreateTranscriptionResponse&) { throw std::runtime_error("streaming error"); }), + std::runtime_error); +} + +TEST_F(AudioClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(AudioClient client(&variant), FoundryLocalException); +} + +TEST_F(AudioClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + EXPECT_EQ("audio-model", client.GetModelId()); +} diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h new file mode 100644 index 00000000..b7aa349d --- /dev/null +++ b/sdk/cpp/test/mock_core.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "foundry_local_internal_core.h" +#include "logger.h" + +namespace FoundryLocal::Testing { + + /// A mock implementation of IFoundryLocalCore for unit testing. + /// Register expected command -> response mappings before use. + class MockCore final : public Internal::IFoundryLocalCore { + public: + using CallbackFn = void (*)(void*, int32_t, void*); + + /// Handler signature: (command, dataArgument, callback, userData) -> response string. + using Handler = std::function; + + /// Register a fixed response for a command. + void OnCall(std::string command, std::string response) { + handlers_[std::move(command)] = [r = std::move(response)](std::string_view, const std::string*, void*, + void*) { return r; }; + } + + /// Register a custom handler for a command. + void OnCall(std::string command, Handler handler) { handlers_[std::move(command)] = std::move(handler); } + + /// Register a handler that throws for a command. + void OnCallThrow(std::string command, std::string errorMessage) { + handlers_[std::move(command)] = [msg = std::move(errorMessage)](std::string_view, const std::string*, void*, + void*) -> std::string { + throw std::runtime_error(msg); + }; + } + + /// Returns the number of times a command was called. + int GetCallCount(const std::string& command) const { + auto it = callCounts_.find(command); + return it != callCounts_.end() ? it->second : 0; + } + + /// Returns the last data argument passed for a command. + const std::string& GetLastDataArg(const std::string& command) const { + auto it = lastDataArgs_.find(command); + if (it == lastDataArgs_.end()) { + static const std::string empty; + return empty; + } + return it->second; + } + + // IFoundryLocalCore implementation + std::string call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, + void* callback = nullptr, void* data = nullptr) const override { + + std::string cmd(command); + const_cast(this)->callCounts_[cmd]++; + if (dataArgument) { + const_cast(this)->lastDataArgs_[cmd] = *dataArgument; + } + + auto it = handlers_.find(cmd); + if (it == handlers_.end()) { + throw std::runtime_error("MockCore: no handler registered for command '" + cmd + "'"); + } + + return it->second(command, dataArgument, callback, data); + } + + void unload() override {} + + private: + std::unordered_map handlers_; + std::unordered_map callCounts_; + std::unordered_map lastDataArgs_; + }; + + /// Read a file into a string. Throws on failure. + inline std::string ReadFile(const std::string& path) { + std::ifstream in(path, std::ios::in | std::ios::binary); + if (!in) + throw std::runtime_error("Failed to open test data file: " + path); + std::ostringstream contents; + contents << in.rdbuf(); + return contents.str(); + } + + /// A mock core that reads model list, cached models and loaded models from JSON files on disk. + class FileBackedCore final : public Internal::IFoundryLocalCore { + public: + FileBackedCore(std::string modelListPath, std::string cachedModelsPath, std::string loadedModelsPath = "") + : modelListPath_(std::move(modelListPath)), cachedModelsPath_(std::move(cachedModelsPath)), + loadedModelsPath_(std::move(loadedModelsPath)) {} + + static FileBackedCore FromModelList(const std::string& path) { return FileBackedCore(path, ""); } + + static FileBackedCore FromBoth(const std::string& modelListPath, const std::string& cachedModelsPath) { + return FileBackedCore(modelListPath, cachedModelsPath); + } + + static FileBackedCore FromAll(const std::string& modelListPath, const std::string& cachedModelsPath, + const std::string& loadedModelsPath) { + return FileBackedCore(modelListPath, cachedModelsPath, loadedModelsPath); + } + + std::string call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, + void* /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { + + if (command == "get_catalog_name") + return "TestCatalog"; + + if (command == "get_model_list") { + if (modelListPath_.empty()) + return "[]"; + return ReadFile(modelListPath_); + } + + if (command == "get_cached_models") { + if (cachedModelsPath_.empty()) + return "[]"; + return ReadFile(cachedModelsPath_); + } + + if (command == "list_loaded_models") { + if (loadedModelsPath_.empty()) + return "[]"; + return ReadFile(loadedModelsPath_); + } + + return "{}"; + } + + void unload() override {} + + private: + std::string modelListPath_; + std::string cachedModelsPath_; + std::string loadedModelsPath_; + }; + +} // namespace FoundryLocal::Testing diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h new file mode 100644 index 00000000..9d029aec --- /dev/null +++ b/sdk/cpp/test/mock_object_factory.h @@ -0,0 +1,61 @@ +#pragma once + +#ifndef FL_TESTS +#define FL_TESTS +#endif + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "logger.h" + +namespace FoundryLocal::Testing { + + /// Factory to construct private-constructor types for testing. + /// Declared as a friend (Testing::MockObjectFactory) in ModelVariant, Model, and Catalog when FL_TESTS is defined. + struct MockObjectFactory { + static ModelVariant CreateModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) { + return ModelVariant(core, std::move(info), logger); + } + + static std::unique_ptr CreateCatalog(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + static Model CreateModel(gsl::not_null core, gsl::not_null logger) { + return Model(core, logger); + } + + /// Push a variant into a Model's internal variant list. + static void AddVariantToModel(Model& model, ModelVariant variant) { + model.variants_.push_back(std::move(variant)); + } + + /// Set the selected variant index on a Model. + static void SetSelectedVariantIndex(Model& model, size_t index) { model.selectedVariantIndex_ = index; } + + /// Helper to build a minimal ModelInfo with defaults. + static ModelInfo MakeModelInfo(std::string name, std::string alias = "", uint32_t version = 1) { + ModelInfo info; + info.id = name; + info.name = std::move(name); + info.alias = alias.empty() ? info.name : std::move(alias); + info.version = version; + info.provider_type = "test"; + info.uri = "test://uri"; + info.model_type = "text"; + return info; + } + + /// Helper to build a JSON string representing a model list entry. + static std::string MakeModelInfoJson(const std::string& name, const std::string& alias = "", + uint32_t version = 1, bool cached = false) { + std::string a = alias.empty() ? name : alias; + return R"({"id":")" + name + R"(","name":")" + name + R"(","version":)" + std::to_string(version) + + R"(,"alias":")" + a + R"(","providerType":"test","uri":"test://uri","modelType":"text","cached":)" + + (cached ? "true" : "false") + R"(,"createdAt":0})"; + } + }; + +} // namespace FoundryLocal::Testing diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp new file mode 100644 index 00000000..7660207c --- /dev/null +++ b/sdk/cpp/test/model_variant_test.cpp @@ -0,0 +1,251 @@ +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +using Factory = MockObjectFactory; + +class ModelVariantTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeVariant(std::string name = "test-model", std::string alias = "test-alias", uint32_t version = 1) { + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); + } +}; + +TEST_F(ModelVariantTest, GetInfo) { + auto variant = MakeVariant("my-model", "my-alias", 3); + const auto& info = variant.GetInfo(); + EXPECT_EQ("my-model", info.name); + EXPECT_EQ("my-alias", info.alias); + EXPECT_EQ(3u, info.version); +} + +TEST_F(ModelVariantTest, GetId) { + auto variant = MakeVariant("my-model"); + EXPECT_EQ("my-model", variant.GetId()); +} + +TEST_F(ModelVariantTest, GetAlias) { + auto variant = MakeVariant("name", "alias"); + EXPECT_EQ("alias", variant.GetAlias()); +} + +TEST_F(ModelVariantTest, GetVersion) { + auto variant = MakeVariant("name", "alias", 5); + EXPECT_EQ(5u, variant.GetVersion()); +} + +TEST_F(ModelVariantTest, IsLoaded_True) { + core_.OnCall("list_loaded_models", R"(["test-model:v1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_TRUE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsLoaded_False) { + core_.OnCall("list_loaded_models", R"(["other-model:v1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsLoaded_EmptyList) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsCached_True) { + core_.OnCall("get_cached_models", R"(["test-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_TRUE(variant.IsCached()); +} + +TEST_F(ModelVariantTest, IsCached_False) { + core_.OnCall("get_cached_models", R"(["other-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsCached()); +} + +TEST_F(ModelVariantTest, Load_CallsCore) { + core_.OnCall("load_model", ""); + auto variant = MakeVariant("test-model"); + variant.Load(); + EXPECT_EQ(1, core_.GetCallCount("load_model")); + + // Verify the data argument contains the model name + auto parsed = nlohmann::json::parse(core_.GetLastDataArg("load_model")); + EXPECT_EQ("test-model", parsed["Params"]["Model"].get()); +} + +TEST_F(ModelVariantTest, Unload_CallsCore) { + core_.OnCall("unload_model", ""); + auto variant = MakeVariant("test-model"); + variant.Unload(); + EXPECT_EQ(1, core_.GetCallCount("unload_model")); +} + +TEST_F(ModelVariantTest, Unload_ThrowsOnError) { + core_.OnCallThrow("unload_model", "unload failed"); + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Unload(), FoundryLocalException); +} + +TEST_F(ModelVariantTest, Download_NoCallback) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", ""); + auto variant = MakeVariant("test-model"); + variant.Download(); + EXPECT_EQ(1, core_.GetCallCount("download_model")); +} + +TEST_F(ModelVariantTest, Download_WithCallback) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + // Simulate calling the progress callback + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string progress = "50"; + cb(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + float lastProgress = -1.0f; + variant.Download([&](float pct) { lastProgress = pct; }); + EXPECT_NEAR(50.0f, lastProgress, 0.01f); +} + +TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { + core_.OnCall("remove_cached_model", ""); + auto variant = MakeVariant("test-model"); + variant.RemoveFromCache(); + EXPECT_EQ(1, core_.GetCallCount("remove_cached_model")); +} + +TEST_F(ModelVariantTest, RemoveFromCache_ThrowsOnError) { + core_.OnCallThrow("remove_cached_model", "remove failed"); + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.RemoveFromCache(), FoundryLocalException); +} + +TEST_F(ModelVariantTest, GetPath_CallsCore) { + core_.OnCall("get_model_path", R"(C:\models\test)"); + auto variant = MakeVariant("test-model"); + const auto& path = variant.GetPath(); + EXPECT_EQ(std::filesystem::path(R"(C:\models\test)"), path); +} + +TEST_F(ModelVariantTest, GetPath_CachesResult) { + core_.OnCall("get_model_path", R"(C:\models\test)"); + auto variant = MakeVariant("test-model"); + variant.GetPath(); + variant.GetPath(); + // Should only call once due to caching + EXPECT_EQ(1, core_.GetCallCount("get_model_path")); +} + +class ModelTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + Model MakeModel() { return Factory::CreateModel(&core_, &logger_); } + + ModelVariant MakeVariant(std::string name = "test-model", std::string alias = "test-alias", uint32_t version = 1) { + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); + } + + /// Helper: create a Model with one variant and selectedVariantIndex_=0. + Model MakeModelWithVariant(const std::string& name = "test-model", const std::string& alias = "test-alias") { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant(name, alias, 1)); + Factory::SetSelectedVariantIndex(model, 0); + return model; + } +}; + +TEST_F(ModelTest, SelectedVariant_ThrowsWhenEmpty) { + auto model = MakeModel(); + EXPECT_THROW(model.GetId(), FoundryLocalException); +} + +TEST_F(ModelTest, AddVariant_AndSelect) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SetSelectedVariantIndex(model, 0); + + EXPECT_EQ("v1", model.GetId()); + EXPECT_EQ("alias", model.GetAlias()); +} + +TEST_F(ModelTest, GetAllModelVariants) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SetSelectedVariantIndex(model, 0); + + auto variants = model.GetAllModelVariants(); + EXPECT_EQ(2u, variants.size()); +} + +TEST_F(ModelTest, SelectVariant) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SetSelectedVariantIndex(model, 0); + + const auto* v2 = &model.GetAllModelVariants()[1]; + model.SelectVariant(v2); + EXPECT_EQ("v2", model.GetId()); +} + +TEST_F(ModelTest, SelectVariant_NotFound_Throws) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SetSelectedVariantIndex(model, 0); + + auto external = MakeVariant("external", "alias", 1); + EXPECT_THROW(model.SelectVariant(&external), FoundryLocalException); +} + +TEST_F(ModelTest, GetLatestVariant) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); + Factory::SetSelectedVariantIndex(model, 0); + + const auto* first = &model.GetAllModelVariants()[0]; + const auto* latest = model.GetLatestVariant(first); + // Should return the first one with matching name (which is variants_[0]) + EXPECT_EQ(first, latest); +} + +TEST_F(ModelTest, DelegationMethods) { + // Test that Model delegates to SelectedVariant + core_.OnCall("list_loaded_models", R"(["test-model:v1"])"); + core_.OnCall("get_cached_models", R"(["test-model:1"])"); + core_.OnCall("load_model", ""); + core_.OnCall("unload_model", ""); + core_.OnCall("download_model", ""); + core_.OnCall("get_model_path", R"(C:\test)"); + + auto model = MakeModelWithVariant("test-model", "alias"); + + EXPECT_TRUE(model.IsLoaded()); + EXPECT_TRUE(model.IsCached()); + model.Load(); + model.Unload(); + model.Download(); + EXPECT_EQ(std::filesystem::path(R"(C:\test)"), model.GetPath()); +} diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp new file mode 100644 index 00000000..4515c3a0 --- /dev/null +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -0,0 +1,592 @@ +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +class ParserTest : public ::testing::Test { +protected: + static nlohmann::json MinimalModelJson() { + return nlohmann::json{{"id", "model-1"}, {"name", "model-1"}, {"version", 1}, + {"alias", "my-model"}, {"providerType", "onnx"}, {"uri", "https://example.com/model"}, + {"modelType", "text"}, {"cached", false}, {"createdAt", 1700000000}}; + } +}; + +TEST_F(ParserTest, ParseDeviceType_CPU) { + EXPECT_EQ(DeviceType::CPU, parse_device_type("CPU")); +} + +TEST_F(ParserTest, ParseDeviceType_GPU) { + EXPECT_EQ(DeviceType::GPU, parse_device_type("GPU")); +} + +TEST_F(ParserTest, ParseDeviceType_NPU) { + EXPECT_EQ(DeviceType::NPU, parse_device_type("NPU")); +} + +TEST_F(ParserTest, ParseDeviceType_Unknown) { + EXPECT_EQ(DeviceType::Invalid, parse_device_type("FPGA")); +} + +TEST_F(ParserTest, ParseFinishReason_Stop) { + EXPECT_EQ(FinishReason::Stop, parse_finish_reason("stop")); +} + +TEST_F(ParserTest, ParseFinishReason_Length) { + EXPECT_EQ(FinishReason::Length, parse_finish_reason("length")); +} + +TEST_F(ParserTest, ParseFinishReason_ToolCalls) { + EXPECT_EQ(FinishReason::ToolCalls, parse_finish_reason("tool_calls")); +} + +TEST_F(ParserTest, ParseFinishReason_ContentFilter) { + EXPECT_EQ(FinishReason::ContentFilter, parse_finish_reason("content_filter")); +} + +TEST_F(ParserTest, ParseFinishReason_None) { + EXPECT_EQ(FinishReason::None, parse_finish_reason("unknown_value")); +} + +TEST_F(ParserTest, GetStringOrEmpty_Present) { + nlohmann::json j = {{"key", "value"}}; + EXPECT_EQ("value", get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetStringOrEmpty_Missing) { + nlohmann::json j = {{"other", "value"}}; + EXPECT_EQ("", get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetStringOrEmpty_NonString) { + nlohmann::json j = {{"key", 42}}; + EXPECT_EQ("", get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetOptString_Present) { + nlohmann::json j = {{"key", "hello"}}; + auto result = get_opt_string(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ("hello", *result); +} + +TEST_F(ParserTest, GetOptString_Null) { + nlohmann::json j = {{"key", nullptr}}; + EXPECT_FALSE(get_opt_string(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptString_Missing) { + nlohmann::json j = {{"other", "v"}}; + EXPECT_FALSE(get_opt_string(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptInt_Present) { + nlohmann::json j = {{"key", 42}}; + auto result = get_opt_int(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(42, *result); +} + +TEST_F(ParserTest, GetOptInt_Missing) { + nlohmann::json j = {}; + EXPECT_FALSE(get_opt_int(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptBool_Present) { + nlohmann::json j = {{"key", true}}; + auto result = get_opt_bool(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(*result); +} + +TEST_F(ParserTest, GetOptBool_Missing) { + nlohmann::json j = {}; + EXPECT_FALSE(get_opt_bool(j, "key").has_value()); +} + +TEST_F(ParserTest, ParseRuntime) { + nlohmann::json j = {{"deviceType", "GPU"}, {"executionProvider", "DML"}}; + Runtime r = j.get(); + EXPECT_EQ(DeviceType::GPU, r.device_type); + EXPECT_EQ("DML", r.execution_provider); +} + +TEST_F(ParserTest, ParsePromptTemplate) { + nlohmann::json j = {{"system", "sys"}, {"user", "usr"}, {"assistant", "asst"}, {"prompt", "p"}}; + PromptTemplate pt = j.get(); + EXPECT_EQ("sys", pt.system); + EXPECT_EQ("usr", pt.user); + EXPECT_EQ("asst", pt.assistant); + EXPECT_EQ("p", pt.prompt); +} + +TEST_F(ParserTest, ParsePromptTemplate_MissingFields) { + nlohmann::json j = {{"system", "sys"}}; + PromptTemplate pt = j.get(); + EXPECT_EQ("sys", pt.system); + EXPECT_EQ("", pt.user); + EXPECT_EQ("", pt.assistant); + EXPECT_EQ("", pt.prompt); +} + +TEST_F(ParserTest, ParseModelInfo_Minimal) { + auto j = MinimalModelJson(); + ModelInfo info = j.get(); + EXPECT_EQ("model-1", info.id); + EXPECT_EQ("model-1", info.name); + EXPECT_EQ(1u, info.version); + EXPECT_EQ("my-model", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("https://example.com/model", info.uri); + EXPECT_EQ("text", info.model_type); + EXPECT_FALSE(info.cached); + EXPECT_EQ(1700000000, info.created_at_unix); + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(ParserTest, ParseModelInfo_WithOptionals) { + auto j = MinimalModelJson(); + j["displayName"] = "My Model"; + j["publisher"] = "TestPublisher"; + j["license"] = "MIT"; + j["fileSizeMb"] = 512; + j["supportsToolCalling"] = true; + j["maxOutputTokens"] = 4096; + j["runtime"] = {{"deviceType", "CPU"}, {"executionProvider", "ORT"}}; + + ModelInfo info = j.get(); + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("My Model", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("TestPublisher", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("MIT", *info.license); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(512u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(4096, *info.max_output_tokens); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::CPU, info.runtime->device_type); + EXPECT_EQ("ORT", info.runtime->execution_provider); +} + +TEST_F(ParserTest, ParseModelSettings) { + nlohmann::json j = {{"parameters", {{{"name", "p1"}, {"value", "v1"}}, {{"name", "p2"}}}}}; + ModelSettings ms = j.get(); + ASSERT_EQ(2u, ms.parameters.size()); + EXPECT_EQ("p1", ms.parameters[0].name); + ASSERT_TRUE(ms.parameters[0].value.has_value()); + EXPECT_EQ("v1", *ms.parameters[0].value); + EXPECT_EQ("p2", ms.parameters[1].name); + EXPECT_FALSE(ms.parameters[1].value.has_value()); +} + +TEST_F(ParserTest, ParseChatMessage) { + nlohmann::json j = {{"role", "user"}, {"content", "hello"}}; + ChatMessage msg = j.get(); + EXPECT_EQ("user", msg.role); + EXPECT_EQ("hello", msg.content); + EXPECT_TRUE(msg.tool_calls.empty()); + EXPECT_FALSE(msg.tool_call_id.has_value()); +} + +TEST_F(ParserTest, ParseChatMessage_WithToolCalls) { + nlohmann::json j = { + {"role", "assistant"}, + {"content", "I'll call a tool."}, + {"tool_calls", + {{{"id", "call_abc123"}, + {"type", "function"}, + {"function", {{"name", "get_weather"}, {"arguments", "{\"city\": \"Seattle\"}"}}}}}}}; + ChatMessage msg = j.get(); + EXPECT_EQ("assistant", msg.role); + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_abc123", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("get_weather", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"city\": \"Seattle\"}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(ParserTest, ParseChatMessage_WithToolCallId) { + nlohmann::json j = { + {"role", "tool"}, + {"content", "72 degrees and sunny"}, + {"tool_call_id", "call_abc123"}}; + ChatMessage msg = j.get(); + EXPECT_EQ("tool", msg.role); + EXPECT_EQ("72 degrees and sunny", msg.content); + ASSERT_TRUE(msg.tool_call_id.has_value()); + EXPECT_EQ("call_abc123", *msg.tool_call_id); +} + +TEST_F(ParserTest, ParseFunctionCall) { + nlohmann::json j = {{"name", "multiply"}, {"arguments", "{\"a\": 1, \"b\": 2}"}}; + FunctionCall fc = j.get(); + EXPECT_EQ("multiply", fc.name); + EXPECT_EQ("{\"a\": 1, \"b\": 2}", fc.arguments); +} + +TEST_F(ParserTest, ParseFunctionCall_ObjectArguments) { + nlohmann::json j = {{"name", "add"}, {"arguments", {{"x", 10}}}}; + FunctionCall fc = j.get(); + EXPECT_EQ("add", fc.name); + EXPECT_EQ("{\"x\":10}", fc.arguments); +} + +TEST_F(ParserTest, ParseToolCall) { + nlohmann::json j = { + {"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "search"}, {"arguments", "{\"query\": \"test\"}"}}}}; + ToolCall tc = j.get(); + EXPECT_EQ("call_1", tc.id); + EXPECT_EQ("function", tc.type); + ASSERT_TRUE(tc.function_call.has_value()); + EXPECT_EQ("search", tc.function_call->name); +} + +TEST_F(ParserTest, SerializeToolDefinition) { + ToolDefinition tool; + tool.type = "function"; + tool.function.name = "get_weather"; + tool.function.description = "Get the current weather"; + tool.function.parameters = PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"location", PropertyDefinition{"string", "The city name"}} + }, + std::vector{"location"} + }; + + nlohmann::json j; + to_json(j, tool); + + EXPECT_EQ("function", j["type"].get()); + EXPECT_EQ("get_weather", j["function"]["name"].get()); + EXPECT_EQ("Get the current weather", j["function"]["description"].get()); + EXPECT_EQ("object", j["function"]["parameters"]["type"].get()); + ASSERT_TRUE(j["function"]["parameters"]["properties"].contains("location")); + EXPECT_EQ("string", j["function"]["parameters"]["properties"]["location"]["type"].get()); + ASSERT_EQ(1u, j["function"]["parameters"]["required"].size()); + EXPECT_EQ("location", j["function"]["parameters"]["required"][0].get()); +} + +TEST_F(ParserTest, SerializeToolDefinition_MinimalFunction) { + ToolDefinition tool; + tool.function.name = "noop"; + + nlohmann::json j; + to_json(j, tool); + + EXPECT_EQ("function", j["type"].get()); + EXPECT_EQ("noop", j["function"]["name"].get()); + EXPECT_FALSE(j["function"].contains("description")); + EXPECT_FALSE(j["function"].contains("parameters")); +} + +TEST_F(ParserTest, ToolChoiceToString) { + EXPECT_EQ("auto", tool_choice_to_string(ToolChoiceKind::Auto)); + EXPECT_EQ("none", tool_choice_to_string(ToolChoiceKind::None)); + EXPECT_EQ("required", tool_choice_to_string(ToolChoiceKind::Required)); +} + +TEST_F(ParserTest, ParseChatChoice_NonStreaming) { + nlohmann::json j = { + {"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", "Hi there!"}}}}; + ChatChoice c = j.get(); + EXPECT_EQ(0, c.index); + EXPECT_EQ(FinishReason::Stop, c.finish_reason); + ASSERT_TRUE(c.message.has_value()); + EXPECT_EQ("assistant", c.message->role); + EXPECT_EQ("Hi there!", c.message->content); + EXPECT_FALSE(c.delta.has_value()); +} + +TEST_F(ParserTest, ParseChatChoice_Streaming) { + nlohmann::json j = { + {"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}; + ChatChoice c = j.get(); + EXPECT_EQ(FinishReason::None, c.finish_reason); + EXPECT_FALSE(c.message.has_value()); + ASSERT_TRUE(c.delta.has_value()); + EXPECT_EQ("Hi", c.delta->content); +} + +TEST_F(ParserTest, ParseChatCompletionCreateResponse) { + nlohmann::json j = { + {"created", 1700000000}, + {"id", "chatcmpl-123"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", "Hello!"}}}}}}}; + ChatCompletionCreateResponse r = j.get(); + EXPECT_EQ(1700000000, r.created); + EXPECT_EQ("chatcmpl-123", r.id); + EXPECT_FALSE(r.is_delta); + EXPECT_TRUE(r.successful); + EXPECT_EQ(200, r.http_status_code); + ASSERT_EQ(1u, r.choices.size()); + EXPECT_EQ("Hello!", r.choices[0].message->content); +} + +TEST(ChatCompletionCreateResponseTest, GetObject_NonDelta) { + ChatCompletionCreateResponse r; + r.is_delta = false; + EXPECT_STREQ("chat.completion", r.GetObject()); +} + +TEST(ChatCompletionCreateResponseTest, GetObject_Delta) { + ChatCompletionCreateResponse r; + r.is_delta = true; + EXPECT_STREQ("chat.completion.chunk", r.GetObject()); +} + +TEST(ChatCompletionCreateResponseTest, GetCreatedAtIso_Zero) { + ChatCompletionCreateResponse r; + r.created = 0; + EXPECT_EQ("", r.GetCreatedAtIso()); +} + +TEST(ChatCompletionCreateResponseTest, GetCreatedAtIso_ValidTimestamp) { + ChatCompletionCreateResponse r; + r.created = 1700000000; // 2023-11-14T22:13:20Z + std::string iso = r.GetCreatedAtIso(); + EXPECT_FALSE(iso.empty()); + EXPECT_EQ('Z', iso.back()); + EXPECT_NE(std::string::npos, iso.find("2023")); +} + +// ============================================================================= +// CoreInteropRequest tests +// ============================================================================= + +TEST(CoreInteropRequestTest, Command) { + CoreInteropRequest req("test_command"); + EXPECT_EQ("test_command", req.Command()); +} + +TEST(CoreInteropRequestTest, ToJson_NoParams) { + CoreInteropRequest req("cmd"); + std::string json = req.ToJson(); + auto parsed = nlohmann::json::parse(json); + EXPECT_FALSE(parsed.contains("Params")); +} + +TEST(CoreInteropRequestTest, ToJson_WithParams) { + CoreInteropRequest req("cmd"); + req.AddParam("key1", "value1"); + req.AddParam("key2", "value2"); + std::string json = req.ToJson(); + auto parsed = nlohmann::json::parse(json); + ASSERT_TRUE(parsed.contains("Params")); + EXPECT_EQ("value1", parsed["Params"]["key1"].get()); + EXPECT_EQ("value2", parsed["Params"]["key2"].get()); +} + +TEST(CoreInteropRequestTest, AddParam_Chaining) { + CoreInteropRequest req("cmd"); + auto& ref = req.AddParam("a", "1").AddParam("b", "2"); + EXPECT_EQ(&req, &ref); +} + +// ============================================================================= +// FoundryLocalException tests +// ============================================================================= + +TEST(FoundryLocalExceptionTest, MessageOnly) { + FoundryLocalException ex("test error"); + EXPECT_STREQ("test error", ex.what()); +} + +TEST(FoundryLocalExceptionTest, MessageAndLogger) { + NullLogger logger; + FoundryLocalException ex("logged error", logger); + EXPECT_STREQ("logged error", ex.what()); +} + +// ============================================================================= +// File-based parser tests (read JSON from testdata/) +// ============================================================================= + +class FileBasedParserTest : public ::testing::Test { +protected: + static std::string TestDataPath(const std::string& filename) { return "testdata/" + filename; } + + static nlohmann::json LoadJsonArray(const std::string& filename) { + std::string raw = Testing::ReadFile(TestDataPath(filename)); + return nlohmann::json::parse(raw); + } +}; + +TEST_F(FileBasedParserTest, AllFields_RequiredFields) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + EXPECT_EQ("model-all-fields", info.id); + EXPECT_EQ("model-all-fields", info.name); + EXPECT_EQ(3u, info.version); + EXPECT_EQ("full-model", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("https://example.com/full-model", info.uri); + EXPECT_EQ("text", info.model_type); + EXPECT_TRUE(info.cached); + EXPECT_EQ(1710000000, info.created_at_unix); +} + +TEST_F(FileBasedParserTest, AllFields_OptionalStrings) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("Full Model Display Name", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("TestPublisher", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("Apache-2.0", *info.license); + ASSERT_TRUE(info.license_description.has_value()); + EXPECT_EQ("Permissive open source license", *info.license_description); + ASSERT_TRUE(info.task.has_value()); + EXPECT_EQ("text-generation", *info.task); + ASSERT_TRUE(info.min_fl_version.has_value()); + EXPECT_EQ("1.0.0", *info.min_fl_version); +} + +TEST_F(FileBasedParserTest, AllFields_NumericOptionals) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(16384u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(8192, *info.max_output_tokens); +} + +TEST_F(FileBasedParserTest, AllFields_Runtime) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::NPU, info.runtime->device_type); + EXPECT_EQ("QNN", info.runtime->execution_provider); +} + +TEST_F(FileBasedParserTest, AllFields_PromptTemplate) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.prompt_template.has_value()); + EXPECT_EQ("<|system|>\n", info.prompt_template->system); + EXPECT_EQ("<|user|>\n", info.prompt_template->user); + EXPECT_EQ("<|assistant|>\n", info.prompt_template->assistant); + EXPECT_EQ("<|endoftext|>", info.prompt_template->prompt); +} + +TEST_F(FileBasedParserTest, AllFields_ModelSettings) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.model_settings.has_value()); + ASSERT_EQ(3u, info.model_settings->parameters.size()); + EXPECT_EQ("temperature", info.model_settings->parameters[0].name); + ASSERT_TRUE(info.model_settings->parameters[0].value.has_value()); + EXPECT_EQ("0.7", *info.model_settings->parameters[0].value); + EXPECT_EQ("top_p", info.model_settings->parameters[1].name); + ASSERT_TRUE(info.model_settings->parameters[1].value.has_value()); + EXPECT_EQ("0.9", *info.model_settings->parameters[1].value); + EXPECT_EQ("max_tokens", info.model_settings->parameters[2].name); + EXPECT_FALSE(info.model_settings->parameters[2].value.has_value()); +} + +TEST_F(FileBasedParserTest, MinimalFields_RequiredOnly) { + auto arr = LoadJsonArray("model_minimal_fields.json"); + ModelInfo info = arr.at(0).get(); + + EXPECT_EQ("minimal-model", info.id); + EXPECT_EQ("minimal-model", info.name); + EXPECT_EQ(1u, info.version); + EXPECT_EQ("minimal", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("text", info.model_type); + EXPECT_FALSE(info.cached); + EXPECT_EQ(0, info.created_at_unix); +} + +TEST_F(FileBasedParserTest, MinimalFields_AllOptionalsAbsent) { + auto arr = LoadJsonArray("model_minimal_fields.json"); + ModelInfo info = arr.at(0).get(); + + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.license.has_value()); + EXPECT_FALSE(info.license_description.has_value()); + EXPECT_FALSE(info.task.has_value()); + EXPECT_FALSE(info.file_size_mb.has_value()); + EXPECT_FALSE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(info.max_output_tokens.has_value()); + EXPECT_FALSE(info.min_fl_version.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(FileBasedParserTest, NullOptionals_AllOptionalsAbsent) { + auto arr = LoadJsonArray("model_null_optionals.json"); + ModelInfo info = arr.at(0).get(); + + EXPECT_EQ("model-null-optionals", info.id); + EXPECT_EQ("null-opts", info.alias); + + // All explicitly-null fields should parse as absent + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.license.has_value()); + EXPECT_FALSE(info.license_description.has_value()); + EXPECT_FALSE(info.task.has_value()); + EXPECT_FALSE(info.file_size_mb.has_value()); + EXPECT_FALSE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(info.max_output_tokens.has_value()); + EXPECT_FALSE(info.min_fl_version.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(FileBasedParserTest, RealModelsList_ParseAllEntries) { + auto arr = LoadJsonArray("real_models_list.json"); + ASSERT_EQ(4u, arr.size()); + + for (const auto& j : arr) { + EXPECT_NO_THROW({ + auto info = j.get(); + EXPECT_FALSE(info.id.empty()); + EXPECT_FALSE(info.name.empty()); + EXPECT_FALSE(info.alias.empty()); + }); + } +} + +TEST_F(FileBasedParserTest, MalformedJson_Throws) { + EXPECT_ANY_THROW({ + std::string raw = Testing::ReadFile(TestDataPath("malformed_models_list.json")); + nlohmann::json::parse(raw); + }); +} diff --git a/sdk/cpp/test/testdata/empty_models_list.json b/sdk/cpp/test/testdata/empty_models_list.json new file mode 100644 index 00000000..fe51488c --- /dev/null +++ b/sdk/cpp/test/testdata/empty_models_list.json @@ -0,0 +1 @@ +[] diff --git a/sdk/cpp/test/testdata/malformed_models_list.json b/sdk/cpp/test/testdata/malformed_models_list.json new file mode 100644 index 00000000..a04360f5 --- /dev/null +++ b/sdk/cpp/test/testdata/malformed_models_list.json @@ -0,0 +1 @@ +{this is not valid json[} diff --git a/sdk/cpp/test/testdata/missing_name_field_models_list.json b/sdk/cpp/test/testdata/missing_name_field_models_list.json new file mode 100644 index 00000000..ff4742f3 --- /dev/null +++ b/sdk/cpp/test/testdata/missing_name_field_models_list.json @@ -0,0 +1,12 @@ +[ + { + "id": "model-missing-name", + "version": 1, + "alias": "test", + "providerType": "onnx", + "uri": "https://example.com/model", + "modelType": "text", + "cached": false, + "createdAt": 0 + } +] diff --git a/sdk/cpp/test/testdata/mixed_openai_and_local.json b/sdk/cpp/test/testdata/mixed_openai_and_local.json new file mode 100644 index 00000000..091e473f --- /dev/null +++ b/sdk/cpp/test/testdata/mixed_openai_and_local.json @@ -0,0 +1,35 @@ +[ + { + "id": "openai-gpt4", + "name": "openai-gpt4", + "version": 1, + "alias": "openai-gpt4", + "providerType": "openai", + "uri": "https://example.com/openai-gpt4", + "modelType": "text", + "cached": false, + "createdAt": 0 + }, + { + "id": "openai-whisper", + "name": "openai-whisper", + "version": 1, + "alias": "openai-whisper", + "providerType": "openai", + "uri": "https://example.com/openai-whisper", + "modelType": "audio", + "cached": false, + "createdAt": 0 + }, + { + "id": "local-phi-4", + "name": "local-phi-4", + "version": 1, + "alias": "phi-4", + "providerType": "onnx", + "uri": "https://example.com/phi-4", + "modelType": "text", + "cached": false, + "createdAt": 1700000000 + } +] diff --git a/sdk/cpp/test/testdata/real_models_list.json b/sdk/cpp/test/testdata/real_models_list.json new file mode 100644 index 00000000..45f456af --- /dev/null +++ b/sdk/cpp/test/testdata/real_models_list.json @@ -0,0 +1,88 @@ +[ + { + "id": "Phi-4-generic-gpu", + "name": "Phi-4-generic-gpu", + "version": 1, + "alias": "phi-4", + "displayName": "Phi-4 (GPU)", + "providerType": "onnx", + "uri": "https://example.com/phi-4-gpu", + "modelType": "text", + "publisher": "Microsoft", + "license": "MIT", + "fileSizeMb": 8192, + "supportsToolCalling": true, + "maxOutputTokens": 4096, + "cached": false, + "createdAt": 1700000000, + "runtime": { + "deviceType": "GPU", + "executionProvider": "DML" + }, + "promptTemplate": { + "system": "<|system|>", + "user": "<|user|>", + "assistant": "<|assistant|>", + "prompt": "<|prompt|>" + } + }, + { + "id": "Phi-4-generic-cpu", + "name": "Phi-4-generic-cpu", + "version": 1, + "alias": "phi-4", + "displayName": "Phi-4 (CPU)", + "providerType": "onnx", + "uri": "https://example.com/phi-4-cpu", + "modelType": "text", + "publisher": "Microsoft", + "license": "MIT", + "fileSizeMb": 4096, + "supportsToolCalling": false, + "maxOutputTokens": 2048, + "cached": false, + "createdAt": 1700000000, + "runtime": { + "deviceType": "CPU", + "executionProvider": "ORT" + } + }, + { + "id": "Mistral-7b-v0.2-generic-gpu", + "name": "Mistral-7b-v0.2-generic-gpu", + "version": 1, + "alias": "mistral-7b-v0.2", + "displayName": "Mistral 7B v0.2 (GPU)", + "providerType": "onnx", + "uri": "https://example.com/mistral-gpu", + "modelType": "text", + "publisher": "Mistral AI", + "license": "Apache-2.0", + "fileSizeMb": 14000, + "cached": false, + "createdAt": 1700100000, + "runtime": { + "deviceType": "GPU", + "executionProvider": "DML" + } + }, + { + "id": "Mistral-7b-v0.2-generic-cpu", + "name": "Mistral-7b-v0.2-generic-cpu", + "version": 1, + "alias": "mistral-7b-v0.2", + "displayName": "Mistral 7B v0.2 (CPU)", + "providerType": "onnx", + "uri": "https://example.com/mistral-cpu", + "modelType": "text", + "publisher": "Mistral AI", + "license": "Apache-2.0", + "fileSizeMb": 7000, + "cached": false, + "createdAt": 1700100000, + "runtime": { + "deviceType": "CPU", + "executionProvider": "ORT" + } + } +] diff --git a/sdk/cpp/test/testdata/single_cached_model.json b/sdk/cpp/test/testdata/single_cached_model.json new file mode 100644 index 00000000..76efa8e7 --- /dev/null +++ b/sdk/cpp/test/testdata/single_cached_model.json @@ -0,0 +1 @@ +["multi-v1-cpu:1"] diff --git a/sdk/cpp/test/testdata/three_variants_one_model.json b/sdk/cpp/test/testdata/three_variants_one_model.json new file mode 100644 index 00000000..e60581ee --- /dev/null +++ b/sdk/cpp/test/testdata/three_variants_one_model.json @@ -0,0 +1,41 @@ +[ + { + "id": "multi-v1-gpu", + "name": "multi-v1-gpu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 GPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-gpu", + "modelType": "text", + "cached": false, + "createdAt": 1700000000, + "runtime": { "deviceType": "GPU", "executionProvider": "DML" } + }, + { + "id": "multi-v1-cpu", + "name": "multi-v1-cpu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 CPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-cpu", + "modelType": "text", + "cached": true, + "createdAt": 1700000000, + "runtime": { "deviceType": "CPU", "executionProvider": "ORT" } + }, + { + "id": "multi-v1-npu", + "name": "multi-v1-npu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 NPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-npu", + "modelType": "text", + "cached": false, + "createdAt": 1700000000, + "runtime": { "deviceType": "NPU", "executionProvider": "QNN" } + } +] diff --git a/sdk/cpp/test/testdata/valid_cached_models.json b/sdk/cpp/test/testdata/valid_cached_models.json new file mode 100644 index 00000000..2b144174 --- /dev/null +++ b/sdk/cpp/test/testdata/valid_cached_models.json @@ -0,0 +1 @@ +["Phi-4-generic-gpu:1", "Phi-4-generic-cpu:1"] diff --git a/sdk/cpp/test/testdata/valid_loaded_models.json b/sdk/cpp/test/testdata/valid_loaded_models.json new file mode 100644 index 00000000..4d2ef328 --- /dev/null +++ b/sdk/cpp/test/testdata/valid_loaded_models.json @@ -0,0 +1 @@ +["Phi-4-generic-gpu:1"]