From c71a94f36d56181875f9864424519af6cccbda5f Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 00:51:43 -0700 Subject: [PATCH 01/11] Add HuggingFace download to C# SDK --- sdk/cs/src/Catalog.cs | 123 +++++- sdk/cs/src/Detail/CoreInterop.cs | 3 - sdk/cs/src/FoundryLocalManager.cs | 31 ++ sdk/cs/src/FoundryModelInfo.cs | 3 + sdk/cs/src/HuggingFaceCatalog.cs | 413 ++++++++++++++++++ sdk/cs/src/ICatalog.cs | 22 +- sdk/cs/src/Microsoft.AI.Foundry.Local.csproj | 2 +- sdk/cs/src/Model.cs | 1 + sdk/cs/src/ModelVariant.cs | 1 + .../HuggingFaceCatalogTests.cs | 370 ++++++++++++++++ 10 files changed, 962 insertions(+), 7 deletions(-) create mode 100644 sdk/cs/src/HuggingFaceCatalog.cs create mode 100644 sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs diff --git a/sdk/cs/src/Catalog.cs b/sdk/cs/src/Catalog.cs index eb9ba0d7..667749a3 100644 --- a/sdk/cs/src/Catalog.cs +++ b/sdk/cs/src/Catalog.cs @@ -7,6 +7,7 @@ namespace Microsoft.AI.Foundry.Local; using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; @@ -77,6 +78,19 @@ public async Task> GetLoadedModelsAsync(CancellationToken? ct .ConfigureAwait(false); } + public async Task DownloadModelAsync(string modelUri, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => DownloadModelImplAsync(modelUri, ct), + $"Error downloading model '{modelUri}'.", _logger) + .ConfigureAwait(false); + } + + public Task RegisterModelAsync(string modelIdentifier, CancellationToken? ct = null) + { + return Task.FromException(new NotSupportedException( + "RegisterModelAsync is only available on HuggingFace catalogs. Use AddCatalogAsync(\"https://huggingface.co\") to create a HuggingFace catalog.")); + } + public async Task GetModelVariantAsync(string modelId, CancellationToken? ct = null) { return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct), @@ -126,9 +140,29 @@ private async Task> GetLoadedModelsImplAsync(CancellationToke private async Task GetModelImplAsync(string modelAlias, CancellationToken? ct = null) { + var hfUrl = NormalizeToHuggingFaceUrl(modelAlias); + if (hfUrl != null) + { + // Force a fresh catalog refresh for HuggingFace lookups + _lastFetch = DateTime.MinValue; + await UpdateModels(ct).ConfigureAwait(false); + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var matchingVariant = _modelIdToModelVariant.Values.FirstOrDefault(v => + string.Equals(v.Info.Uri, hfUrl, StringComparison.OrdinalIgnoreCase)); + + if (matchingVariant != null) + { + _modelAliasToModel.TryGetValue(matchingVariant.Alias, out Model? hfModel); + return hfModel; + } + + return null; + } + await UpdateModels(ct).ConfigureAwait(false); - using var disposable = await _lock.LockAsync().ConfigureAwait(false); + using var d = await _lock.LockAsync().ConfigureAwait(false); _modelAliasToModel.TryGetValue(modelAlias, out Model? model); return model; @@ -143,6 +177,93 @@ private async Task> GetLoadedModelsImplAsync(CancellationToke return modelVariant; } + private async Task DownloadModelImplAsync(string modelUri, CancellationToken? ct) + { + // Validate that this is a HuggingFace identifier + if (NormalizeToHuggingFaceUrl(modelUri) == null) + { + throw new FoundryLocalException( + $"'{modelUri}' is not a valid HuggingFace URL or org/repo identifier.", _logger); + } + + // Send the original URI to Core — it handles full URLs with /tree/revision/ + // and raw org/repo/subdir strings. Do NOT send the normalized form, as Core's + // URL parser expects /tree/revision/ when the https:// prefix is present. + var downloadRequest = new CoreInteropRequest + { + Params = new Dictionary { { "Model", modelUri } } + }; + + var result = await _coreInterop.ExecuteCommandAsync("download_model", downloadRequest, ct) + .ConfigureAwait(false); + + if (result.Error != null) + { + throw new FoundryLocalException( + $"Error downloading model '{modelUri}': {result.Error}", _logger); + } + + // Force a catalog refresh to pick up the newly downloaded model + _lastFetch = DateTime.MinValue; + await UpdateModels(ct).ConfigureAwait(false); + + // The backend returns the org/model URI (e.g. "microsoft/Phi-3-mini") as result.Data + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var expectedUri = $"https://huggingface.co/{result.Data}"; + var matchingVariant = _modelIdToModelVariant.Values.FirstOrDefault(v => + string.Equals(v.Info.Uri, expectedUri, StringComparison.OrdinalIgnoreCase)); + + if (matchingVariant != null) + { + _modelAliasToModel.TryGetValue(matchingVariant.Alias, out Model? hfModel); + return hfModel!; + } + + throw new FoundryLocalException( + $"Model '{modelUri}' was downloaded but could not be found in the catalog.", _logger); + } + + /// + /// Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. + /// Strips /tree/{revision}/ from full browser URLs so the result matches the stored Info.Uri format. + /// Handles: + /// - "https://huggingface.co/org/repo/tree/main/sub" -> "https://huggingface.co/org/repo/sub" + /// - "https://huggingface.co/org/repo" -> returned as-is + /// - "org/repo[/sub]" -> "https://huggingface.co/org/repo[/sub]" + /// - "phi-3-mini" (plain alias) -> null + /// + private static string? NormalizeToHuggingFaceUrl(string input) + { + const string hfPrefix = "https://huggingface.co/"; + + if (input.StartsWith(hfPrefix, StringComparison.OrdinalIgnoreCase)) + { + // Strip /tree/{revision}/ to match the canonical form stored by Core + var path = input[hfPrefix.Length..]; + var parts = path.Split('/'); + if (parts.Length >= 4 && + parts[2].Equals("tree", StringComparison.OrdinalIgnoreCase)) + { + // parts[0]=org, parts[1]=repo, parts[2]="tree", parts[3]=revision, parts[4..]=subpath + var org = parts[0]; + var repo = parts[1]; + var subPath = parts.Length > 4 ? string.Join("/", parts.Skip(4)) : null; + return subPath != null + ? $"{hfPrefix}{org}/{repo}/{subPath}" + : $"{hfPrefix}{org}/{repo}"; + } + + return input; + } + + if (input.Contains('/') && !input.StartsWith("azureml://", StringComparison.OrdinalIgnoreCase)) + { + return hfPrefix + input; + } + + return null; + } + private async Task UpdateModels(CancellationToken? ct) { // TODO: make this configurable diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 8411473b..95073b7a 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -6,7 +6,6 @@ namespace Microsoft.AI.Foundry.Local.Detail; -using System.Diagnostics; using System.Runtime.InteropServices; using Microsoft.Extensions.Logging; @@ -183,8 +182,6 @@ private static void HandleCallback(nint data, int length, nint callbackHelper) callbackData = System.Text.Encoding.UTF8.GetString(managedData); } - Debug.Assert(callbackHelper != IntPtr.Zero, "Callback helper pointer is required."); - helper = (CallbackHelper)GCHandle.FromIntPtr(callbackHelper).Target!; helper.Callback.Invoke(callbackData); } diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index 639be3a2..c2427270 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -107,6 +107,21 @@ public async Task GetCatalogAsync(CancellationToken? ct = null) "Error getting Catalog.", _logger).ConfigureAwait(false); } + /// + /// Create a separate catalog for a HuggingFace model registry. + /// + /// URL of the catalog (must contain "huggingface.co"). + /// Optional authentication token for accessing private HuggingFace repositories. + /// Optional CancellationToken. + /// The HuggingFace catalog instance. + public async Task AddCatalogAsync(string catalogUrl, string? token = null, + CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => AddCatalogImplAsync(catalogUrl, token, ct), + $"Error adding catalog '{catalogUrl}'.", _logger) + .ConfigureAwait(false); + } + /// /// Start the optional web service. This will provide an OpenAI-compatible REST endpoint that supports /// /v1/chat_completions @@ -212,6 +227,22 @@ private async Task GetCatalogImplAsync(CancellationToken? ct = null) return _catalog; } + private async Task AddCatalogImplAsync(string catalogUrl, string? token, + CancellationToken? ct = null) + { + if (!catalogUrl.Contains("huggingface.co", StringComparison.OrdinalIgnoreCase)) + { + throw new FoundryLocalException( + $"Unsupported catalog URL '{catalogUrl}'. Only HuggingFace catalogs (huggingface.co) are supported.", + _logger); + } + +#pragma warning disable IDISP005 // Return type is not disposable + return await HuggingFaceCatalog.CreateAsync(_modelManager!, _coreInterop!, _logger, token, ct) + .ConfigureAwait(false); +#pragma warning restore IDISP005 + } + private async Task StartWebServiceImplAsync(CancellationToken? ct = null) { if (_config?.Web?.Urls == null) diff --git a/sdk/cs/src/FoundryModelInfo.cs b/sdk/cs/src/FoundryModelInfo.cs index 1f795d22..90025e51 100644 --- a/sdk/cs/src/FoundryModelInfo.cs +++ b/sdk/cs/src/FoundryModelInfo.cs @@ -65,6 +65,9 @@ public record ModelInfo [JsonPropertyName("version")] public int Version { get; init; } + [JsonPropertyName("hash")] + public string? Hash { get; init; } + [JsonPropertyName("alias")] public required string Alias { get; init; } diff --git a/sdk/cs/src/HuggingFaceCatalog.cs b/sdk/cs/src/HuggingFaceCatalog.cs new file mode 100644 index 00000000..1c5cb907 --- /dev/null +++ b/sdk/cs/src/HuggingFaceCatalog.cs @@ -0,0 +1,413 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +using Microsoft.AI.Foundry.Local.Detail; +using Microsoft.Extensions.Logging; + +internal sealed class HuggingFaceCatalog : ICatalog, IDisposable +{ + private readonly Dictionary _modelIdToModelVariant = new(); + private readonly Dictionary _modelIdToModel = new(); + + private readonly IModelLoadManager _modelLoadManager; + private readonly ICoreInterop _coreInterop; + private readonly ILogger _logger; + private readonly AsyncLock _lock = new(); + private readonly string? _token; + + public string Name { get; init; } + + private HuggingFaceCatalog(IModelLoadManager modelLoadManager, ICoreInterop coreInterop, ILogger logger, + string? token) + { + _modelLoadManager = modelLoadManager; + _coreInterop = coreInterop; + _logger = logger; + _token = token; + + Name = "HuggingFace"; + } + + internal static async Task CreateAsync(IModelLoadManager modelManager, + ICoreInterop coreInterop, + ILogger logger, + string? token = null, + CancellationToken? ct = null) + { + var catalog = new HuggingFaceCatalog(modelManager, coreInterop, logger, token); + await catalog.LoadRegistrationsAsync(ct).ConfigureAwait(false); + return catalog; + } + + public async Task> ListModelsAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => ListModelsImplAsync(ct), + "Error listing HuggingFace models.", _logger) + .ConfigureAwait(false); + } + + public async Task GetModelAsync(string modelAlias, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetModelImplAsync(modelAlias, ct), + $"Error getting HuggingFace model '{modelAlias}'.", _logger) + .ConfigureAwait(false); + } + + public async Task DownloadModelAsync(string modelUri, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling( + () => DownloadModelImplAsync(modelUri, ct), + $"Error downloading HuggingFace model '{modelUri}'.", _logger) + .ConfigureAwait(false); + } + + public async Task RegisterModelAsync(string modelIdentifier, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => RegisterModelImplAsync(modelIdentifier, ct), + $"Error registering HuggingFace model '{modelIdentifier}'.", + _logger) + .ConfigureAwait(false); + } + + public async Task GetModelVariantAsync(string modelId, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct), + $"Error getting HuggingFace model variant '{modelId}'.", + _logger) + .ConfigureAwait(false); + } + + public async Task> GetCachedModelsAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetCachedModelsImplAsync(ct), + "Error getting cached HuggingFace models.", _logger) + .ConfigureAwait(false); + } + + public async Task> GetLoadedModelsAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetLoadedModelsImplAsync(ct), + "Error getting loaded HuggingFace models.", _logger) + .ConfigureAwait(false); + } + + private async Task> ListModelsImplAsync(CancellationToken? ct = null) + { + // HuggingFace catalog returns one entry per registration (each variant is individually referenceable) + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + return _modelIdToModel.Values.OrderBy(m => m.Id).ToList(); + } + + private async Task GetModelImplAsync(string modelIdentifier, CancellationToken? ct = null) + { + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + + // Try direct Id lookup first + if (_modelIdToModel.TryGetValue(modelIdentifier, out Model? model)) + { + return model; + } + + // Try alias lookup (returns first match) + var aliaMatch = _modelIdToModel.Values.FirstOrDefault(m => + string.Equals(m.Alias, modelIdentifier, StringComparison.OrdinalIgnoreCase)); + if (aliaMatch != null) + { + return aliaMatch; + } + + // Try URI-based lookup + var normalizedUrl = NormalizeToHuggingFaceUrl(modelIdentifier); + if (normalizedUrl != null) + { + var normalizedUrlWithSlash = normalizedUrl.TrimEnd('/') + "/"; + foreach (var variant in _modelIdToModelVariant.Values) + { + if (string.Equals(variant.Info.Uri, normalizedUrl, StringComparison.OrdinalIgnoreCase) || + variant.Info.Uri.StartsWith(normalizedUrlWithSlash, StringComparison.OrdinalIgnoreCase)) + { + if (_modelIdToModel.TryGetValue(variant.Id, out Model? foundModel)) + { + return foundModel; + } + } + } + } + + return null; + } + + private async Task RegisterModelImplAsync(string modelIdentifier, CancellationToken? ct = null) + { + // Validate it's a HuggingFace URL or org/repo format + var normalizedUrl = NormalizeToHuggingFaceUrl(modelIdentifier); + if (normalizedUrl == null) + { + throw new FoundryLocalException( + $"'{modelIdentifier}' is not a valid HuggingFace URL or org/repo identifier.", _logger); + } + + // Call Core to register the model (fetch metadata, generate inference_model.json, persist to + // huggingface.modelinfo.json) + var registerRequest = new CoreInteropRequest + { + Params = new Dictionary + { + { "Model", modelIdentifier }, + { "Token", _token ?? "" } + } + }; + + var result = await _coreInterop.ExecuteCommandAsync("register_model", registerRequest, ct) + .ConfigureAwait(false); + + if (result.Error != null) + { + throw new FoundryLocalException($"Error registering HuggingFace model '{modelIdentifier}': {result.Error}", + _logger); + } + + // Deserialize the returned ModelInfo + var modelInfo = JsonSerializer.Deserialize(result.Data!, JsonSerializationContext.Default.ModelInfo); + if (modelInfo == null) + { + throw new FoundryLocalException($"Failed to deserialize registered model metadata.", _logger); + } + + // Add to internal dictionaries with lock + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger); + _modelIdToModelVariant[modelInfo.Id] = variant; + + // Each registration is a distinct entry, keyed by Id + var registeredModel = new Model(variant, _logger); + _modelIdToModel[modelInfo.Id] = registeredModel; + + // Persist registrations to local file + await SaveRegistrationsAsync(ct).ConfigureAwait(false); + + return registeredModel; + } + + private async Task DownloadModelImplAsync(string modelUri, CancellationToken? ct) + { + // Validate it's a HuggingFace URL or org/repo format + if (NormalizeToHuggingFaceUrl(modelUri) == null) + { + throw new FoundryLocalException( + $"'{modelUri}' is not a valid HuggingFace URL or org/repo identifier.", _logger); + } + + // Call Core's download_model command (same as existing Catalog) + var downloadRequest = new CoreInteropRequest + { + Params = new Dictionary + { + { "Model", modelUri }, + { "Token", _token ?? "" } + } + }; + + var result = await _coreInterop.ExecuteCommandAsync("download_model", downloadRequest, ct) + .ConfigureAwait(false); + + if (result.Error != null) + { + throw new FoundryLocalException($"Error downloading model '{modelUri}': {result.Error}", _logger); + } + + // The backend returns the org/model URI (e.g. "microsoft/Phi-3-mini") as result.Data + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var expectedUri = $"https://huggingface.co/{result.Data}"; + var expectedUriWithSlash = expectedUri.TrimEnd('/') + "/"; + var matchingVariant = _modelIdToModelVariant.Values.FirstOrDefault(v => + string.Equals(v.Info.Uri, expectedUri, StringComparison.OrdinalIgnoreCase) || + v.Info.Uri.StartsWith(expectedUriWithSlash, StringComparison.OrdinalIgnoreCase) || + expectedUri.StartsWith(v.Info.Uri.TrimEnd('/') + "/", StringComparison.OrdinalIgnoreCase)); + + if (matchingVariant != null) + { + if (_modelIdToModel.TryGetValue(matchingVariant.Id, out Model? hfModel)) + { + return hfModel; + } + } + + throw new FoundryLocalException( + $"Model '{modelUri}' was downloaded but could not be found in the catalog.", _logger); + } + + private async Task> GetCachedModelsImplAsync(CancellationToken? ct = null) + { + var cachedModelIds = await Utils.GetCachedModelIdsAsync(_coreInterop, ct).ConfigureAwait(false); + + List cachedModels = new(); + foreach (var modelId in cachedModelIds) + { + if (_modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant)) + { + cachedModels.Add(modelVariant); + } + } + + return cachedModels; + } + + private async Task> GetLoadedModelsImplAsync(CancellationToken? ct = null) + { + var loadedModelIds = await _modelLoadManager.ListLoadedModelsAsync(ct).ConfigureAwait(false); + List loadedModels = new(); + + foreach (var modelId in loadedModelIds) + { + if (_modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant)) + { + loadedModels.Add(modelVariant); + } + } + + return loadedModels; + } + + private async Task GetModelVariantImplAsync(string modelId, CancellationToken? ct = null) + { + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + _modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant); + return modelVariant; + } + + private async Task LoadRegistrationsAsync(CancellationToken? ct = null) + { + // Load persisted HuggingFace registrations from local file + // File path: ~/.foundry-local/HuggingFace/huggingface.modelinfo.json + try + { + var homeDir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + var registrationsPath = Path.Combine(homeDir, ".foundry-local", "HuggingFace", "huggingface.modelinfo.json"); + + if (!File.Exists(registrationsPath)) + { + return; // No registrations yet + } + + var registrationsJson = await File.ReadAllTextAsync(registrationsPath).ConfigureAwait(false); + if (string.IsNullOrEmpty(registrationsJson)) + { + return; + } + + var models = JsonSerializer.Deserialize(registrationsJson, JsonSerializationContext.Default.ListModelInfo); + if (models == null) + { + _logger.LogDebug("Failed to deserialize HuggingFace registrations from file"); + return; + } + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + + foreach (var modelInfo in models) + { + var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger); + _modelIdToModelVariant[modelInfo.Id] = variant; + _modelIdToModel[modelInfo.Id] = new Model(variant, _logger); + } + } + catch (Exception ex) + { + _logger.LogWarning($"Exception loading HuggingFace registrations: {ex.Message}"); + // Continue anyway — empty catalog is valid + } + } + + private async Task SaveRegistrationsAsync(CancellationToken? ct = null) + { + // Save persisted HuggingFace registrations to local file + // File path: ~/.foundry-local/HuggingFace/huggingface.modelinfo.json + try + { + var homeDir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + var registrationsDir = Path.Combine(homeDir, ".foundry-local", "HuggingFace"); + var registrationsPath = Path.Combine(registrationsDir, "huggingface.modelinfo.json"); + + // Ensure directory exists + Directory.CreateDirectory(registrationsDir); + + // Collect all registered models (from both dictionaries, using variants) + var models = _modelIdToModelVariant.Values + .Select(v => v.Info) + .Distinct() + .ToList(); + + // Serialize with pretty-printing (matching foundry.modelinfo.json style) + var prettyOptions = new JsonSerializerOptions { WriteIndented = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }; + var prettyContext = new JsonSerializationContext(prettyOptions); + var json = JsonSerializer.Serialize(models, prettyContext.ListModelInfo); + await File.WriteAllTextAsync(registrationsPath, json).ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.LogWarning($"Failed to save HuggingFace registrations: {ex.Message}"); + // Continue anyway — loss of persistence file is not critical + } + } + + private static string? NormalizeToHuggingFaceUrl(string input) + { + const string hfPrefix = "https://huggingface.co/"; + + if (input.StartsWith(hfPrefix, StringComparison.OrdinalIgnoreCase)) + { + // Strip /tree/{revision}/ to match the canonical form stored by Core + var path = input[hfPrefix.Length..]; + var parts = path.Split('/'); + if (parts.Length >= 4 && + parts[2].Equals("tree", StringComparison.OrdinalIgnoreCase)) + { + var org = parts[0]; + var repo = parts[1]; + var subPath = parts.Length > 4 ? string.Join("/", parts.Skip(4)) : null; + return subPath != null + ? $"{hfPrefix}{org}/{repo}/{subPath}" + : $"{hfPrefix}{org}/{repo}"; + } + + return input; + } + + if (input.Contains('/') && !input.StartsWith("azureml://", StringComparison.OrdinalIgnoreCase)) + { + // Strip /tree/{revision}/ from bare identifiers (e.g. "org/repo/tree/main/subpath") + var parts = input.Split('/'); + if (parts.Length >= 4 && + parts[2].Equals("tree", StringComparison.OrdinalIgnoreCase)) + { + var org = parts[0]; + var repo = parts[1]; + var subPath = parts.Length > 4 ? string.Join("/", parts.Skip(4)) : null; + return subPath != null + ? $"{hfPrefix}{org}/{repo}/{subPath}" + : $"{hfPrefix}{org}/{repo}"; + } + + return hfPrefix + input; + } + + return null; + } + + public void Dispose() + { + _lock?.Dispose(); + } +} diff --git a/sdk/cs/src/ICatalog.cs b/sdk/cs/src/ICatalog.cs index 35285736..90423d57 100644 --- a/sdk/cs/src/ICatalog.cs +++ b/sdk/cs/src/ICatalog.cs @@ -22,13 +22,31 @@ public interface ICatalog Task> ListModelsAsync(CancellationToken? ct = null); /// - /// Lookup a model by its alias. + /// Lookup a model by its alias, HuggingFace URL (https://huggingface.co/org/repo), or org/repo identifier. /// - /// Model alias. + /// Model alias, HuggingFace URL (https://huggingface.co/org/repo), or org/repo identifier. /// Optional CancellationToken. /// The matching Model, or null if no model with the given alias exists. Task GetModelAsync(string modelAlias, CancellationToken? ct = null); + /// + /// Download a model from its HuggingFace identifier (URL or org/repo format). + /// + /// HuggingFace URL or org/repo identifier. + /// Optional CancellationToken. + /// The downloaded Model. + Task DownloadModelAsync(string modelUri, CancellationToken? ct = null); + + /// + /// Register a HuggingFace model by downloading its config files and generating metadata. + /// Only available on HuggingFace catalogs created via . + /// + /// HuggingFace URL or org/repo identifier. + /// Optional CancellationToken. + /// The registered Model. + /// If called on a non-HuggingFace catalog. + Task RegisterModelAsync(string modelIdentifier, CancellationToken? ct = null); + /// /// Lookup a model variant by its unique model id. /// diff --git a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj index 905f9652..b4051044 100644 --- a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj +++ b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj @@ -118,7 +118,7 @@ + Include="Microsoft.AI.Foundry.Local.Core" Version="0.9.0-0.local.20260320004727" /> diff --git a/sdk/cs/src/Model.cs b/sdk/cs/src/Model.cs index bbbbcb5b..a46d178a 100644 --- a/sdk/cs/src/Model.cs +++ b/sdk/cs/src/Model.cs @@ -17,6 +17,7 @@ public class Model : IModel public string Alias { get; init; } public string Id => SelectedVariant.Id; + public string Uri => SelectedVariant.Info.Uri; /// /// Is the currently selected variant cached locally? diff --git a/sdk/cs/src/ModelVariant.cs b/sdk/cs/src/ModelVariant.cs index 6ca7cda7..8ed6edc6 100644 --- a/sdk/cs/src/ModelVariant.cs +++ b/sdk/cs/src/ModelVariant.cs @@ -21,6 +21,7 @@ public class ModelVariant : IModel public string Id => Info.Id; public string Alias => Info.Alias; public int Version { get; init; } // parsed from Info.Version if possible, else 0 + public string VersionDisplay => Info.Hash ?? Info.Version.ToString(System.Globalization.CultureInfo.InvariantCulture); internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, ICoreInterop coreInterop, ILogger logger) diff --git a/sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs b/sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs new file mode 100644 index 00000000..5605b2c7 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs @@ -0,0 +1,370 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using System.Text.Json; +using Microsoft.AI.Foundry.Local.Detail; +using Microsoft.Extensions.Logging; +using Moq; + +/// +/// Unit tests for HuggingFaceCatalog — validates GetModelAsync lookup by alias and URI, +/// and that DownloadModelAsync accepts model URI from registration. +/// +public class HuggingFaceCatalogTests +{ + private static ModelInfo CreateTestModelInfo(string alias, string id, string uri) + { + return new ModelInfo + { + Id = id, + Name = alias, + DisplayName = alias, + Alias = alias, + Uri = uri, + ProviderType = "HuggingFace", + Version = 0, + ModelType = "ONNX", + Publisher = "test-org", + Task = "chat-completion", + License = "MIT", + FileSizeMb = 100 + }; + } + + private static (Mock coreInterop, Mock loadManager, Mock logger) + CreateMocks(ModelInfo modelInfo) + { + var logger = Utils.CreateCapturingLoggerMock([]); + var loadManager = new Mock(); + loadManager.Setup(x => x.ListLoadedModelsAsync(It.IsAny())) + .ReturnsAsync(Array.Empty()); + + var coreInterop = new Mock(); + + // Mock register_model to return the ModelInfo JSON + var modelInfoJson = JsonSerializer.Serialize(modelInfo, JsonSerializationContext.Default.ModelInfo); + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "register_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response { Data = modelInfoJson, Error = null }); + + // Mock get_cached_models to return empty + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "get_cached_models"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response { Data = "[]", Error = null }); + + return (coreInterop, loadManager, logger); + } + + [Test] + public async Task GetModelAsync_ByAlias_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + // Register the model + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + await Assert.That(registered.Alias).IsEqualTo("phi-3-mini-4k"); + + // Lookup by alias + var found = await catalog.GetModelAsync("phi-3-mini-4k"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task GetModelAsync_ByUri_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Lookup by full URI (what SelectedVariant.Info.Uri returns) + var found = await catalog.GetModelAsync("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task GetModelAsync_ByOrgRepo_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Lookup by org/repo identifier + var found = await catalog.GetModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task GetModelAsync_NotRegistered_ReturnsNull() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + // Don't register anything — lookup should return null + var found = await catalog.GetModelAsync("nonexistent-model"); + await Assert.That(found).IsNull(); + } + + [Test] + public async Task GetModelAsync_BySubpathUri_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "gemma-3-4b-it", + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile:abcd1234", + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + await catalog.RegisterModelAsync("onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + // Lookup by the full URI with subpath + var found = await catalog.GetModelAsync( + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("gemma-3-4b-it"); + } + + [Test] + public async Task DownloadModelAsync_WithModelUri_SucceedsWhenModelRegistered() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + // Mock download_model to return the org/repo identifier (as Core does) + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response + { + Data = "microsoft/Phi-3-mini-4k-instruct-onnx", + Error = null + }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + // Register first + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Download using model's URI (the simplified API pattern) + var uri = registered.SelectedVariant.Info.Uri; + var downloaded = await catalog.DownloadModelAsync(uri); + + await Assert.That(downloaded).IsNotNull(); + await Assert.That(downloaded.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task DownloadModelAsync_WithSubpathUri_SucceedsWhenModelRegistered() + { + var modelInfo = CreateTestModelInfo( + "gemma-3-4b-it", + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile:abcd1234", + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response + { + Data = "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile", + Error = null + }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var registered = await catalog.RegisterModelAsync( + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + // Download using model's URI (subpath model) + var uri = registered.SelectedVariant.Info.Uri; + var downloaded = await catalog.DownloadModelAsync(uri); + + await Assert.That(downloaded).IsNotNull(); + await Assert.That(downloaded.Alias).IsEqualTo("gemma-3-4b-it"); + } + + [Test] + public async Task RegisterModelAsync_Idempotent_ReturnsSameAlias() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var first = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + var second = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + await Assert.That(first.Alias).IsEqualTo(second.Alias); + } + + [Test] + public async Task SelectedVariantInfoUri_ReturnsExpectedValue() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var model = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // SelectedVariant.Info.Uri holds the HuggingFace URL + await Assert.That(model.SelectedVariant.Info.Uri) + .IsEqualTo("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + } + + [Test] + public async Task DownloadModelAsync_UsingSelectedVariantInfoUri_SucceedsWithoutRepeatingIdentifier() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response + { + Data = "microsoft/Phi-3-mini-4k-instruct-onnx", + Error = null + }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Simplified pattern: use SelectedVariant.Info.Uri instead of repeating the raw string + var downloaded = await catalog.DownloadModelAsync(registered.SelectedVariant.Info.Uri); + + await Assert.That(downloaded).IsNotNull(); + await Assert.That(downloaded.Alias).IsEqualTo("phi-3-mini-4k"); + await Assert.That(downloaded.SelectedVariant.Info.Uri).IsEqualTo(registered.SelectedVariant.Info.Uri); + } + + [Test] + public async Task ListModelsAsync_ReturnsAllRegisteredModelsSortedByAlias() + { + // Register two models — need separate mock setup for sequential calls + var phi = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + var gemma = CreateTestModelInfo( + "gemma-3-4b-it", + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile:efgh5678", + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var logger = Utils.CreateCapturingLoggerMock([]); + var loadManager = new Mock(); + loadManager.Setup(x => x.ListLoadedModelsAsync(It.IsAny())) + .ReturnsAsync(Array.Empty()); + var coreInterop = new Mock(); + + // Return phi on first call, gemma on second + var callIndex = 0; + var models = new[] { phi, gemma }; + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "register_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(() => + { + var json = JsonSerializer.Serialize(models[callIndex++], + JsonSerializationContext.Default.ModelInfo); + return new ICoreInterop.Response { Data = json, Error = null }; + }); + + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "get_cached_models"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response { Data = "[]", Error = null }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + await catalog.RegisterModelAsync("onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var list = await catalog.ListModelsAsync(); + + await Assert.That(list.Count).IsEqualTo(2); + // Sorted alphabetically by alias + await Assert.That(list[0].Alias).IsEqualTo("gemma-3-4b-it"); + await Assert.That(list[1].Alias).IsEqualTo("phi-3-mini-4k"); + // Each model has correct URI via SelectedVariant.Info.Uri + await Assert.That(list[0].SelectedVariant.Info.Uri) + .IsEqualTo("https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + await Assert.That(list[1].SelectedVariant.Info.Uri) + .IsEqualTo("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + } +} From aa5f5d54b17b161b132371b9f1c979d5d56a3902 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 01:23:06 -0700 Subject: [PATCH 02/11] Add support for HuggingFace download to JS SDK --- sdk/js/examples/responses.ts | 2 +- sdk/js/src/catalog.ts | 119 ++++++++++++++++++++- sdk/js/src/detail/coreInterop.ts | 49 +++++++++ sdk/js/src/types.ts | 1 + sdk/js/test/catalog.test.ts | 75 ++++++++++++- sdk/js/test/openai/audioClient.test.ts | 14 +-- sdk/js/test/openai/chatClient.test.ts | 16 +-- sdk/js/test/openai/responsesClient.test.ts | 2 +- 8 files changed, 253 insertions(+), 25 deletions(-) diff --git a/sdk/js/examples/responses.ts b/sdk/js/examples/responses.ts index fa8a6d93..cf6915b4 100644 --- a/sdk/js/examples/responses.ts +++ b/sdk/js/examples/responses.ts @@ -19,7 +19,7 @@ async function main() { // Load a model const modelAlias = 'MODEL_ALIAS'; // Replace with a valid model alias const catalog = manager.catalog; - const model = await catalog.getModel(modelAlias); + const model = (await catalog.getModel(modelAlias))!; await model.load(); console.log(`✓ Model ${model.id} loaded`); diff --git a/sdk/js/src/catalog.ts b/sdk/js/src/catalog.ts index bf2ae5c9..f6b9578a 100644 --- a/sdk/js/src/catalog.ts +++ b/sdk/js/src/catalog.ts @@ -79,16 +79,34 @@ export class Catalog { } /** - * Retrieves a model by its alias. - * This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. - * @param alias - The alias of the model to retrieve. - * @returns A Promise that resolves to the Model object if found, otherwise throws an error. + * Retrieves a model by its alias, HuggingFace URL, or org/repo identifier. + * For plain aliases, throws if the model is not found. + * For HuggingFace URLs or org/repo identifiers, returns undefined if not found. + * @param alias - The alias of the model, a HuggingFace URL, or an org/repo identifier. + * @returns A Promise that resolves to the Model object if found. * @throws Error - If alias is null, undefined, or empty. + * @throws Error - If a plain alias is not found in the catalog. */ - public async getModel(alias: string): Promise { + public async getModel(alias: string): Promise { if (typeof alias !== 'string' || alias.trim() === '') { throw new Error('Model alias must be a non-empty string.'); } + + const hfUrl = Catalog.normalizeToHuggingFaceUrl(alias); + if (hfUrl) { + // Force a fresh catalog refresh for HuggingFace lookups + this.lastFetch = 0; + await this.updateModels(); + + for (const [, variant] of this.modelIdToModelVariant) { + if (variant.modelInfo.uri.toLowerCase() === hfUrl.toLowerCase()) { + return this.modelAliasToModel.get(variant.alias); + } + } + + return undefined; + } + await this.updateModels(); const model = this.modelAliasToModel.get(alias); if (!model) { @@ -98,6 +116,97 @@ export class Catalog { return model; } + /** + * Downloads a model by its HuggingFace URL or org/repo identifier and adds it to the catalog. + * If the model is already cached, this is a no-op and returns the existing model. + * @param modelUri - A HuggingFace URL (https://huggingface.co/org/repo) or org/repo identifier. + * @param progressCallback - Optional callback invoked with download progress percentage (0-100). + * @returns A Promise that resolves to the downloaded Model. + * @throws Error if the URI is not a valid HuggingFace identifier or if the download fails. + */ + public async downloadModel(modelUri: string, progressCallback?: (progress: number) => void): Promise { + // Validate that this is a HuggingFace identifier + if (!Catalog.normalizeToHuggingFaceUrl(modelUri)) { + throw new Error(`'${modelUri}' is not a valid HuggingFace URL or org/repo identifier.`); + } + + // Send the original URI to Core — it handles full URLs with /tree/revision/ + // and raw org/repo/subdir strings. Do NOT send the normalized form, as Core's + // URL parser expects /tree/revision/ when the https:// prefix is present. + const request = { Params: { Model: modelUri } }; + let resultData: string; + + if (progressCallback) { + resultData = await this.coreInterop.executeCommandWithCallback( + "download_model", + request, + (progressString: string) => { + try { + const progress = JSON.parse(progressString); + if (progress.percent != null) { + progressCallback(progress.percent); + } + } catch { /* ignore malformed progress */ } + } + ); + } else { + resultData = this.coreInterop.executeCommand("download_model", request); + } + + // Force a catalog refresh to pick up the newly downloaded model + this.lastFetch = 0; + await this.updateModels(); + + // The backend returns the org/model[/subpath] identifier as resultData + const expectedUri = `https://huggingface.co/${resultData}`; + for (const [, variant] of this.modelIdToModelVariant) { + if (variant.modelInfo.uri.toLowerCase() === expectedUri.toLowerCase()) { + const model = this.modelAliasToModel.get(variant.alias); + if (model) { + return model; + } + } + } + + throw new Error(`Model '${modelUri}' was downloaded but could not be found in the catalog.`); + } + + /** + * Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. + * Strips /tree/{revision}/ from full browser URLs so the result matches the stored Info.Uri format. + * Handles: + * - "https://huggingface.co/org/repo/tree/main/sub" -> "https://huggingface.co/org/repo/sub" + * - "https://huggingface.co/org/repo" -> returned as-is + * - "org/repo[/sub]" -> "https://huggingface.co/org/repo[/sub]" + * - "phi-3-mini" (plain alias) -> null + */ + private static normalizeToHuggingFaceUrl(input: string): string | null { + const hfPrefix = "https://huggingface.co/"; + + if (input.toLowerCase().startsWith(hfPrefix)) { + // Strip /tree/{revision}/ to match the canonical form stored by Core + const path = input.substring(hfPrefix.length); + const parts = path.split('/'); + if (parts.length >= 4 && parts[2].toLowerCase() === 'tree') { + // parts[0]=org, parts[1]=repo, parts[2]="tree", parts[3]=revision, parts[4..]=subpath + const org = parts[0]; + const repo = parts[1]; + const subPath = parts.length > 4 ? parts.slice(4).join('/') : null; + return subPath + ? `${hfPrefix}${org}/${repo}/${subPath}` + : `${hfPrefix}${org}/${repo}`; + } + + return input; + } + + if (input.includes('/') && !input.toLowerCase().startsWith("azureml://")) { + return hfPrefix + input; + } + + return null; + } + /** * Retrieves a specific model variant by its ID. * This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. diff --git a/sdk/js/src/detail/coreInterop.ts b/sdk/js/src/detail/coreInterop.ts index 167784e7..53eccd8a 100644 --- a/sdk/js/src/detail/coreInterop.ts +++ b/sdk/js/src/detail/coreInterop.ts @@ -129,6 +129,55 @@ export class CoreInterop { } } + public executeCommandWithCallback(command: string, params: any, callback: (chunk: string) => void): Promise { + const cmdBuf = koffi.alloc('char', command.length + 1); + koffi.encode(cmdBuf, 'char', command, command.length + 1); + + const dataStr = params ? JSON.stringify(params) : ''; + const dataBytes = this._toBytes(dataStr); + const dataBuf = koffi.alloc('char', dataBytes.length + 1); + koffi.encode(dataBuf, 'char', dataStr, dataBytes.length + 1); + + const cb = koffi.register((data: any, length: number, userData: any) => { + const chunk = koffi.decode(data, 'char', length); + callback(chunk); + }, koffi.pointer(CallbackType)); + + return new Promise((resolve, reject) => { + const req = { + Command: koffi.address(cmdBuf), + CommandLength: command.length, + Data: koffi.address(dataBuf), + DataLength: dataBytes.length + }; + const res = { Data: 0, DataLength: 0, Error: 0, ErrorLength: 0 }; + + this.execute_command_with_callback.async(req, res, cb, null, (err: any) => { + koffi.unregister(cb); + koffi.free(cmdBuf); + koffi.free(dataBuf); + + if (err) { + reject(err); + return; + } + + try { + if (res.Error) { + const errorMsg = koffi.decode(res.Error, 'char', res.ErrorLength); + reject(new Error(`Command '${command}' failed: ${errorMsg}`)); + } else { + const data = res.Data ? koffi.decode(res.Data, 'char', res.DataLength) : ""; + resolve(data); + } + } finally { + if (res.Data) koffi.free(res.Data); + if (res.Error) koffi.free(res.Error); + } + }); + }); + } + public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise { const cmdBuf = koffi.alloc('char', command.length + 1); koffi.encode(cmdBuf, 'char', command, command.length + 1); diff --git a/sdk/js/src/types.ts b/sdk/js/src/types.ts index 639676de..b5515511 100644 --- a/sdk/js/src/types.ts +++ b/sdk/js/src/types.ts @@ -32,6 +32,7 @@ export interface ModelInfo { id: string; name: string; version: number; + hash?: string | null; alias: string; displayName?: string | null; providerType: string; diff --git a/sdk/js/test/catalog.test.ts b/sdk/js/test/catalog.test.ts index df47d4f6..12f46184 100644 --- a/sdk/js/test/catalog.test.ts +++ b/sdk/js/test/catalog.test.ts @@ -27,8 +27,8 @@ describe('Catalog Tests', () => { it('should get model by alias', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); - + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; + expect(model.alias).to.equal(TEST_MODEL_ALIAS); }); @@ -95,7 +95,7 @@ describe('Catalog Tests', () => { it('should throw when getting model variant with unknown ID', async function() { const manager = getTestManager(); const catalog = manager.catalog; - + const unknownId = 'definitely-not-a-real-model-id-12345'; try { await catalog.getModelVariant(unknownId); @@ -107,3 +107,72 @@ describe('Catalog Tests', () => { } }); }); + +describe('Catalog HuggingFace Tests', () => { + const HF_URL = 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4'; + + it('should return undefined for a HuggingFace model that is not cached', async function() { + this.timeout(10000); + const manager = getTestManager(); + const catalog = manager.catalog; + // getModel should NOT auto-download; it should return undefined if not cached + const model = await catalog.getModel(HF_URL); + // Model may or may not be cached depending on the test environment + // This test just verifies getModel doesn't throw + }); + + it('should download and return a HuggingFace model by URL', async function() { + this.timeout(600000); // 10 minutes - downloads can be large + const manager = getTestManager(); + const catalog = manager.catalog; + const model = await catalog.downloadModel(HF_URL); + + expect(model).to.not.be.undefined; + expect(model?.alias).to.be.a('string'); + expect(model?.id).to.be.a('string'); + }); + + it('should download HuggingFace model with progress callback', async function() { + this.timeout(600000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const progressUpdates: number[] = []; + const model = await catalog.downloadModel(HF_URL, (progress: number) => { + progressUpdates.push(progress); + }); + + expect(model).to.not.be.undefined; + expect(model?.alias).to.be.a('string'); + // If the model was already cached, there may be no progress updates + // If it was downloaded, we should have received some + }); + + it('should find HuggingFace model in cached models after download', async function() { + this.timeout(600000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const model = await catalog.downloadModel(HF_URL); + expect(model).to.not.be.undefined; + if (!model) return; + + const cachedModels = await catalog.getCachedModels(); + const found = cachedModels.find(m => m.id === model.id); + expect(found, 'HuggingFace model should appear in cached models').to.not.be.undefined; + }); + + it('should return same model on repeated download calls with HuggingFace URL', async function() { + this.timeout(600000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const model1 = await catalog.downloadModel(HF_URL); + const model2 = await catalog.downloadModel(HF_URL); + + expect(model1).to.not.be.undefined; + expect(model2).to.not.be.undefined; + expect(model1?.id).to.equal(model2?.id); + expect(model1?.alias).to.equal(model2?.alias); + }); +}); diff --git a/sdk/js/test/openai/audioClient.test.ts b/sdk/js/test/openai/audioClient.test.ts index a57c02e5..b5ef3562 100644 --- a/sdk/js/test/openai/audioClient.test.ts +++ b/sdk/js/test/openai/audioClient.test.ts @@ -19,7 +19,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -57,7 +57,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -95,7 +95,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -136,7 +136,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -169,7 +169,7 @@ describe('Audio Client Tests', () => { it('should throw when transcribing with empty audio file path', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; const audioClient = model.createAudioClient(); @@ -185,7 +185,7 @@ describe('Audio Client Tests', () => { it('should throw when transcribing streaming with empty audio file path', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; const audioClient = model.createAudioClient(); @@ -201,7 +201,7 @@ describe('Audio Client Tests', () => { it('should throw when transcribing streaming with invalid callback', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; const audioClient = model.createAudioClient(); const invalidCallbacks: any[] = [null, undefined, 42, {}, 'not-a-function']; for (const invalidCallback of invalidCallbacks) { diff --git a/sdk/js/test/openai/chatClient.test.ts b/sdk/js/test/openai/chatClient.test.ts index 5f612845..b44e4383 100644 --- a/sdk/js/test/openai/chatClient.test.ts +++ b/sdk/js/test/openai/chatClient.test.ts @@ -15,7 +15,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -58,7 +58,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -122,7 +122,7 @@ describe('Chat Client Tests', () => { it('should throw when completing chat with empty, null, or undefined messages', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); @@ -141,7 +141,7 @@ describe('Chat Client Tests', () => { it('should throw when completing chat with invalid message', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); @@ -165,7 +165,7 @@ describe('Chat Client Tests', () => { it('should throw when completing streaming chat with empty, null, or undefined messages', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); @@ -184,7 +184,7 @@ describe('Chat Client Tests', () => { it('should throw when completing streaming chat with invalid callback', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); const messages = [{ role: 'user', content: 'Hello' }]; const invalidCallbacks: any[] = [null, undefined, {} as any, 'not a function' as any]; @@ -209,7 +209,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -280,7 +280,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; diff --git a/sdk/js/test/openai/responsesClient.test.ts b/sdk/js/test/openai/responsesClient.test.ts index 925a2360..d251e40e 100644 --- a/sdk/js/test/openai/responsesClient.test.ts +++ b/sdk/js/test/openai/responsesClient.test.ts @@ -394,7 +394,7 @@ describe('ResponsesClient Tests', () => { return; } - model = await catalog.getModel(TEST_MODEL_ALIAS); + model = (await catalog.getModel(TEST_MODEL_ALIAS))!; model.selectVariant(cachedVariant); await model.load(); manager.startWebService(); From 1a688942a4d64b85a2d50fd1bad85846947bfc93 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 10:56:37 -0700 Subject: [PATCH 03/11] Add HuggingFace download to Rust SDK, and unify JS SDK --- sdk/js/src/catalog.ts | 115 +----- sdk/js/src/foundryLocalManager.ts | 26 ++ sdk/js/src/huggingFaceCatalog.ts | 315 ++++++++++++++ sdk/js/src/index.ts | 1 + sdk/js/test/catalog.test.ts | 69 ---- sdk/js/test/huggingFaceCatalog.test.ts | 84 ++++ sdk/rust/examples/chat_completion.rs | 2 +- sdk/rust/examples/interactive_chat.rs | 2 +- sdk/rust/src/catalog.rs | 21 +- sdk/rust/src/foundry_local_manager.rs | 35 ++ sdk/rust/src/hf_utils.rs | 117 ++++++ sdk/rust/src/huggingface_catalog.rs | 389 ++++++++++++++++++ sdk/rust/src/lib.rs | 3 + sdk/rust/src/types.rs | 2 + .../tests/integration/audio_client_test.rs | 3 +- sdk/rust/tests/integration/catalog_test.rs | 3 +- .../tests/integration/chat_client_test.rs | 3 +- .../integration/huggingface_catalog_test.rs | 116 ++++++ sdk/rust/tests/integration/main.rs | 1 + sdk/rust/tests/integration/model_test.rs | 28 +- .../tests/integration/web_service_test.rs | 6 +- 21 files changed, 1138 insertions(+), 203 deletions(-) create mode 100644 sdk/js/src/huggingFaceCatalog.ts create mode 100644 sdk/js/test/huggingFaceCatalog.test.ts create mode 100644 sdk/rust/src/hf_utils.rs create mode 100644 sdk/rust/src/huggingface_catalog.rs create mode 100644 sdk/rust/tests/integration/huggingface_catalog_test.rs diff --git a/sdk/js/src/catalog.ts b/sdk/js/src/catalog.ts index f6b9578a..6a1ecde3 100644 --- a/sdk/js/src/catalog.ts +++ b/sdk/js/src/catalog.ts @@ -79,34 +79,18 @@ export class Catalog { } /** - * Retrieves a model by its alias, HuggingFace URL, or org/repo identifier. - * For plain aliases, throws if the model is not found. - * For HuggingFace URLs or org/repo identifiers, returns undefined if not found. - * @param alias - The alias of the model, a HuggingFace URL, or an org/repo identifier. + * Retrieves a model by its alias. + * For HuggingFace models, use {@link HuggingFaceCatalog} via {@link FoundryLocalManager.addCatalog}. + * @param alias - The alias of the model. * @returns A Promise that resolves to the Model object if found. * @throws Error - If alias is null, undefined, or empty. - * @throws Error - If a plain alias is not found in the catalog. + * @throws Error - If the alias is not found in the catalog. */ public async getModel(alias: string): Promise { if (typeof alias !== 'string' || alias.trim() === '') { throw new Error('Model alias must be a non-empty string.'); } - const hfUrl = Catalog.normalizeToHuggingFaceUrl(alias); - if (hfUrl) { - // Force a fresh catalog refresh for HuggingFace lookups - this.lastFetch = 0; - await this.updateModels(); - - for (const [, variant] of this.modelIdToModelVariant) { - if (variant.modelInfo.uri.toLowerCase() === hfUrl.toLowerCase()) { - return this.modelAliasToModel.get(variant.alias); - } - } - - return undefined; - } - await this.updateModels(); const model = this.modelAliasToModel.get(alias); if (!model) { @@ -116,97 +100,6 @@ export class Catalog { return model; } - /** - * Downloads a model by its HuggingFace URL or org/repo identifier and adds it to the catalog. - * If the model is already cached, this is a no-op and returns the existing model. - * @param modelUri - A HuggingFace URL (https://huggingface.co/org/repo) or org/repo identifier. - * @param progressCallback - Optional callback invoked with download progress percentage (0-100). - * @returns A Promise that resolves to the downloaded Model. - * @throws Error if the URI is not a valid HuggingFace identifier or if the download fails. - */ - public async downloadModel(modelUri: string, progressCallback?: (progress: number) => void): Promise { - // Validate that this is a HuggingFace identifier - if (!Catalog.normalizeToHuggingFaceUrl(modelUri)) { - throw new Error(`'${modelUri}' is not a valid HuggingFace URL or org/repo identifier.`); - } - - // Send the original URI to Core — it handles full URLs with /tree/revision/ - // and raw org/repo/subdir strings. Do NOT send the normalized form, as Core's - // URL parser expects /tree/revision/ when the https:// prefix is present. - const request = { Params: { Model: modelUri } }; - let resultData: string; - - if (progressCallback) { - resultData = await this.coreInterop.executeCommandWithCallback( - "download_model", - request, - (progressString: string) => { - try { - const progress = JSON.parse(progressString); - if (progress.percent != null) { - progressCallback(progress.percent); - } - } catch { /* ignore malformed progress */ } - } - ); - } else { - resultData = this.coreInterop.executeCommand("download_model", request); - } - - // Force a catalog refresh to pick up the newly downloaded model - this.lastFetch = 0; - await this.updateModels(); - - // The backend returns the org/model[/subpath] identifier as resultData - const expectedUri = `https://huggingface.co/${resultData}`; - for (const [, variant] of this.modelIdToModelVariant) { - if (variant.modelInfo.uri.toLowerCase() === expectedUri.toLowerCase()) { - const model = this.modelAliasToModel.get(variant.alias); - if (model) { - return model; - } - } - } - - throw new Error(`Model '${modelUri}' was downloaded but could not be found in the catalog.`); - } - - /** - * Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. - * Strips /tree/{revision}/ from full browser URLs so the result matches the stored Info.Uri format. - * Handles: - * - "https://huggingface.co/org/repo/tree/main/sub" -> "https://huggingface.co/org/repo/sub" - * - "https://huggingface.co/org/repo" -> returned as-is - * - "org/repo[/sub]" -> "https://huggingface.co/org/repo[/sub]" - * - "phi-3-mini" (plain alias) -> null - */ - private static normalizeToHuggingFaceUrl(input: string): string | null { - const hfPrefix = "https://huggingface.co/"; - - if (input.toLowerCase().startsWith(hfPrefix)) { - // Strip /tree/{revision}/ to match the canonical form stored by Core - const path = input.substring(hfPrefix.length); - const parts = path.split('/'); - if (parts.length >= 4 && parts[2].toLowerCase() === 'tree') { - // parts[0]=org, parts[1]=repo, parts[2]="tree", parts[3]=revision, parts[4..]=subpath - const org = parts[0]; - const repo = parts[1]; - const subPath = parts.length > 4 ? parts.slice(4).join('/') : null; - return subPath - ? `${hfPrefix}${org}/${repo}/${subPath}` - : `${hfPrefix}${org}/${repo}`; - } - - return input; - } - - if (input.includes('/') && !input.toLowerCase().startsWith("azureml://")) { - return hfPrefix + input; - } - - return null; - } - /** * Retrieves a specific model variant by its ID. * This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. diff --git a/sdk/js/src/foundryLocalManager.ts b/sdk/js/src/foundryLocalManager.ts index bc408f78..0a0f5f5f 100644 --- a/sdk/js/src/foundryLocalManager.ts +++ b/sdk/js/src/foundryLocalManager.ts @@ -2,6 +2,7 @@ import { Configuration, FoundryLocalConfig } from './configuration.js'; import { CoreInterop } from './detail/coreInterop.js'; import { ModelLoadManager } from './detail/modelLoadManager.js'; import { Catalog } from './catalog.js'; +import { HuggingFaceCatalog } from './huggingFaceCatalog.js'; import { ResponsesClient } from './openai/responsesClient.js'; /** @@ -52,6 +53,31 @@ export class FoundryLocalManager { return this._catalog; } + /** + * Creates a separate HuggingFace catalog for registering and downloading + * models from HuggingFace. + * + * Three-step flow: + * 1. `addCatalog("https://huggingface.co")` — create the catalog + * 2. `catalog.registerModel("org/repo")` — register (config-only download) + * 3. `model.download()` — download ONNX files + * + * Each call creates a new instance with registrations loaded from disk. + * + * @param catalogUrl - Must contain "huggingface.co". + * @param token - Optional authentication token for private repositories. + * @returns A new HuggingFaceCatalog instance. + */ + public async addCatalog(catalogUrl: string, token?: string): Promise { + if (!catalogUrl.toLowerCase().includes("huggingface.co")) { + throw new Error( + `Unsupported catalog URL '${catalogUrl}'. Only HuggingFace catalogs (huggingface.co) are supported.` + ); + } + + return HuggingFaceCatalog.create(this.coreInterop, this._modelLoadManager, token); + } + /** * Gets the URLs where the web service is listening. * Returns an empty array if the web service is not running. diff --git a/sdk/js/src/huggingFaceCatalog.ts b/sdk/js/src/huggingFaceCatalog.ts new file mode 100644 index 00000000..ec856f26 --- /dev/null +++ b/sdk/js/src/huggingFaceCatalog.ts @@ -0,0 +1,315 @@ +import * as fs from 'fs'; +import * as path from 'path'; +import * as os from 'os'; + +import { CoreInterop } from './detail/coreInterop.js'; +import { ModelLoadManager } from './detail/modelLoadManager.js'; +import { Model } from './model.js'; +import { ModelVariant } from './modelVariant.js'; +import { ModelInfo } from './types.js'; + +/** Persistence file path relative to user home directory. */ +const REGISTRATIONS_SUBPATH = path.join('.foundry-local', 'HuggingFace', 'huggingface.modelinfo.json'); + +/** + * Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. + * Strips /tree/{revision}/ from full browser URLs so the result matches the stored Info.Uri format. + */ +function normalizeToHuggingFaceUrl(input: string): string | null { + const hfPrefix = "https://huggingface.co/"; + + if (input.toLowerCase().startsWith(hfPrefix)) { + const urlPath = input.substring(hfPrefix.length); + const parts = urlPath.split('/'); + if (parts.length >= 4 && parts[2].toLowerCase() === 'tree') { + const org = parts[0]; + const repo = parts[1]; + const subPath = parts.length > 4 ? parts.slice(4).join('/') : null; + return subPath + ? `${hfPrefix}${org}/${repo}/${subPath}` + : `${hfPrefix}${org}/${repo}`; + } + return input; + } + + if (input.includes('/') && !input.toLowerCase().startsWith("azureml://")) { + const parts = input.split('/'); + if (parts.length >= 4 && parts[2].toLowerCase() === 'tree') { + const org = parts[0]; + const repo = parts[1]; + const subPath = parts.length > 4 ? parts.slice(4).join('/') : null; + return subPath + ? `${hfPrefix}${org}/${repo}/${subPath}` + : `${hfPrefix}${org}/${repo}`; + } + return hfPrefix + input; + } + + return null; +} + +/** + * A catalog for HuggingFace models. + * + * Created via {@link FoundryLocalManager.addCatalog}. Each call creates a new + * instance with registrations loaded from disk. + * + * Three-step flow: + * ```typescript + * const hf = await manager.addCatalog("https://huggingface.co"); + * const model = await hf.registerModel("org/repo"); // config files only + * await model.download(); // ONNX files + * ``` + */ +export class HuggingFaceCatalog { + private coreInterop: CoreInterop; + private modelLoadManager: ModelLoadManager; + private token: string | undefined; + private variantsById: Map = new Map(); + private modelsById: Map = new Map(); + + private constructor(coreInterop: CoreInterop, modelLoadManager: ModelLoadManager, token?: string) { + this.coreInterop = coreInterop; + this.modelLoadManager = modelLoadManager; + this.token = token; + } + + /** + * Creates a new HuggingFaceCatalog and loads persisted registrations. + * @internal + */ + static async create(coreInterop: CoreInterop, modelLoadManager: ModelLoadManager, token?: string): Promise { + const catalog = new HuggingFaceCatalog(coreInterop, modelLoadManager, token); + catalog.loadRegistrations(); + return catalog; + } + + /** + * Gets the catalog name. + */ + public get name(): string { + return "HuggingFace"; + } + + /** + * Register a HuggingFace model by downloading its config files only (~50KB). + * + * Sends the `register_model` FFI command to the native core, which downloads + * config files (genai_config.json, config.json, tokenizer_config.json, etc.) + * and generates metadata. Returns a Model with `cached: false`. + * + * After registration, call `model.download()` to download the ONNX files. + * + * @param modelIdentifier - A HuggingFace URL or org/repo identifier. + * @returns The registered Model. + */ + public async registerModel(modelIdentifier: string): Promise { + if (!normalizeToHuggingFaceUrl(modelIdentifier)) { + throw new Error(`'${modelIdentifier}' is not a valid HuggingFace URL or org/repo identifier.`); + } + + const request = { + Params: { + Model: modelIdentifier, + Token: this.token ?? "" + } + }; + + const result = this.coreInterop.executeCommand("register_model", request); + + let modelInfo: ModelInfo; + try { + modelInfo = JSON.parse(result); + } catch { + throw new Error(`Failed to parse register_model response: ${result}`); + } + + const variant = new ModelVariant(modelInfo, this.coreInterop, this.modelLoadManager); + this.variantsById.set(modelInfo.id, variant); + + const model = new Model(variant); + this.modelsById.set(modelInfo.id, model); + + this.saveRegistrations(); + + return model; + } + + /** + * Look up a model by its ID, alias, or HuggingFace URL. + * + * Uses three-tier lookup: + * 1. Direct ID match + * 2. Alias match (case-insensitive) + * 3. URI-based match (normalize to HuggingFace URL and compare) + * + * @param identifier - Model ID, alias, or HuggingFace URL/identifier. + * @returns The Model if found, undefined otherwise. + */ + public async getModel(identifier: string): Promise { + if (typeof identifier !== 'string' || identifier.trim() === '') { + throw new Error('Model identifier must be a non-empty string.'); + } + + // 1. Direct ID match + const byId = this.modelsById.get(identifier); + if (byId) return byId; + + // 2. Alias match (case-insensitive) + for (const model of this.modelsById.values()) { + if (model.alias.toLowerCase() === identifier.toLowerCase()) { + return model; + } + } + + // 3. URI-based match + const normalizedUrl = normalizeToHuggingFaceUrl(identifier); + if (normalizedUrl) { + const normalizedLower = normalizedUrl.toLowerCase(); + const normalizedWithSlash = normalizedLower.replace(/\/+$/, '') + '/'; + for (const variant of this.variantsById.values()) { + const uriLower = variant.modelInfo.uri.toLowerCase(); + if (uriLower === normalizedLower || uriLower.startsWith(normalizedWithSlash)) { + return this.modelsById.get(variant.id); + } + } + } + + return undefined; + } + + /** + * Downloads a HuggingFace model's ONNX files. + * + * The model should have been previously registered via {@link registerModel}. + * + * @param modelUri - A HuggingFace URL or org/repo identifier. + * @param progressCallback - Optional callback invoked with download progress percentage (0-100). + * @returns The downloaded Model. + */ + public async downloadModel(modelUri: string, progressCallback?: (progress: number) => void): Promise { + if (!normalizeToHuggingFaceUrl(modelUri)) { + throw new Error(`'${modelUri}' is not a valid HuggingFace URL or org/repo identifier.`); + } + + const request = { + Params: { + Model: modelUri, + Token: this.token ?? "" + } + }; + + let resultData: string; + if (progressCallback) { + resultData = await this.coreInterop.executeCommandWithCallback( + "download_model", + request, + (progressString: string) => { + try { + const progress = JSON.parse(progressString); + if (progress.percent != null) { + progressCallback(progress.percent); + } + } catch { /* ignore malformed progress */ } + } + ); + } else { + resultData = this.coreInterop.executeCommand("download_model", request); + } + + // Match result against registered models by URI + const expectedUri = `https://huggingface.co/${resultData}`; + const expectedLower = expectedUri.toLowerCase(); + const expectedWithSlash = expectedLower.replace(/\/+$/, '') + '/'; + + for (const variant of this.variantsById.values()) { + const uriLower = variant.modelInfo.uri.toLowerCase(); + if (uriLower === expectedLower + || uriLower.startsWith(expectedWithSlash) + || expectedLower.startsWith(uriLower.replace(/\/+$/, '') + '/')) { + const model = this.modelsById.get(variant.id); + if (model) return model; + } + } + + throw new Error(`Model '${modelUri}' was downloaded but could not be found in the catalog.`); + } + + /** + * Returns all registered models. + */ + public async getModels(): Promise { + return Array.from(this.modelsById.values()); + } + + /** + * Look up a specific model variant by its unique ID. + */ + public async getModelVariant(modelId: string): Promise { + return this.variantsById.get(modelId); + } + + /** + * Returns only the model variants that are currently cached on disk. + */ + public async getCachedModels(): Promise { + const cachedModelListJson = this.coreInterop.executeCommand("get_cached_models"); + let cachedModelIds: string[]; + try { + cachedModelIds = JSON.parse(cachedModelListJson); + } catch { + return []; + } + return cachedModelIds + .map(id => this.variantsById.get(id)) + .filter((v): v is ModelVariant => v !== undefined); + } + + /** + * Returns model variants that are currently loaded into memory. + */ + public async getLoadedModels(): Promise { + const loadedModelIds = await this.modelLoadManager.listLoaded(); + return loadedModelIds + .map(id => this.variantsById.get(id)) + .filter((v): v is ModelVariant => v !== undefined); + } + + // ── Persistence ────────────────────────────────────────────────────── + + private static get registrationsPath(): string { + return path.join(os.homedir(), REGISTRATIONS_SUBPATH); + } + + private loadRegistrations(): void { + try { + const filePath = HuggingFaceCatalog.registrationsPath; + if (!fs.existsSync(filePath)) return; + + const json = fs.readFileSync(filePath, 'utf-8'); + if (!json.trim()) return; + + const infos: ModelInfo[] = JSON.parse(json); + for (const info of infos) { + const variant = new ModelVariant(info, this.coreInterop, this.modelLoadManager); + this.variantsById.set(info.id, variant); + this.modelsById.set(info.id, new Model(variant)); + } + } catch { + // Gracefully skip on any error — empty catalog is valid + } + } + + private saveRegistrations(): void { + try { + const filePath = HuggingFaceCatalog.registrationsPath; + const dir = path.dirname(filePath); + fs.mkdirSync(dir, { recursive: true }); + + const infos = Array.from(this.variantsById.values()).map(v => v.modelInfo); + fs.writeFileSync(filePath, JSON.stringify(infos, null, 2)); + } catch { + // Non-critical — loss of persistence file is not fatal + } + } +} diff --git a/sdk/js/src/index.ts b/sdk/js/src/index.ts index 7d7ee17a..d8044b75 100644 --- a/sdk/js/src/index.ts +++ b/sdk/js/src/index.ts @@ -1,6 +1,7 @@ export { FoundryLocalManager } from './foundryLocalManager.js'; export type { FoundryLocalConfig } from './configuration.js'; export { Catalog } from './catalog.js'; +export { HuggingFaceCatalog } from './huggingFaceCatalog.js'; export { Model } from './model.js'; export { ModelVariant } from './modelVariant.js'; export type { IModel } from './imodel.js'; diff --git a/sdk/js/test/catalog.test.ts b/sdk/js/test/catalog.test.ts index 12f46184..67da60d6 100644 --- a/sdk/js/test/catalog.test.ts +++ b/sdk/js/test/catalog.test.ts @@ -107,72 +107,3 @@ describe('Catalog Tests', () => { } }); }); - -describe('Catalog HuggingFace Tests', () => { - const HF_URL = 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4'; - - it('should return undefined for a HuggingFace model that is not cached', async function() { - this.timeout(10000); - const manager = getTestManager(); - const catalog = manager.catalog; - // getModel should NOT auto-download; it should return undefined if not cached - const model = await catalog.getModel(HF_URL); - // Model may or may not be cached depending on the test environment - // This test just verifies getModel doesn't throw - }); - - it('should download and return a HuggingFace model by URL', async function() { - this.timeout(600000); // 10 minutes - downloads can be large - const manager = getTestManager(); - const catalog = manager.catalog; - const model = await catalog.downloadModel(HF_URL); - - expect(model).to.not.be.undefined; - expect(model?.alias).to.be.a('string'); - expect(model?.id).to.be.a('string'); - }); - - it('should download HuggingFace model with progress callback', async function() { - this.timeout(600000); - const manager = getTestManager(); - const catalog = manager.catalog; - - const progressUpdates: number[] = []; - const model = await catalog.downloadModel(HF_URL, (progress: number) => { - progressUpdates.push(progress); - }); - - expect(model).to.not.be.undefined; - expect(model?.alias).to.be.a('string'); - // If the model was already cached, there may be no progress updates - // If it was downloaded, we should have received some - }); - - it('should find HuggingFace model in cached models after download', async function() { - this.timeout(600000); - const manager = getTestManager(); - const catalog = manager.catalog; - - const model = await catalog.downloadModel(HF_URL); - expect(model).to.not.be.undefined; - if (!model) return; - - const cachedModels = await catalog.getCachedModels(); - const found = cachedModels.find(m => m.id === model.id); - expect(found, 'HuggingFace model should appear in cached models').to.not.be.undefined; - }); - - it('should return same model on repeated download calls with HuggingFace URL', async function() { - this.timeout(600000); - const manager = getTestManager(); - const catalog = manager.catalog; - - const model1 = await catalog.downloadModel(HF_URL); - const model2 = await catalog.downloadModel(HF_URL); - - expect(model1).to.not.be.undefined; - expect(model2).to.not.be.undefined; - expect(model1?.id).to.equal(model2?.id); - expect(model1?.alias).to.equal(model2?.alias); - }); -}); diff --git a/sdk/js/test/huggingFaceCatalog.test.ts b/sdk/js/test/huggingFaceCatalog.test.ts new file mode 100644 index 00000000..c2e5c235 --- /dev/null +++ b/sdk/js/test/huggingFaceCatalog.test.ts @@ -0,0 +1,84 @@ +import { describe, it } from 'mocha'; +import { expect } from 'chai'; +import { getTestManager } from './testUtils.js'; + +const HF_URL = 'https://huggingface.co/onnxruntime/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4'; + +describe('HuggingFace Catalog Tests', () => { + it('should create huggingface catalog', async function() { + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + expect(hfCatalog.name).to.equal('HuggingFace'); + }); + + it('should reject non-huggingface url', async function() { + const manager = getTestManager(); + try { + await manager.addCatalog('https://example.com'); + expect.fail('Should have thrown an error for non-HuggingFace URL'); + } catch (error) { + expect(error).to.be.instanceOf(Error); + expect((error as Error).message).to.include('Unsupported catalog URL'); + } + }); + + it('should register model', async function() { + this.timeout(30000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + const model = await hfCatalog.registerModel(HF_URL); + expect(model.alias).to.be.a('string'); + expect(model.alias.length).to.be.greaterThan(0); + expect(model.id).to.be.a('string'); + expect(model.id.length).to.be.greaterThan(0); + }); + + it('should find registered model by identifier', async function() { + this.timeout(30000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + await hfCatalog.registerModel(HF_URL); + + const found = await hfCatalog.getModel(HF_URL); + expect(found).to.not.be.undefined; + }); + + it('should register then download model', async function() { + this.timeout(600000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + const model = await hfCatalog.registerModel(HF_URL); + expect(model.alias.length).to.be.greaterThan(0); + + // Now download the ONNX files + await model.download(); + }); + + it('should reject registration of plain alias', async function() { + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + try { + await hfCatalog.registerModel('phi-3-mini'); + expect.fail('Should have thrown an error for plain alias'); + } catch (error) { + expect(error).to.be.instanceOf(Error); + expect((error as Error).message).to.include('not a valid HuggingFace URL'); + } + }); + + it('should list registered models', async function() { + this.timeout(30000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + await hfCatalog.registerModel(HF_URL); + + const models = await hfCatalog.getModels(); + expect(models).to.be.an('array'); + expect(models.length).to.be.greaterThan(0); + }); +}); diff --git a/sdk/rust/examples/chat_completion.rs b/sdk/rust/examples/chat_completion.rs index 3516aa60..a1342552 100644 --- a/sdk/rust/examples/chat_completion.rs +++ b/sdk/rust/examples/chat_completion.rs @@ -35,7 +35,7 @@ async fn main() -> Result<()> { .or_else(|| models.first().map(|m| m.alias().to_string())) .expect("No models available in the catalog"); - let model = manager.catalog().get_model(&model_alias).await?; + let model = manager.catalog().get_model(&model_alias).await?.expect("model not found"); if !model.is_cached().await? { println!("Downloading model '{}'…", model.alias()); diff --git a/sdk/rust/examples/interactive_chat.rs b/sdk/rust/examples/interactive_chat.rs index bd230155..1467e1b5 100644 --- a/sdk/rust/examples/interactive_chat.rs +++ b/sdk/rust/examples/interactive_chat.rs @@ -36,7 +36,7 @@ async fn main() -> Result<(), Box> { .map(|m| m.alias().to_string()) .unwrap_or_else(|| models[0].alias().to_string()); - let model = catalog.get_model(&alias).await?; + let model = catalog.get_model(&alias).await?.expect("model not found"); // Download if needed if !model.is_cached().await? { diff --git a/sdk/rust/src/catalog.rs b/sdk/rust/src/catalog.rs index 78485bff..8400d012 100644 --- a/sdk/rust/src/catalog.rs +++ b/sdk/rust/src/catalog.rs @@ -21,7 +21,7 @@ const CACHE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours pub(crate) struct CacheInvalidator(Arc); impl CacheInvalidator { - fn new() -> Self { + pub(crate) fn new() -> Self { Self(Arc::new(AtomicBool::new(false))) } @@ -126,20 +126,27 @@ impl Catalog { } /// Look up a model by its alias. - pub async fn get_model(&self, alias: &str) -> Result> { + /// + /// Returns an error if not found. For HuggingFace models, use + /// [`HuggingFaceCatalog`] via [`FoundryLocalManager::add_catalog`]. + pub async fn get_model(&self, alias: &str) -> Result>> { if alias.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Model alias must be a non-empty string".into(), }); } + self.update_models().await?; let s = self.lock_state()?; - s.models_by_alias.get(alias).cloned().ok_or_else(|| { - let available: Vec<&String> = s.models_by_alias.keys().collect(); - FoundryLocalError::ModelOperation { - reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + match s.models_by_alias.get(alias) { + Some(model) => Ok(Some(Arc::clone(model))), + None => { + let available: Vec<&String> = s.models_by_alias.keys().collect(); + Err(FoundryLocalError::ModelOperation { + reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + }) } - }) + } } /// Look up a specific model variant by its unique id. diff --git a/sdk/rust/src/foundry_local_manager.rs b/sdk/rust/src/foundry_local_manager.rs index f80a7176..c9891200 100644 --- a/sdk/rust/src/foundry_local_manager.rs +++ b/sdk/rust/src/foundry_local_manager.rs @@ -13,6 +13,7 @@ use crate::configuration::{Configuration, FoundryLocalConfig, Logger}; use crate::detail::core_interop::CoreInterop; use crate::detail::ModelLoadManager; use crate::error::{FoundryLocalError, Result}; +use crate::huggingface_catalog::HuggingFaceCatalog; /// Global singleton holder — only stores a successfully initialised manager. static INSTANCE: OnceLock = OnceLock::new(); @@ -25,6 +26,7 @@ static INIT_GUARD: Mutex<()> = Mutex::new(()); /// the existing instance. pub struct FoundryLocalManager { core: Arc, + model_load_manager: Arc, catalog: Catalog, urls: Mutex>, /// Application logger (stub — not yet wired into the native core). @@ -70,6 +72,7 @@ impl FoundryLocalManager { let manager = FoundryLocalManager { core, + model_load_manager, catalog, urls: Mutex::new(Vec::new()), _logger: logger, @@ -90,6 +93,38 @@ impl FoundryLocalManager { &self.catalog } + /// Create a separate HuggingFace catalog for registering and downloading + /// models from HuggingFace. + /// + /// The three-step flow: + /// 1. `add_catalog("https://huggingface.co", None)` — create the catalog + /// 2. `catalog.register_model("org/repo")` — register (config-only download) + /// 3. `model.download(None)` — download ONNX files + /// + /// The returned catalog is owned by the caller. Each call creates a new + /// instance with registrations loaded from disk. + pub async fn add_catalog( + &self, + catalog_url: &str, + token: Option, + ) -> Result { + if !catalog_url.to_lowercase().contains("huggingface.co") { + return Err(FoundryLocalError::Validation { + reason: format!( + "Unsupported catalog URL '{}'. Only HuggingFace catalogs (huggingface.co) are supported.", + catalog_url + ), + }); + } + + HuggingFaceCatalog::create( + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + token, + ) + .await + } + /// URLs that the local web service is listening on. /// /// Empty until [`Self::start_web_service`] has been called. diff --git a/sdk/rust/src/hf_utils.rs b/sdk/rust/src/hf_utils.rs new file mode 100644 index 00000000..17c46849 --- /dev/null +++ b/sdk/rust/src/hf_utils.rs @@ -0,0 +1,117 @@ +//! Shared HuggingFace URL utilities. + +/// Normalise a model identifier to a canonical HuggingFace URL, or return +/// `None` if it is a plain alias. +/// +/// Strips `/tree/{revision}/` from full browser URLs so the result matches +/// the stored `ModelInfo.uri` format. +/// +/// # Examples +/// +/// ```text +/// "https://huggingface.co/org/repo/tree/main/sub" → Some("https://huggingface.co/org/repo/sub") +/// "https://huggingface.co/org/repo" → Some("https://huggingface.co/org/repo") +/// "org/repo/sub" → Some("https://huggingface.co/org/repo/sub") +/// "phi-3-mini" (plain alias) → None +/// ``` +pub(crate) fn normalize_to_huggingface_url(input: &str) -> Option { + const HF_PREFIX: &str = "https://huggingface.co/"; + + if input.to_lowercase().starts_with(&HF_PREFIX.to_lowercase()) { + let path = &input[HF_PREFIX.len()..]; + let parts: Vec<&str> = path.split('/').collect(); + if parts.len() >= 4 && parts[2].eq_ignore_ascii_case("tree") { + let org = parts[0]; + let repo = parts[1]; + let sub: Vec<&str> = parts[4..].to_vec(); + return if sub.is_empty() { + Some(format!("{HF_PREFIX}{org}/{repo}")) + } else { + Some(format!("{HF_PREFIX}{org}/{repo}/{}", sub.join("/"))) + }; + } + return Some(input.to_string()); + } + + if input.contains('/') && !input.to_lowercase().starts_with("azureml://") { + let parts: Vec<&str> = input.split('/').collect(); + if parts.len() >= 4 && parts[2].eq_ignore_ascii_case("tree") { + let org = parts[0]; + let repo = parts[1]; + let sub: Vec<&str> = parts[4..].to_vec(); + return if sub.is_empty() { + Some(format!("{HF_PREFIX}{org}/{repo}")) + } else { + Some(format!("{HF_PREFIX}{org}/{repo}/{}", sub.join("/"))) + }; + } + return Some(format!("{HF_PREFIX}{input}")); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn plain_alias_returns_none() { + assert!(normalize_to_huggingface_url("phi-3-mini").is_none()); + } + + #[test] + fn org_repo_returns_hf_url() { + assert_eq!( + normalize_to_huggingface_url("microsoft/Phi-3-mini-4k-instruct-onnx"), + Some("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx".into()) + ); + } + + #[test] + fn full_hf_url_passthrough() { + let url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"; + assert_eq!(normalize_to_huggingface_url(url), Some(url.into())); + } + + #[test] + fn browser_url_strips_tree_revision() { + let input = "https://huggingface.co/org/repo/tree/main/sub/path"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo/sub/path".into()) + ); + } + + #[test] + fn browser_url_no_subpath() { + let input = "https://huggingface.co/org/repo/tree/main"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo".into()) + ); + } + + #[test] + fn bare_identifier_with_tree_revision() { + let input = "org/repo/tree/main/sub"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo/sub".into()) + ); + } + + #[test] + fn org_repo_with_subpath() { + let input = "org/repo/sub/path"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo/sub/path".into()) + ); + } + + #[test] + fn azureml_url_returns_none() { + assert!(normalize_to_huggingface_url("azureml://some/model").is_none()); + } +} diff --git a/sdk/rust/src/huggingface_catalog.rs b/sdk/rust/src/huggingface_catalog.rs new file mode 100644 index 00000000..47f74a45 --- /dev/null +++ b/sdk/rust/src/huggingface_catalog.rs @@ -0,0 +1,389 @@ +//! HuggingFace model catalog — register, download, and look up HuggingFace models. +//! +//! Created via [`FoundryLocalManager::add_catalog`]. Provides the three-step +//! flow: register (config-only download) → download (ONNX files) → inference. + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +use serde_json::json; + +use crate::catalog::CacheInvalidator; +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; +use crate::hf_utils::normalize_to_huggingface_url; +use crate::model::Model; +use crate::model_variant::ModelVariant; +use crate::types::ModelInfo; + +/// Persistence file path relative to the user home directory. +const REGISTRATIONS_SUBPATH: &str = ".foundry-local/HuggingFace/huggingface.modelinfo.json"; + +/// Internal state protected by a Mutex. +struct HuggingFaceCatalogState { + variants_by_id: HashMap>, + models_by_id: HashMap>, +} + +/// A catalog for HuggingFace models. +/// +/// Created via [`FoundryLocalManager::add_catalog`]. Each call to `add_catalog` +/// creates a new instance with registrations loaded from disk. +/// +/// # Three-step flow +/// +/// ```text +/// let hf = manager.add_catalog("https://huggingface.co", None).await?; +/// let model = hf.register_model("org/repo").await?; // config files only +/// model.download::(None).await?; // ONNX files +/// ``` +pub struct HuggingFaceCatalog { + core: Arc, + model_load_manager: Arc, + token: Option, + state: Mutex, + invalidator: CacheInvalidator, +} + +impl HuggingFaceCatalog { + pub(crate) async fn create( + core: Arc, + model_load_manager: Arc, + token: Option, + ) -> Result { + let invalidator = CacheInvalidator::new(); + let catalog = Self { + core, + model_load_manager, + token, + state: Mutex::new(HuggingFaceCatalogState { + variants_by_id: HashMap::new(), + models_by_id: HashMap::new(), + }), + invalidator, + }; + catalog.load_registrations()?; + Ok(catalog) + } + + /// Catalog name. + pub fn name(&self) -> &str { + "HuggingFace" + } + + /// Register a HuggingFace model by downloading its config files only (~50KB). + /// + /// Sends the `register_model` FFI command to the native core, which downloads + /// config files (genai_config.json, config.json, tokenizer_config.json, etc.) + /// and generates metadata. Returns a `Model` with `cached: false`. + /// + /// After registration, call [`Model::download`] to download the ONNX files. + pub async fn register_model(&self, model_identifier: &str) -> Result> { + if normalize_to_huggingface_url(model_identifier).is_none() { + return Err(FoundryLocalError::Validation { + reason: format!( + "'{model_identifier}' is not a valid HuggingFace URL or org/repo identifier." + ), + }); + } + + let params = json!({ + "Params": { + "Model": model_identifier, + "Token": self.token.as_deref().unwrap_or("") + } + }); + + let result = self + .core + .execute_command_async("register_model".into(), Some(params)) + .await?; + + let model_info: ModelInfo = serde_json::from_str(&result)?; + + let model = { + let mut s = self.lock_state()?; + let variant = ModelVariant::new( + model_info.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + self.invalidator.clone(), + ); + let variant_arc = Arc::new(variant.clone()); + s.variants_by_id + .insert(model_info.id.clone(), variant_arc); + + let mut m = Model::new(model_info.alias.clone(), Arc::clone(&self.core)); + m.add_variant(variant); + let model = Arc::new(m); + s.models_by_id + .insert(model_info.id.clone(), Arc::clone(&model)); + model + }; + + self.save_registrations()?; + + Ok(model) + } + + /// Look up a model by its ID, alias, or HuggingFace URL. + /// + /// Uses three-tier lookup: + /// 1. Direct ID match + /// 2. Alias match (case-insensitive) + /// 3. URI-based match (normalize to HuggingFace URL and compare) + /// + /// Returns `Ok(None)` if the model is not found. + pub async fn get_model(&self, identifier: &str) -> Result>> { + if identifier.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "Model identifier must be a non-empty string".into(), + }); + } + + let s = self.lock_state()?; + + // 1. Direct ID match + if let Some(model) = s.models_by_id.get(identifier) { + return Ok(Some(Arc::clone(model))); + } + + // 2. Alias match (case-insensitive) + for model in s.models_by_id.values() { + if model.alias().eq_ignore_ascii_case(identifier) { + return Ok(Some(Arc::clone(model))); + } + } + + // 3. URI-based match + if let Some(normalized_url) = normalize_to_huggingface_url(identifier) { + let normalized_lower = normalized_url.to_lowercase(); + let normalized_with_slash = + format!("{}/", normalized_lower.trim_end_matches('/')); + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower == normalized_lower + || uri_lower.starts_with(&normalized_with_slash) + { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Some(Arc::clone(model))); + } + } + } + } + + Ok(None) + } + + /// Download a HuggingFace model's ONNX files. + /// + /// Sends the `download_model` FFI command. The model should have been + /// previously registered via [`HuggingFaceCatalog::register_model`]. + /// + /// If `progress` is provided, it receives human-readable progress strings + /// as the download proceeds. + pub async fn download_model( + &self, + model_uri: &str, + progress: Option, + ) -> Result> + where + F: FnMut(&str) + Send + 'static, + { + if normalize_to_huggingface_url(model_uri).is_none() { + return Err(FoundryLocalError::Validation { + reason: format!( + "'{model_uri}' is not a valid HuggingFace URL or org/repo identifier." + ), + }); + } + + let params = json!({ + "Params": { + "Model": model_uri, + "Token": self.token.as_deref().unwrap_or("") + } + }); + + let result_data = match progress { + Some(cb) => { + self.core + .execute_command_streaming_async( + "download_model".into(), + Some(params), + cb, + ) + .await? + } + None => { + self.core + .execute_command_async("download_model".into(), Some(params)) + .await? + } + }; + + // Match result against registered models by URI + let expected_uri = format!("https://huggingface.co/{result_data}"); + let expected_lower = expected_uri.to_lowercase(); + let expected_with_slash = + format!("{}/", expected_lower.trim_end_matches('/')); + + let s = self.lock_state()?; + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower == expected_lower + || uri_lower.starts_with(&expected_with_slash) + || expected_lower.starts_with( + &format!("{}/", uri_lower.trim_end_matches('/')), + ) + { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Model '{model_uri}' was downloaded but could not be found in the catalog." + ), + }) + } + + /// Return all registered models. + pub async fn get_models(&self) -> Result>> { + let s = self.lock_state()?; + Ok(s.models_by_id.values().cloned().collect()) + } + + /// Look up a specific model variant by its unique id. + pub async fn get_model_variant( + &self, + id: &str, + ) -> Result>> { + let s = self.lock_state()?; + Ok(s.variants_by_id.get(id).cloned()) + } + + /// Return only the model variants that are currently cached on disk. + pub async fn get_cached_models(&self) -> Result>> { + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(Vec::new()); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + let s = self.lock_state()?; + Ok(cached_ids + .iter() + .filter_map(|id| s.variants_by_id.get(id).cloned()) + .collect()) + } + + /// Return model variants that are currently loaded into memory. + pub async fn get_loaded_models(&self) -> Result>> { + let loaded_ids = self.model_load_manager.list_loaded().await?; + let s = self.lock_state()?; + Ok(loaded_ids + .iter() + .filter_map(|id| s.variants_by_id.get(id).cloned()) + .collect()) + } + + // ── Persistence ────────────────────────────────────────────────────── + + fn registrations_path() -> Result { + let home = home_dir().ok_or_else(|| FoundryLocalError::Internal { + reason: "Could not determine home directory".into(), + })?; + Ok(home.join(REGISTRATIONS_SUBPATH)) + } + + fn load_registrations(&self) -> Result<()> { + let path = match Self::registrations_path() { + Ok(p) => p, + Err(_) => return Ok(()), // gracefully skip if home dir unknown + }; + if !path.exists() { + return Ok(()); + } + + let json = match std::fs::read_to_string(&path) { + Ok(s) => s, + Err(_) => return Ok(()), // gracefully skip on read error + }; + if json.trim().is_empty() { + return Ok(()); + } + + let infos: Vec = match serde_json::from_str(&json) { + Ok(v) => v, + Err(_) => return Ok(()), // gracefully skip on parse error + }; + + let mut s = self.lock_state()?; + for info in infos { + let variant = ModelVariant::new( + info.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + self.invalidator.clone(), + ); + let variant_arc = Arc::new(variant.clone()); + s.variants_by_id.insert(info.id.clone(), variant_arc); + + let mut m = Model::new(info.alias.clone(), Arc::clone(&self.core)); + m.add_variant(variant); + s.models_by_id.insert(info.id.clone(), Arc::new(m)); + } + + Ok(()) + } + + fn save_registrations(&self) -> Result<()> { + let path = match Self::registrations_path() { + Ok(p) => p, + Err(_) => return Ok(()), // gracefully skip + }; + if let Some(dir) = path.parent() { + let _ = std::fs::create_dir_all(dir); + } + + let s = self.lock_state()?; + let infos: Vec<&ModelInfo> = + s.variants_by_id.values().map(|v| v.info()).collect(); + + let json = serde_json::to_string_pretty(&infos) + .map_err(|e| FoundryLocalError::Internal { + reason: format!("Failed to serialize registrations: {e}"), + })?; + std::fs::write(&path, json).map_err(|e| FoundryLocalError::Internal { + reason: format!("Failed to write registrations file: {e}"), + })?; + Ok(()) + } + + fn lock_state( + &self, + ) -> Result> { + self.state.lock().map_err(|_| FoundryLocalError::Internal { + reason: "HuggingFace catalog state mutex poisoned".into(), + }) + } +} + +/// Platform-aware home directory detection. +fn home_dir() -> Option { + #[cfg(unix)] + { + std::env::var("HOME").ok().map(PathBuf::from) + } + #[cfg(windows)] + { + std::env::var("USERPROFILE").ok().map(PathBuf::from) + } +} diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index c6d6e6c4..0bae2b07 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -6,6 +6,8 @@ mod catalog; mod configuration; mod error; mod foundry_local_manager; +mod hf_utils; +mod huggingface_catalog; mod model; mod model_variant; mod types; @@ -17,6 +19,7 @@ pub use self::catalog::Catalog; pub use self::configuration::{FoundryLocalConfig, LogLevel, Logger}; pub use self::error::FoundryLocalError; pub use self::foundry_local_manager::FoundryLocalManager; +pub use self::huggingface_catalog::HuggingFaceCatalog; pub use self::model::Model; pub use self::model_variant::ModelVariant; pub use self::types::{ diff --git a/sdk/rust/src/types.rs b/sdk/rust/src/types.rs index d1d1f002..e336e075 100644 --- a/sdk/rust/src/types.rs +++ b/sdk/rust/src/types.rs @@ -56,6 +56,8 @@ pub struct ModelInfo { pub id: String, pub name: String, pub version: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub hash: Option, pub alias: String, #[serde(skip_serializing_if = "Option::is_none")] pub display_name: Option, diff --git a/sdk/rust/tests/integration/audio_client_test.rs b/sdk/rust/tests/integration/audio_client_test.rs index 47cef9d0..8b69867c 100644 --- a/sdk/rust/tests/integration/audio_client_test.rs +++ b/sdk/rust/tests/integration/audio_client_test.rs @@ -9,7 +9,8 @@ async fn setup_audio_client() -> (AudioClient, Arc) { let model = catalog .get_model(common::WHISPER_MODEL_ALIAS) .await - .expect("get_model(whisper-tiny) failed"); + .expect("get_model(whisper-tiny) failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); (model.create_audio_client(), model) } diff --git a/sdk/rust/tests/integration/catalog_test.rs b/sdk/rust/tests/integration/catalog_test.rs index d418c7a7..d3f0fc8d 100644 --- a/sdk/rust/tests/integration/catalog_test.rs +++ b/sdk/rust/tests/integration/catalog_test.rs @@ -36,7 +36,8 @@ async fn should_get_model_by_alias() { let model = cat .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); } diff --git a/sdk/rust/tests/integration/chat_client_test.rs b/sdk/rust/tests/integration/chat_client_test.rs index b24f3804..34dab08a 100644 --- a/sdk/rust/tests/integration/chat_client_test.rs +++ b/sdk/rust/tests/integration/chat_client_test.rs @@ -15,7 +15,8 @@ async fn setup_chat_client() -> (ChatClient, Arc) { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); let client = model.create_chat_client().max_tokens(500).temperature(0.0); diff --git a/sdk/rust/tests/integration/huggingface_catalog_test.rs b/sdk/rust/tests/integration/huggingface_catalog_test.rs new file mode 100644 index 00000000..907ab2f0 --- /dev/null +++ b/sdk/rust/tests/integration/huggingface_catalog_test.rs @@ -0,0 +1,116 @@ +use super::common; + +const HF_URL: &str = "https://huggingface.co/onnxruntime/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"; + +#[tokio::test] +async fn should_create_huggingface_catalog() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + assert_eq!(hf_catalog.name(), "HuggingFace"); +} + +#[tokio::test] +async fn should_reject_non_huggingface_url() { + let manager = common::get_test_manager(); + let result = manager.add_catalog("https://example.com", None).await; + assert!(result.is_err(), "Non-HuggingFace URL should be rejected"); +} + +#[tokio::test] +async fn should_register_model() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let model = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + assert!(!model.alias().is_empty(), "Model alias should be non-empty"); + assert!(!model.id().is_empty(), "Model id should be non-empty"); +} + +#[tokio::test] +async fn should_find_registered_model_by_identifier() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let _model = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + let found = hf_catalog + .get_model(HF_URL) + .await + .expect("get_model failed"); + + assert!(found.is_some(), "Should find model by HuggingFace URL"); +} + +#[tokio::test] +async fn should_register_then_download_model() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let registered = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + assert!(!registered.alias().is_empty()); + + // Now download the ONNX files + registered + .download::(None) + .await + .expect("download failed"); +} + +#[tokio::test] +async fn should_reject_registration_of_plain_alias() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let result = hf_catalog.register_model("phi-3-mini").await; + assert!(result.is_err(), "Plain alias should be rejected"); +} + +#[tokio::test] +async fn should_list_registered_models() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let _model = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + let models = hf_catalog + .get_models() + .await + .expect("get_models failed"); + + assert!( + !models.is_empty(), + "Should have at least one registered model" + ); +} diff --git a/sdk/rust/tests/integration/main.rs b/sdk/rust/tests/integration/main.rs index 04de9a23..d4e25d1f 100644 --- a/sdk/rust/tests/integration/main.rs +++ b/sdk/rust/tests/integration/main.rs @@ -11,6 +11,7 @@ mod common; mod audio_client_test; mod catalog_test; mod chat_client_test; +mod huggingface_catalog_test; mod manager_test; mod model_test; mod web_service_test; diff --git a/sdk/rust/tests/integration/model_test.rs b/sdk/rust/tests/integration/model_test.rs index d2b68b77..02a7761e 100644 --- a/sdk/rust/tests/integration/model_test.rs +++ b/sdk/rust/tests/integration/model_test.rs @@ -38,7 +38,8 @@ async fn should_load_and_unload_model() { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); assert!( @@ -62,7 +63,8 @@ async fn should_expose_alias() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); } @@ -74,7 +76,8 @@ async fn should_expose_non_empty_id() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); println!("Model id: {}", model.id()); @@ -91,7 +94,8 @@ async fn should_have_at_least_one_variant() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let variants = model.variants(); println!("Model has {} variant(s)", variants.len()); @@ -109,7 +113,8 @@ async fn should_have_selected_variant_matching_id() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let selected = model.selected_variant(); assert_eq!( @@ -126,7 +131,8 @@ async fn should_report_cached_model_as_cached() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let cached = model.is_cached().await.expect("is_cached() should succeed"); assert!( @@ -143,7 +149,8 @@ async fn should_return_non_empty_path_for_cached_model() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let path = model.path().await.expect("path() should succeed"); println!("Model path: {}", path.display()); @@ -161,7 +168,8 @@ async fn should_select_variant_by_id() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); // Remember the original selection so we can restore it afterward. let original_id = model.id().to_string(); @@ -190,7 +198,8 @@ async fn should_fail_to_select_unknown_variant() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let result = model.select_variant("nonexistent-variant-id"); assert!( @@ -214,6 +223,7 @@ async fn get_test_model() -> Arc { .get_model(common::TEST_MODEL_ALIAS) .await .expect("get_model failed") + .expect("model not found") } #[tokio::test] diff --git a/sdk/rust/tests/integration/web_service_test.rs b/sdk/rust/tests/integration/web_service_test.rs index 9222f9d4..e6511f65 100644 --- a/sdk/rust/tests/integration/web_service_test.rs +++ b/sdk/rust/tests/integration/web_service_test.rs @@ -10,7 +10,8 @@ async fn should_complete_chat_via_rest_api() { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); manager @@ -72,7 +73,8 @@ async fn should_stream_chat_via_rest_api() { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); manager From 3c12fca1d134483bc18808b8a04fb55af60eadd9 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 12:33:27 -0700 Subject: [PATCH 04/11] Clean up --- sdk/cs/src/Detail/CoreInterop.cs | 2 ++ sdk/cs/src/Microsoft.AI.Foundry.Local.csproj | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 95073b7a..2246c4b2 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -182,6 +182,8 @@ private static void HandleCallback(nint data, int length, nint callbackHelper) callbackData = System.Text.Encoding.UTF8.GetString(managedData); } + Debug.Assert(callbackHelper != IntPtr.Zero, "Callback helper pointer is required."); + helper = (CallbackHelper)GCHandle.FromIntPtr(callbackHelper).Target!; helper.Callback.Invoke(callbackData); } diff --git a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj index b4051044..905f9652 100644 --- a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj +++ b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj @@ -118,7 +118,7 @@ + Include="Microsoft.AI.Foundry.Local.Core" Version="$(FoundryLocalCoreVersion)" /> From 6ba964a41932022cc93a48f7232538b6020195ae Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 15:31:43 -0700 Subject: [PATCH 05/11] Ensure that huggingface catalog data is persisted to the right location --- sdk/cs/src/Detail/CoreInterop.cs | 1 + sdk/cs/src/HuggingFaceCatalog.cs | 21 +++++++------ sdk/cs/src/Microsoft.AI.Foundry.Local.csproj | 2 +- sdk/js/src/huggingFaceCatalog.ts | 14 ++++----- sdk/rust/src/huggingface_catalog.rs | 32 +++++++------------- 5 files changed, 32 insertions(+), 38 deletions(-) diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 2246c4b2..8411473b 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -6,6 +6,7 @@ namespace Microsoft.AI.Foundry.Local.Detail; +using System.Diagnostics; using System.Runtime.InteropServices; using Microsoft.Extensions.Logging; diff --git a/sdk/cs/src/HuggingFaceCatalog.cs b/sdk/cs/src/HuggingFaceCatalog.cs index 1c5cb907..a564b7c4 100644 --- a/sdk/cs/src/HuggingFaceCatalog.cs +++ b/sdk/cs/src/HuggingFaceCatalog.cs @@ -287,14 +287,19 @@ private async Task> GetLoadedModelsImplAsync(CancellationToke return modelVariant; } + private string GetRegistrationsPath() + { + var result = _coreInterop.ExecuteCommand("get_cache_directory"); + var cacheDir = result.Data?.Trim().Trim('"') ?? throw new InvalidOperationException("Failed to get cache directory from Core"); + return Path.Combine(cacheDir, "HuggingFace", "huggingface.modelinfo.json"); + } + private async Task LoadRegistrationsAsync(CancellationToken? ct = null) { - // Load persisted HuggingFace registrations from local file - // File path: ~/.foundry-local/HuggingFace/huggingface.modelinfo.json + // Load persisted HuggingFace registrations from cache directory try { - var homeDir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); - var registrationsPath = Path.Combine(homeDir, ".foundry-local", "HuggingFace", "huggingface.modelinfo.json"); + var registrationsPath = GetRegistrationsPath(); if (!File.Exists(registrationsPath)) { @@ -332,13 +337,11 @@ private async Task LoadRegistrationsAsync(CancellationToken? ct = null) private async Task SaveRegistrationsAsync(CancellationToken? ct = null) { - // Save persisted HuggingFace registrations to local file - // File path: ~/.foundry-local/HuggingFace/huggingface.modelinfo.json + // Save persisted HuggingFace registrations to cache directory try { - var homeDir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); - var registrationsDir = Path.Combine(homeDir, ".foundry-local", "HuggingFace"); - var registrationsPath = Path.Combine(registrationsDir, "huggingface.modelinfo.json"); + var registrationsPath = GetRegistrationsPath(); + var registrationsDir = Path.GetDirectoryName(registrationsPath)!; // Ensure directory exists Directory.CreateDirectory(registrationsDir); diff --git a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj index 905f9652..8f2aedc1 100644 --- a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj +++ b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj @@ -118,7 +118,7 @@ + Include="Microsoft.AI.Foundry.Local.Core" Version="0.9.0-0.local.20260320140747" /> diff --git a/sdk/js/src/huggingFaceCatalog.ts b/sdk/js/src/huggingFaceCatalog.ts index ec856f26..c49e97c5 100644 --- a/sdk/js/src/huggingFaceCatalog.ts +++ b/sdk/js/src/huggingFaceCatalog.ts @@ -1,6 +1,5 @@ import * as fs from 'fs'; import * as path from 'path'; -import * as os from 'os'; import { CoreInterop } from './detail/coreInterop.js'; import { ModelLoadManager } from './detail/modelLoadManager.js'; @@ -8,8 +7,8 @@ import { Model } from './model.js'; import { ModelVariant } from './modelVariant.js'; import { ModelInfo } from './types.js'; -/** Persistence file path relative to user home directory. */ -const REGISTRATIONS_SUBPATH = path.join('.foundry-local', 'HuggingFace', 'huggingface.modelinfo.json'); +/** Filename for the HuggingFace registration persistence file. */ +const REGISTRATIONS_FILENAME = 'huggingface.modelinfo.json'; /** * Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. @@ -277,13 +276,14 @@ export class HuggingFaceCatalog { // ── Persistence ────────────────────────────────────────────────────── - private static get registrationsPath(): string { - return path.join(os.homedir(), REGISTRATIONS_SUBPATH); + private get registrationsPath(): string { + const cacheDir = this.coreInterop.executeCommand("get_cache_directory").trim().replace(/^"|"$/g, ''); + return path.join(cacheDir, 'HuggingFace', REGISTRATIONS_FILENAME); } private loadRegistrations(): void { try { - const filePath = HuggingFaceCatalog.registrationsPath; + const filePath = this.registrationsPath; if (!fs.existsSync(filePath)) return; const json = fs.readFileSync(filePath, 'utf-8'); @@ -302,7 +302,7 @@ export class HuggingFaceCatalog { private saveRegistrations(): void { try { - const filePath = HuggingFaceCatalog.registrationsPath; + const filePath = this.registrationsPath; const dir = path.dirname(filePath); fs.mkdirSync(dir, { recursive: true }); diff --git a/sdk/rust/src/huggingface_catalog.rs b/sdk/rust/src/huggingface_catalog.rs index 47f74a45..0ec07036 100644 --- a/sdk/rust/src/huggingface_catalog.rs +++ b/sdk/rust/src/huggingface_catalog.rs @@ -18,8 +18,8 @@ use crate::model::Model; use crate::model_variant::ModelVariant; use crate::types::ModelInfo; -/// Persistence file path relative to the user home directory. -const REGISTRATIONS_SUBPATH: &str = ".foundry-local/HuggingFace/huggingface.modelinfo.json"; +/// Filename for the HuggingFace registration persistence file. +const REGISTRATIONS_FILENAME: &str = "huggingface.modelinfo.json"; /// Internal state protected by a Mutex. struct HuggingFaceCatalogState { @@ -296,15 +296,17 @@ impl HuggingFaceCatalog { // ── Persistence ────────────────────────────────────────────────────── - fn registrations_path() -> Result { - let home = home_dir().ok_or_else(|| FoundryLocalError::Internal { - reason: "Could not determine home directory".into(), - })?; - Ok(home.join(REGISTRATIONS_SUBPATH)) + fn registrations_path(&self) -> Result { + let cache_dir = self + .core + .execute_command("get_cache_directory".into(), None)?; + Ok(PathBuf::from(cache_dir.trim().trim_matches('"')) + .join("HuggingFace") + .join(REGISTRATIONS_FILENAME)) } fn load_registrations(&self) -> Result<()> { - let path = match Self::registrations_path() { + let path = match self.registrations_path() { Ok(p) => p, Err(_) => return Ok(()), // gracefully skip if home dir unknown }; @@ -345,7 +347,7 @@ impl HuggingFaceCatalog { } fn save_registrations(&self) -> Result<()> { - let path = match Self::registrations_path() { + let path = match self.registrations_path() { Ok(p) => p, Err(_) => return Ok(()), // gracefully skip }; @@ -375,15 +377,3 @@ impl HuggingFaceCatalog { }) } } - -/// Platform-aware home directory detection. -fn home_dir() -> Option { - #[cfg(unix)] - { - std::env::var("HOME").ok().map(PathBuf::from) - } - #[cfg(windows)] - { - std::env::var("USERPROFILE").ok().map(PathBuf::from) - } -} From 11f94767d819abb1b3352bb881b8592bcd74f5dd Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 18:02:01 -0700 Subject: [PATCH 06/11] Moved HuggingFace catalog referencing to core --- sdk/cs/src/ModelVariant.cs | 2 +- sdk/js/src/modelVariant.ts | 5 ++++- sdk/js/src/types.ts | 2 +- sdk/rust/src/model_variant.rs | 7 ++++++- sdk/rust/src/types.rs | 2 +- 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sdk/cs/src/ModelVariant.cs b/sdk/cs/src/ModelVariant.cs index 8ed6edc6..8dbff97c 100644 --- a/sdk/cs/src/ModelVariant.cs +++ b/sdk/cs/src/ModelVariant.cs @@ -132,7 +132,7 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, { var request = new CoreInteropRequest { - Params = new() { { "Model", Id } } + Params = new() { { "Model", string.Equals(Info.ProviderType, "HuggingFace", StringComparison.OrdinalIgnoreCase) ? Info.Uri : Id } } }; ICoreInterop.Response? response; diff --git a/sdk/js/src/modelVariant.ts b/sdk/js/src/modelVariant.ts index 4d3e2bee..a79e3d49 100644 --- a/sdk/js/src/modelVariant.ts +++ b/sdk/js/src/modelVariant.ts @@ -68,7 +68,10 @@ export class ModelVariant implements IModel { * @param progressCallback - Optional callback to report download progress (0-100). */ public async download(progressCallback?: (progress: number) => void): Promise { - const request = { Params: { Model: this._modelInfo.id } }; + const modelParam = this._modelInfo.providerType?.toLowerCase() === 'huggingface' + ? this._modelInfo.uri + : this._modelInfo.id; + const request = { Params: { Model: modelParam } }; if (!progressCallback) { this.coreInterop.executeCommand("download_model", request); } else { diff --git a/sdk/js/src/types.ts b/sdk/js/src/types.ts index b5515511..c05e45c2 100644 --- a/sdk/js/src/types.ts +++ b/sdk/js/src/types.ts @@ -50,7 +50,7 @@ export interface ModelInfo { supportsToolCalling?: boolean | null; maxOutputTokens?: number | null; minFLVersion?: string | null; - createdAtUnix: number; + createdAt: number; } export interface ResponseFormat { diff --git a/sdk/rust/src/model_variant.rs b/sdk/rust/src/model_variant.rs index c4be6822..8e6ddc27 100644 --- a/sdk/rust/src/model_variant.rs +++ b/sdk/rust/src/model_variant.rs @@ -93,7 +93,12 @@ impl ModelVariant { where F: FnMut(&str) + Send + 'static, { - let params = json!({ "Params": { "Model": self.info.id } }); + let model_param = if self.info.provider_type.eq_ignore_ascii_case("huggingface") { + &self.info.uri + } else { + &self.info.id + }; + let params = json!({ "Params": { "Model": model_param } }); match progress { Some(cb) => { self.core diff --git a/sdk/rust/src/types.rs b/sdk/rust/src/types.rs index e336e075..6ea0f2b5 100644 --- a/sdk/rust/src/types.rs +++ b/sdk/rust/src/types.rs @@ -88,7 +88,7 @@ pub struct ModelInfo { #[serde(skip_serializing_if = "Option::is_none")] pub min_fl_version: Option, #[serde(default)] - pub created_at_unix: u64, + pub created_at: u64, } /// Desired response format for chat completions. From 48ee68bdbf3509d80a67573a8544fff07e5212a0 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Fri, 20 Mar 2026 19:25:35 -0700 Subject: [PATCH 07/11] Route token through SDK --- sdk/cs/src/HuggingFaceCatalog.cs | 4 ++-- sdk/cs/src/ModelVariant.cs | 9 ++++++++- sdk/js/src/huggingFaceCatalog.ts | 4 ++-- sdk/js/src/modelVariant.ts | 9 +++++++-- sdk/rust/src/catalog.rs | 1 + sdk/rust/src/huggingface_catalog.rs | 2 ++ sdk/rust/src/model_variant.rs | 10 +++++++++- 7 files changed, 31 insertions(+), 8 deletions(-) diff --git a/sdk/cs/src/HuggingFaceCatalog.cs b/sdk/cs/src/HuggingFaceCatalog.cs index a564b7c4..a7f3c7c5 100644 --- a/sdk/cs/src/HuggingFaceCatalog.cs +++ b/sdk/cs/src/HuggingFaceCatalog.cs @@ -187,7 +187,7 @@ private async Task RegisterModelImplAsync(string modelIdentifier, Cancell // Add to internal dictionaries with lock using var disposable = await _lock.LockAsync().ConfigureAwait(false); - var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger); + var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger, _token); _modelIdToModelVariant[modelInfo.Id] = variant; // Each registration is a distinct entry, keyed by Id @@ -323,7 +323,7 @@ private async Task LoadRegistrationsAsync(CancellationToken? ct = null) foreach (var modelInfo in models) { - var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger); + var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger, _token); _modelIdToModelVariant[modelInfo.Id] = variant; _modelIdToModel[modelInfo.Id] = new Model(variant, _logger); } diff --git a/sdk/cs/src/ModelVariant.cs b/sdk/cs/src/ModelVariant.cs index 8dbff97c..142aab76 100644 --- a/sdk/cs/src/ModelVariant.cs +++ b/sdk/cs/src/ModelVariant.cs @@ -14,6 +14,7 @@ public class ModelVariant : IModel private readonly IModelLoadManager _modelLoadManager; private readonly ICoreInterop _coreInterop; private readonly ILogger _logger; + private readonly string? _token; public ModelInfo Info { get; } // expose the full info record @@ -24,7 +25,7 @@ public class ModelVariant : IModel public string VersionDisplay => Info.Hash ?? Info.Version.ToString(System.Globalization.CultureInfo.InvariantCulture); internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, ICoreInterop coreInterop, - ILogger logger) + ILogger logger, string? token = null) { Info = modelInfo; Version = modelInfo.Version; @@ -32,6 +33,7 @@ internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, I _modelLoadManager = modelLoadManager; _coreInterop = coreInterop; _logger = logger; + _token = token; } @@ -135,6 +137,11 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, Params = new() { { "Model", string.Equals(Info.ProviderType, "HuggingFace", StringComparison.OrdinalIgnoreCase) ? Info.Uri : Id } } }; + if (!string.IsNullOrEmpty(_token)) + { + request.Params["Token"] = _token; + } + ICoreInterop.Response? response; if (downloadProgress == null) diff --git a/sdk/js/src/huggingFaceCatalog.ts b/sdk/js/src/huggingFaceCatalog.ts index c49e97c5..fa745a22 100644 --- a/sdk/js/src/huggingFaceCatalog.ts +++ b/sdk/js/src/huggingFaceCatalog.ts @@ -123,7 +123,7 @@ export class HuggingFaceCatalog { throw new Error(`Failed to parse register_model response: ${result}`); } - const variant = new ModelVariant(modelInfo, this.coreInterop, this.modelLoadManager); + const variant = new ModelVariant(modelInfo, this.coreInterop, this.modelLoadManager, this.token); this.variantsById.set(modelInfo.id, variant); const model = new Model(variant); @@ -291,7 +291,7 @@ export class HuggingFaceCatalog { const infos: ModelInfo[] = JSON.parse(json); for (const info of infos) { - const variant = new ModelVariant(info, this.coreInterop, this.modelLoadManager); + const variant = new ModelVariant(info, this.coreInterop, this.modelLoadManager, this.token); this.variantsById.set(info.id, variant); this.modelsById.set(info.id, new Model(variant)); } diff --git a/sdk/js/src/modelVariant.ts b/sdk/js/src/modelVariant.ts index a79e3d49..d58cdd37 100644 --- a/sdk/js/src/modelVariant.ts +++ b/sdk/js/src/modelVariant.ts @@ -14,11 +14,13 @@ export class ModelVariant implements IModel { private _modelInfo: ModelInfo; private coreInterop: CoreInterop; private modelLoadManager: ModelLoadManager; + private token?: string; - constructor(modelInfo: ModelInfo, coreInterop: CoreInterop, modelLoadManager: ModelLoadManager) { + constructor(modelInfo: ModelInfo, coreInterop: CoreInterop, modelLoadManager: ModelLoadManager, token?: string) { this._modelInfo = modelInfo; this.coreInterop = coreInterop; this.modelLoadManager = modelLoadManager; + this.token = token; } /** @@ -71,7 +73,10 @@ export class ModelVariant implements IModel { const modelParam = this._modelInfo.providerType?.toLowerCase() === 'huggingface' ? this._modelInfo.uri : this._modelInfo.id; - const request = { Params: { Model: modelParam } }; + const request: { Params: Record } = { Params: { Model: modelParam } }; + if (this.token) { + request.Params.Token = this.token; + } if (!progressCallback) { this.coreInterop.executeCommand("download_model", request); } else { diff --git a/sdk/rust/src/catalog.rs b/sdk/rust/src/catalog.rs index 8400d012..b5156dbb 100644 --- a/sdk/rust/src/catalog.rs +++ b/sdk/rust/src/catalog.rs @@ -228,6 +228,7 @@ impl Catalog { Arc::clone(&self.core), Arc::clone(&self.model_load_manager), self.invalidator.clone(), + None, ); let variant_arc = Arc::new(variant.clone()); id_map.insert(id, variant_arc); diff --git a/sdk/rust/src/huggingface_catalog.rs b/sdk/rust/src/huggingface_catalog.rs index 0ec07036..b25c6bdd 100644 --- a/sdk/rust/src/huggingface_catalog.rs +++ b/sdk/rust/src/huggingface_catalog.rs @@ -110,6 +110,7 @@ impl HuggingFaceCatalog { Arc::clone(&self.core), Arc::clone(&self.model_load_manager), self.invalidator.clone(), + self.token.clone(), ); let variant_arc = Arc::new(variant.clone()); s.variants_by_id @@ -334,6 +335,7 @@ impl HuggingFaceCatalog { Arc::clone(&self.core), Arc::clone(&self.model_load_manager), self.invalidator.clone(), + self.token.clone(), ); let variant_arc = Arc::new(variant.clone()); s.variants_by_id.insert(info.id.clone(), variant_arc); diff --git a/sdk/rust/src/model_variant.rs b/sdk/rust/src/model_variant.rs index 8e6ddc27..06c70db7 100644 --- a/sdk/rust/src/model_variant.rs +++ b/sdk/rust/src/model_variant.rs @@ -22,6 +22,7 @@ pub struct ModelVariant { core: Arc, model_load_manager: Arc, cache_invalidator: CacheInvalidator, + token: Option, } impl fmt::Debug for ModelVariant { @@ -39,12 +40,14 @@ impl ModelVariant { core: Arc, model_load_manager: Arc, cache_invalidator: CacheInvalidator, + token: Option, ) -> Self { Self { info, core, model_load_manager, cache_invalidator, + token, } } @@ -98,7 +101,12 @@ impl ModelVariant { } else { &self.info.id }; - let params = json!({ "Params": { "Model": model_param } }); + let mut params_map = serde_json::Map::new(); + params_map.insert("Model".into(), json!(model_param)); + if let Some(ref t) = self.token { + params_map.insert("Token".into(), json!(t)); + } + let params = json!({ "Params": params_map }); match progress { Some(cb) => { self.core From c7f1aa59e5a1897e69e2a3f22db79cd7c3ca6bad Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Sun, 22 Mar 2026 10:18:43 -0700 Subject: [PATCH 08/11] Fix thread safety, API consistency, and URI matching from code review C#: Acquire _lock in GetCachedModelsImplAsync, GetLoadedModelsImplAsync, and snapshot data under lock in SaveRegistrationsAsync before file I/O. JS: Remove '| undefined' from Catalog.getModel() return type to match its throwing behavior. Remove unnecessary non-null assertion in test. Rust: Change HuggingFaceCatalog::get_model to return Result> (not Option) for consistency with Catalog::get_model. Add two-pass exact-then-prefix URI matching to avoid ambiguous results. Release mutex lock before filesystem I/O in save_registrations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/cs/src/HuggingFaceCatalog.cs | 17 ++++-- sdk/js/src/catalog.ts | 2 +- sdk/js/test/catalog.test.ts | 2 +- sdk/rust/src/huggingface_catalog.rs | 53 +++++++++++++------ .../integration/huggingface_catalog_test.rs | 2 +- 5 files changed, 53 insertions(+), 23 deletions(-) diff --git a/sdk/cs/src/HuggingFaceCatalog.cs b/sdk/cs/src/HuggingFaceCatalog.cs index a7f3c7c5..d2dd3865 100644 --- a/sdk/cs/src/HuggingFaceCatalog.cs +++ b/sdk/cs/src/HuggingFaceCatalog.cs @@ -252,6 +252,7 @@ private async Task> GetCachedModelsImplAsync(CancellationToke { var cachedModelIds = await Utils.GetCachedModelIdsAsync(_coreInterop, ct).ConfigureAwait(false); + using var disposable = await _lock.LockAsync().ConfigureAwait(false); List cachedModels = new(); foreach (var modelId in cachedModelIds) { @@ -267,6 +268,8 @@ private async Task> GetCachedModelsImplAsync(CancellationToke private async Task> GetLoadedModelsImplAsync(CancellationToken? ct = null) { var loadedModelIds = await _modelLoadManager.ListLoadedModelsAsync(ct).ConfigureAwait(false); + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); List loadedModels = new(); foreach (var modelId in loadedModelIds) @@ -346,11 +349,15 @@ private async Task SaveRegistrationsAsync(CancellationToken? ct = null) // Ensure directory exists Directory.CreateDirectory(registrationsDir); - // Collect all registered models (from both dictionaries, using variants) - var models = _modelIdToModelVariant.Values - .Select(v => v.Info) - .Distinct() - .ToList(); + // Snapshot registered models under lock, then do file I/O outside it + List models; + using (await _lock.LockAsync().ConfigureAwait(false)) + { + models = _modelIdToModelVariant.Values + .Select(v => v.Info) + .Distinct() + .ToList(); + } // Serialize with pretty-printing (matching foundry.modelinfo.json style) var prettyOptions = new JsonSerializerOptions { WriteIndented = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }; diff --git a/sdk/js/src/catalog.ts b/sdk/js/src/catalog.ts index 6a1ecde3..0b989fd4 100644 --- a/sdk/js/src/catalog.ts +++ b/sdk/js/src/catalog.ts @@ -86,7 +86,7 @@ export class Catalog { * @throws Error - If alias is null, undefined, or empty. * @throws Error - If the alias is not found in the catalog. */ - public async getModel(alias: string): Promise { + public async getModel(alias: string): Promise { if (typeof alias !== 'string' || alias.trim() === '') { throw new Error('Model alias must be a non-empty string.'); } diff --git a/sdk/js/test/catalog.test.ts b/sdk/js/test/catalog.test.ts index 67da60d6..b823738a 100644 --- a/sdk/js/test/catalog.test.ts +++ b/sdk/js/test/catalog.test.ts @@ -27,7 +27,7 @@ describe('Catalog Tests', () => { it('should get model by alias', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; + const model = await catalog.getModel(TEST_MODEL_ALIAS); expect(model.alias).to.equal(TEST_MODEL_ALIAS); }); diff --git a/sdk/rust/src/huggingface_catalog.rs b/sdk/rust/src/huggingface_catalog.rs index b25c6bdd..45d9daef 100644 --- a/sdk/rust/src/huggingface_catalog.rs +++ b/sdk/rust/src/huggingface_catalog.rs @@ -136,8 +136,8 @@ impl HuggingFaceCatalog { /// 2. Alias match (case-insensitive) /// 3. URI-based match (normalize to HuggingFace URL and compare) /// - /// Returns `Ok(None)` if the model is not found. - pub async fn get_model(&self, identifier: &str) -> Result>> { + /// Returns an error if the model is not found. + pub async fn get_model(&self, identifier: &str) -> Result> { if identifier.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Model identifier must be a non-empty string".into(), @@ -148,34 +148,45 @@ impl HuggingFaceCatalog { // 1. Direct ID match if let Some(model) = s.models_by_id.get(identifier) { - return Ok(Some(Arc::clone(model))); + return Ok(Arc::clone(model)); } // 2. Alias match (case-insensitive) for model in s.models_by_id.values() { if model.alias().eq_ignore_ascii_case(identifier) { - return Ok(Some(Arc::clone(model))); + return Ok(Arc::clone(model)); } } - // 3. URI-based match + // 3. URI-based match — prefer exact, fall back to prefix if let Some(normalized_url) = normalize_to_huggingface_url(identifier) { let normalized_lower = normalized_url.to_lowercase(); let normalized_with_slash = format!("{}/", normalized_lower.trim_end_matches('/')); + + // Exact match first + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower == normalized_lower { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + // Prefix match fallback for variant in s.variants_by_id.values() { let uri_lower = variant.info().uri.to_lowercase(); - if uri_lower == normalized_lower - || uri_lower.starts_with(&normalized_with_slash) - { + if uri_lower.starts_with(&normalized_with_slash) { if let Some(model) = s.models_by_id.get(variant.id()) { - return Ok(Some(Arc::clone(model))); + return Ok(Arc::clone(model)); } } } } - Ok(None) + Err(FoundryLocalError::ModelOperation { + reason: format!("Model '{identifier}' not found in HuggingFace catalog."), + }) } /// Download a HuggingFace model's ONNX files. @@ -232,10 +243,20 @@ impl HuggingFaceCatalog { format!("{}/", expected_lower.trim_end_matches('/')); let s = self.lock_state()?; + + // Exact match first for variant in s.variants_by_id.values() { let uri_lower = variant.info().uri.to_lowercase(); - if uri_lower == expected_lower - || uri_lower.starts_with(&expected_with_slash) + if uri_lower == expected_lower { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + // Prefix match fallback + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower.starts_with(&expected_with_slash) || expected_lower.starts_with( &format!("{}/", uri_lower.trim_end_matches('/')), ) @@ -357,9 +378,11 @@ impl HuggingFaceCatalog { let _ = std::fs::create_dir_all(dir); } - let s = self.lock_state()?; - let infos: Vec<&ModelInfo> = - s.variants_by_id.values().map(|v| v.info()).collect(); + // Snapshot data under the lock, then release before doing I/O + let infos: Vec = { + let s = self.lock_state()?; + s.variants_by_id.values().map(|v| v.info().clone()).collect() + }; let json = serde_json::to_string_pretty(&infos) .map_err(|e| FoundryLocalError::Internal { diff --git a/sdk/rust/tests/integration/huggingface_catalog_test.rs b/sdk/rust/tests/integration/huggingface_catalog_test.rs index 907ab2f0..00308fd2 100644 --- a/sdk/rust/tests/integration/huggingface_catalog_test.rs +++ b/sdk/rust/tests/integration/huggingface_catalog_test.rs @@ -54,7 +54,7 @@ async fn should_find_registered_model_by_identifier() { .await .expect("get_model failed"); - assert!(found.is_some(), "Should find model by HuggingFace URL"); + assert!(!found.alias().is_empty(), "Should find model by HuggingFace URL"); } #[tokio::test] From fab281547dd17a8cff280e3f18c821d593674b33 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Sun, 22 Mar 2026 10:21:31 -0700 Subject: [PATCH 09/11] Fix CodeQL ReDoS: replace polynomial regex with string loop Replace /\/+$/ regex in huggingFaceCatalog.ts with a simple trimTrailingSlashes helper that uses a while loop, eliminating the polynomial backtracking risk on uncontrolled input. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/js/src/huggingFaceCatalog.ts | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sdk/js/src/huggingFaceCatalog.ts b/sdk/js/src/huggingFaceCatalog.ts index fa745a22..06e8291f 100644 --- a/sdk/js/src/huggingFaceCatalog.ts +++ b/sdk/js/src/huggingFaceCatalog.ts @@ -7,6 +7,12 @@ import { Model } from './model.js'; import { ModelVariant } from './modelVariant.js'; import { ModelInfo } from './types.js'; +function trimTrailingSlashes(s: string): string { + let end = s.length; + while (end > 0 && s[end - 1] === '/') end--; + return s.substring(0, end); +} + /** Filename for the HuggingFace registration persistence file. */ const REGISTRATIONS_FILENAME = 'huggingface.modelinfo.json'; @@ -165,7 +171,7 @@ export class HuggingFaceCatalog { const normalizedUrl = normalizeToHuggingFaceUrl(identifier); if (normalizedUrl) { const normalizedLower = normalizedUrl.toLowerCase(); - const normalizedWithSlash = normalizedLower.replace(/\/+$/, '') + '/'; + const normalizedWithSlash = trimTrailingSlashes(normalizedLower) + '/'; for (const variant of this.variantsById.values()) { const uriLower = variant.modelInfo.uri.toLowerCase(); if (uriLower === normalizedLower || uriLower.startsWith(normalizedWithSlash)) { @@ -219,13 +225,13 @@ export class HuggingFaceCatalog { // Match result against registered models by URI const expectedUri = `https://huggingface.co/${resultData}`; const expectedLower = expectedUri.toLowerCase(); - const expectedWithSlash = expectedLower.replace(/\/+$/, '') + '/'; + const expectedWithSlash = trimTrailingSlashes(expectedLower) + '/'; for (const variant of this.variantsById.values()) { const uriLower = variant.modelInfo.uri.toLowerCase(); if (uriLower === expectedLower || uriLower.startsWith(expectedWithSlash) - || expectedLower.startsWith(uriLower.replace(/\/+$/, '') + '/')) { + || expectedLower.startsWith(trimTrailingSlashes(uriLower) + '/')) { const model = this.modelsById.get(variant.id); if (model) return model; } From 3226bbcf7248d5d74d76b9319cfb6f0dd5b1b765 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Sun, 22 Mar 2026 10:25:17 -0700 Subject: [PATCH 10/11] Fix typo: aliaMatch -> aliasMatch in HuggingFaceCatalog.cs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/cs/src/HuggingFaceCatalog.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cs/src/HuggingFaceCatalog.cs b/sdk/cs/src/HuggingFaceCatalog.cs index d2dd3865..0e63a7d5 100644 --- a/sdk/cs/src/HuggingFaceCatalog.cs +++ b/sdk/cs/src/HuggingFaceCatalog.cs @@ -120,11 +120,11 @@ private async Task> ListModelsImplAsync(CancellationToken? ct = null } // Try alias lookup (returns first match) - var aliaMatch = _modelIdToModel.Values.FirstOrDefault(m => + var aliasMatch = _modelIdToModel.Values.FirstOrDefault(m => string.Equals(m.Alias, modelIdentifier, StringComparison.OrdinalIgnoreCase)); - if (aliaMatch != null) + if (aliasMatch != null) { - return aliaMatch; + return aliasMatch; } // Try URI-based lookup From b2ff8855e905093eb42fa1817d005fd6c5c9c89f Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Sun, 22 Mar 2026 10:28:39 -0700 Subject: [PATCH 11/11] Revert .csproj local Core version back to variable reference Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk/cs/src/Microsoft.AI.Foundry.Local.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj index 8f2aedc1..905f9652 100644 --- a/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj +++ b/sdk/cs/src/Microsoft.AI.Foundry.Local.csproj @@ -118,7 +118,7 @@ + Include="Microsoft.AI.Foundry.Local.Core" Version="$(FoundryLocalCoreVersion)" />