diff --git a/sdk/cs/src/Catalog.cs b/sdk/cs/src/Catalog.cs index eb9ba0d7..667749a3 100644 --- a/sdk/cs/src/Catalog.cs +++ b/sdk/cs/src/Catalog.cs @@ -7,6 +7,7 @@ namespace Microsoft.AI.Foundry.Local; using System; using System.Collections.Generic; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; @@ -77,6 +78,19 @@ public async Task> GetLoadedModelsAsync(CancellationToken? ct .ConfigureAwait(false); } + public async Task DownloadModelAsync(string modelUri, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => DownloadModelImplAsync(modelUri, ct), + $"Error downloading model '{modelUri}'.", _logger) + .ConfigureAwait(false); + } + + public Task RegisterModelAsync(string modelIdentifier, CancellationToken? ct = null) + { + return Task.FromException(new NotSupportedException( + "RegisterModelAsync is only available on HuggingFace catalogs. Use AddCatalogAsync(\"https://huggingface.co\") to create a HuggingFace catalog.")); + } + public async Task GetModelVariantAsync(string modelId, CancellationToken? ct = null) { return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct), @@ -126,9 +140,29 @@ private async Task> GetLoadedModelsImplAsync(CancellationToke private async Task GetModelImplAsync(string modelAlias, CancellationToken? ct = null) { + var hfUrl = NormalizeToHuggingFaceUrl(modelAlias); + if (hfUrl != null) + { + // Force a fresh catalog refresh for HuggingFace lookups + _lastFetch = DateTime.MinValue; + await UpdateModels(ct).ConfigureAwait(false); + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var matchingVariant = _modelIdToModelVariant.Values.FirstOrDefault(v => + string.Equals(v.Info.Uri, hfUrl, StringComparison.OrdinalIgnoreCase)); + + if (matchingVariant != null) + { + _modelAliasToModel.TryGetValue(matchingVariant.Alias, out Model? hfModel); + return hfModel; + } + + return null; + } + await UpdateModels(ct).ConfigureAwait(false); - using var disposable = await _lock.LockAsync().ConfigureAwait(false); + using var d = await _lock.LockAsync().ConfigureAwait(false); _modelAliasToModel.TryGetValue(modelAlias, out Model? model); return model; @@ -143,6 +177,93 @@ private async Task> GetLoadedModelsImplAsync(CancellationToke return modelVariant; } + private async Task DownloadModelImplAsync(string modelUri, CancellationToken? ct) + { + // Validate that this is a HuggingFace identifier + if (NormalizeToHuggingFaceUrl(modelUri) == null) + { + throw new FoundryLocalException( + $"'{modelUri}' is not a valid HuggingFace URL or org/repo identifier.", _logger); + } + + // Send the original URI to Core — it handles full URLs with /tree/revision/ + // and raw org/repo/subdir strings. Do NOT send the normalized form, as Core's + // URL parser expects /tree/revision/ when the https:// prefix is present. + var downloadRequest = new CoreInteropRequest + { + Params = new Dictionary { { "Model", modelUri } } + }; + + var result = await _coreInterop.ExecuteCommandAsync("download_model", downloadRequest, ct) + .ConfigureAwait(false); + + if (result.Error != null) + { + throw new FoundryLocalException( + $"Error downloading model '{modelUri}': {result.Error}", _logger); + } + + // Force a catalog refresh to pick up the newly downloaded model + _lastFetch = DateTime.MinValue; + await UpdateModels(ct).ConfigureAwait(false); + + // The backend returns the org/model URI (e.g. "microsoft/Phi-3-mini") as result.Data + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var expectedUri = $"https://huggingface.co/{result.Data}"; + var matchingVariant = _modelIdToModelVariant.Values.FirstOrDefault(v => + string.Equals(v.Info.Uri, expectedUri, StringComparison.OrdinalIgnoreCase)); + + if (matchingVariant != null) + { + _modelAliasToModel.TryGetValue(matchingVariant.Alias, out Model? hfModel); + return hfModel!; + } + + throw new FoundryLocalException( + $"Model '{modelUri}' was downloaded but could not be found in the catalog.", _logger); + } + + /// + /// Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. + /// Strips /tree/{revision}/ from full browser URLs so the result matches the stored Info.Uri format. + /// Handles: + /// - "https://huggingface.co/org/repo/tree/main/sub" -> "https://huggingface.co/org/repo/sub" + /// - "https://huggingface.co/org/repo" -> returned as-is + /// - "org/repo[/sub]" -> "https://huggingface.co/org/repo[/sub]" + /// - "phi-3-mini" (plain alias) -> null + /// + private static string? NormalizeToHuggingFaceUrl(string input) + { + const string hfPrefix = "https://huggingface.co/"; + + if (input.StartsWith(hfPrefix, StringComparison.OrdinalIgnoreCase)) + { + // Strip /tree/{revision}/ to match the canonical form stored by Core + var path = input[hfPrefix.Length..]; + var parts = path.Split('/'); + if (parts.Length >= 4 && + parts[2].Equals("tree", StringComparison.OrdinalIgnoreCase)) + { + // parts[0]=org, parts[1]=repo, parts[2]="tree", parts[3]=revision, parts[4..]=subpath + var org = parts[0]; + var repo = parts[1]; + var subPath = parts.Length > 4 ? string.Join("/", parts.Skip(4)) : null; + return subPath != null + ? $"{hfPrefix}{org}/{repo}/{subPath}" + : $"{hfPrefix}{org}/{repo}"; + } + + return input; + } + + if (input.Contains('/') && !input.StartsWith("azureml://", StringComparison.OrdinalIgnoreCase)) + { + return hfPrefix + input; + } + + return null; + } + private async Task UpdateModels(CancellationToken? ct) { // TODO: make this configurable diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index 639be3a2..c2427270 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -107,6 +107,21 @@ public async Task GetCatalogAsync(CancellationToken? ct = null) "Error getting Catalog.", _logger).ConfigureAwait(false); } + /// + /// Create a separate catalog for a HuggingFace model registry. + /// + /// URL of the catalog (must contain "huggingface.co"). + /// Optional authentication token for accessing private HuggingFace repositories. + /// Optional CancellationToken. + /// The HuggingFace catalog instance. + public async Task AddCatalogAsync(string catalogUrl, string? token = null, + CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => AddCatalogImplAsync(catalogUrl, token, ct), + $"Error adding catalog '{catalogUrl}'.", _logger) + .ConfigureAwait(false); + } + /// /// Start the optional web service. This will provide an OpenAI-compatible REST endpoint that supports /// /v1/chat_completions @@ -212,6 +227,22 @@ private async Task GetCatalogImplAsync(CancellationToken? ct = null) return _catalog; } + private async Task AddCatalogImplAsync(string catalogUrl, string? token, + CancellationToken? ct = null) + { + if (!catalogUrl.Contains("huggingface.co", StringComparison.OrdinalIgnoreCase)) + { + throw new FoundryLocalException( + $"Unsupported catalog URL '{catalogUrl}'. Only HuggingFace catalogs (huggingface.co) are supported.", + _logger); + } + +#pragma warning disable IDISP005 // Return type is not disposable + return await HuggingFaceCatalog.CreateAsync(_modelManager!, _coreInterop!, _logger, token, ct) + .ConfigureAwait(false); +#pragma warning restore IDISP005 + } + private async Task StartWebServiceImplAsync(CancellationToken? ct = null) { if (_config?.Web?.Urls == null) diff --git a/sdk/cs/src/FoundryModelInfo.cs b/sdk/cs/src/FoundryModelInfo.cs index 1f795d22..90025e51 100644 --- a/sdk/cs/src/FoundryModelInfo.cs +++ b/sdk/cs/src/FoundryModelInfo.cs @@ -65,6 +65,9 @@ public record ModelInfo [JsonPropertyName("version")] public int Version { get; init; } + [JsonPropertyName("hash")] + public string? Hash { get; init; } + [JsonPropertyName("alias")] public required string Alias { get; init; } diff --git a/sdk/cs/src/HuggingFaceCatalog.cs b/sdk/cs/src/HuggingFaceCatalog.cs new file mode 100644 index 00000000..0e63a7d5 --- /dev/null +++ b/sdk/cs/src/HuggingFaceCatalog.cs @@ -0,0 +1,423 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +using Microsoft.AI.Foundry.Local.Detail; +using Microsoft.Extensions.Logging; + +internal sealed class HuggingFaceCatalog : ICatalog, IDisposable +{ + private readonly Dictionary _modelIdToModelVariant = new(); + private readonly Dictionary _modelIdToModel = new(); + + private readonly IModelLoadManager _modelLoadManager; + private readonly ICoreInterop _coreInterop; + private readonly ILogger _logger; + private readonly AsyncLock _lock = new(); + private readonly string? _token; + + public string Name { get; init; } + + private HuggingFaceCatalog(IModelLoadManager modelLoadManager, ICoreInterop coreInterop, ILogger logger, + string? token) + { + _modelLoadManager = modelLoadManager; + _coreInterop = coreInterop; + _logger = logger; + _token = token; + + Name = "HuggingFace"; + } + + internal static async Task CreateAsync(IModelLoadManager modelManager, + ICoreInterop coreInterop, + ILogger logger, + string? token = null, + CancellationToken? ct = null) + { + var catalog = new HuggingFaceCatalog(modelManager, coreInterop, logger, token); + await catalog.LoadRegistrationsAsync(ct).ConfigureAwait(false); + return catalog; + } + + public async Task> ListModelsAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => ListModelsImplAsync(ct), + "Error listing HuggingFace models.", _logger) + .ConfigureAwait(false); + } + + public async Task GetModelAsync(string modelAlias, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetModelImplAsync(modelAlias, ct), + $"Error getting HuggingFace model '{modelAlias}'.", _logger) + .ConfigureAwait(false); + } + + public async Task DownloadModelAsync(string modelUri, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling( + () => DownloadModelImplAsync(modelUri, ct), + $"Error downloading HuggingFace model '{modelUri}'.", _logger) + .ConfigureAwait(false); + } + + public async Task RegisterModelAsync(string modelIdentifier, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => RegisterModelImplAsync(modelIdentifier, ct), + $"Error registering HuggingFace model '{modelIdentifier}'.", + _logger) + .ConfigureAwait(false); + } + + public async Task GetModelVariantAsync(string modelId, CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetModelVariantImplAsync(modelId, ct), + $"Error getting HuggingFace model variant '{modelId}'.", + _logger) + .ConfigureAwait(false); + } + + public async Task> GetCachedModelsAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetCachedModelsImplAsync(ct), + "Error getting cached HuggingFace models.", _logger) + .ConfigureAwait(false); + } + + public async Task> GetLoadedModelsAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetLoadedModelsImplAsync(ct), + "Error getting loaded HuggingFace models.", _logger) + .ConfigureAwait(false); + } + + private async Task> ListModelsImplAsync(CancellationToken? ct = null) + { + // HuggingFace catalog returns one entry per registration (each variant is individually referenceable) + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + return _modelIdToModel.Values.OrderBy(m => m.Id).ToList(); + } + + private async Task GetModelImplAsync(string modelIdentifier, CancellationToken? ct = null) + { + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + + // Try direct Id lookup first + if (_modelIdToModel.TryGetValue(modelIdentifier, out Model? model)) + { + return model; + } + + // Try alias lookup (returns first match) + var aliasMatch = _modelIdToModel.Values.FirstOrDefault(m => + string.Equals(m.Alias, modelIdentifier, StringComparison.OrdinalIgnoreCase)); + if (aliasMatch != null) + { + return aliasMatch; + } + + // Try URI-based lookup + var normalizedUrl = NormalizeToHuggingFaceUrl(modelIdentifier); + if (normalizedUrl != null) + { + var normalizedUrlWithSlash = normalizedUrl.TrimEnd('/') + "/"; + foreach (var variant in _modelIdToModelVariant.Values) + { + if (string.Equals(variant.Info.Uri, normalizedUrl, StringComparison.OrdinalIgnoreCase) || + variant.Info.Uri.StartsWith(normalizedUrlWithSlash, StringComparison.OrdinalIgnoreCase)) + { + if (_modelIdToModel.TryGetValue(variant.Id, out Model? foundModel)) + { + return foundModel; + } + } + } + } + + return null; + } + + private async Task RegisterModelImplAsync(string modelIdentifier, CancellationToken? ct = null) + { + // Validate it's a HuggingFace URL or org/repo format + var normalizedUrl = NormalizeToHuggingFaceUrl(modelIdentifier); + if (normalizedUrl == null) + { + throw new FoundryLocalException( + $"'{modelIdentifier}' is not a valid HuggingFace URL or org/repo identifier.", _logger); + } + + // Call Core to register the model (fetch metadata, generate inference_model.json, persist to + // huggingface.modelinfo.json) + var registerRequest = new CoreInteropRequest + { + Params = new Dictionary + { + { "Model", modelIdentifier }, + { "Token", _token ?? "" } + } + }; + + var result = await _coreInterop.ExecuteCommandAsync("register_model", registerRequest, ct) + .ConfigureAwait(false); + + if (result.Error != null) + { + throw new FoundryLocalException($"Error registering HuggingFace model '{modelIdentifier}': {result.Error}", + _logger); + } + + // Deserialize the returned ModelInfo + var modelInfo = JsonSerializer.Deserialize(result.Data!, JsonSerializationContext.Default.ModelInfo); + if (modelInfo == null) + { + throw new FoundryLocalException($"Failed to deserialize registered model metadata.", _logger); + } + + // Add to internal dictionaries with lock + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger, _token); + _modelIdToModelVariant[modelInfo.Id] = variant; + + // Each registration is a distinct entry, keyed by Id + var registeredModel = new Model(variant, _logger); + _modelIdToModel[modelInfo.Id] = registeredModel; + + // Persist registrations to local file + await SaveRegistrationsAsync(ct).ConfigureAwait(false); + + return registeredModel; + } + + private async Task DownloadModelImplAsync(string modelUri, CancellationToken? ct) + { + // Validate it's a HuggingFace URL or org/repo format + if (NormalizeToHuggingFaceUrl(modelUri) == null) + { + throw new FoundryLocalException( + $"'{modelUri}' is not a valid HuggingFace URL or org/repo identifier.", _logger); + } + + // Call Core's download_model command (same as existing Catalog) + var downloadRequest = new CoreInteropRequest + { + Params = new Dictionary + { + { "Model", modelUri }, + { "Token", _token ?? "" } + } + }; + + var result = await _coreInterop.ExecuteCommandAsync("download_model", downloadRequest, ct) + .ConfigureAwait(false); + + if (result.Error != null) + { + throw new FoundryLocalException($"Error downloading model '{modelUri}': {result.Error}", _logger); + } + + // The backend returns the org/model URI (e.g. "microsoft/Phi-3-mini") as result.Data + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + var expectedUri = $"https://huggingface.co/{result.Data}"; + var expectedUriWithSlash = expectedUri.TrimEnd('/') + "/"; + var matchingVariant = _modelIdToModelVariant.Values.FirstOrDefault(v => + string.Equals(v.Info.Uri, expectedUri, StringComparison.OrdinalIgnoreCase) || + v.Info.Uri.StartsWith(expectedUriWithSlash, StringComparison.OrdinalIgnoreCase) || + expectedUri.StartsWith(v.Info.Uri.TrimEnd('/') + "/", StringComparison.OrdinalIgnoreCase)); + + if (matchingVariant != null) + { + if (_modelIdToModel.TryGetValue(matchingVariant.Id, out Model? hfModel)) + { + return hfModel; + } + } + + throw new FoundryLocalException( + $"Model '{modelUri}' was downloaded but could not be found in the catalog.", _logger); + } + + private async Task> GetCachedModelsImplAsync(CancellationToken? ct = null) + { + var cachedModelIds = await Utils.GetCachedModelIdsAsync(_coreInterop, ct).ConfigureAwait(false); + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + List cachedModels = new(); + foreach (var modelId in cachedModelIds) + { + if (_modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant)) + { + cachedModels.Add(modelVariant); + } + } + + return cachedModels; + } + + private async Task> GetLoadedModelsImplAsync(CancellationToken? ct = null) + { + var loadedModelIds = await _modelLoadManager.ListLoadedModelsAsync(ct).ConfigureAwait(false); + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + List loadedModels = new(); + + foreach (var modelId in loadedModelIds) + { + if (_modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant)) + { + loadedModels.Add(modelVariant); + } + } + + return loadedModels; + } + + private async Task GetModelVariantImplAsync(string modelId, CancellationToken? ct = null) + { + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + _modelIdToModelVariant.TryGetValue(modelId, out ModelVariant? modelVariant); + return modelVariant; + } + + private string GetRegistrationsPath() + { + var result = _coreInterop.ExecuteCommand("get_cache_directory"); + var cacheDir = result.Data?.Trim().Trim('"') ?? throw new InvalidOperationException("Failed to get cache directory from Core"); + return Path.Combine(cacheDir, "HuggingFace", "huggingface.modelinfo.json"); + } + + private async Task LoadRegistrationsAsync(CancellationToken? ct = null) + { + // Load persisted HuggingFace registrations from cache directory + try + { + var registrationsPath = GetRegistrationsPath(); + + if (!File.Exists(registrationsPath)) + { + return; // No registrations yet + } + + var registrationsJson = await File.ReadAllTextAsync(registrationsPath).ConfigureAwait(false); + if (string.IsNullOrEmpty(registrationsJson)) + { + return; + } + + var models = JsonSerializer.Deserialize(registrationsJson, JsonSerializationContext.Default.ListModelInfo); + if (models == null) + { + _logger.LogDebug("Failed to deserialize HuggingFace registrations from file"); + return; + } + + using var disposable = await _lock.LockAsync().ConfigureAwait(false); + + foreach (var modelInfo in models) + { + var variant = new ModelVariant(modelInfo, _modelLoadManager, _coreInterop, _logger, _token); + _modelIdToModelVariant[modelInfo.Id] = variant; + _modelIdToModel[modelInfo.Id] = new Model(variant, _logger); + } + } + catch (Exception ex) + { + _logger.LogWarning($"Exception loading HuggingFace registrations: {ex.Message}"); + // Continue anyway — empty catalog is valid + } + } + + private async Task SaveRegistrationsAsync(CancellationToken? ct = null) + { + // Save persisted HuggingFace registrations to cache directory + try + { + var registrationsPath = GetRegistrationsPath(); + var registrationsDir = Path.GetDirectoryName(registrationsPath)!; + + // Ensure directory exists + Directory.CreateDirectory(registrationsDir); + + // Snapshot registered models under lock, then do file I/O outside it + List models; + using (await _lock.LockAsync().ConfigureAwait(false)) + { + models = _modelIdToModelVariant.Values + .Select(v => v.Info) + .Distinct() + .ToList(); + } + + // Serialize with pretty-printing (matching foundry.modelinfo.json style) + var prettyOptions = new JsonSerializerOptions { WriteIndented = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull }; + var prettyContext = new JsonSerializationContext(prettyOptions); + var json = JsonSerializer.Serialize(models, prettyContext.ListModelInfo); + await File.WriteAllTextAsync(registrationsPath, json).ConfigureAwait(false); + } + catch (Exception ex) + { + _logger.LogWarning($"Failed to save HuggingFace registrations: {ex.Message}"); + // Continue anyway — loss of persistence file is not critical + } + } + + private static string? NormalizeToHuggingFaceUrl(string input) + { + const string hfPrefix = "https://huggingface.co/"; + + if (input.StartsWith(hfPrefix, StringComparison.OrdinalIgnoreCase)) + { + // Strip /tree/{revision}/ to match the canonical form stored by Core + var path = input[hfPrefix.Length..]; + var parts = path.Split('/'); + if (parts.Length >= 4 && + parts[2].Equals("tree", StringComparison.OrdinalIgnoreCase)) + { + var org = parts[0]; + var repo = parts[1]; + var subPath = parts.Length > 4 ? string.Join("/", parts.Skip(4)) : null; + return subPath != null + ? $"{hfPrefix}{org}/{repo}/{subPath}" + : $"{hfPrefix}{org}/{repo}"; + } + + return input; + } + + if (input.Contains('/') && !input.StartsWith("azureml://", StringComparison.OrdinalIgnoreCase)) + { + // Strip /tree/{revision}/ from bare identifiers (e.g. "org/repo/tree/main/subpath") + var parts = input.Split('/'); + if (parts.Length >= 4 && + parts[2].Equals("tree", StringComparison.OrdinalIgnoreCase)) + { + var org = parts[0]; + var repo = parts[1]; + var subPath = parts.Length > 4 ? string.Join("/", parts.Skip(4)) : null; + return subPath != null + ? $"{hfPrefix}{org}/{repo}/{subPath}" + : $"{hfPrefix}{org}/{repo}"; + } + + return hfPrefix + input; + } + + return null; + } + + public void Dispose() + { + _lock?.Dispose(); + } +} diff --git a/sdk/cs/src/ICatalog.cs b/sdk/cs/src/ICatalog.cs index 35285736..90423d57 100644 --- a/sdk/cs/src/ICatalog.cs +++ b/sdk/cs/src/ICatalog.cs @@ -22,13 +22,31 @@ public interface ICatalog Task> ListModelsAsync(CancellationToken? ct = null); /// - /// Lookup a model by its alias. + /// Lookup a model by its alias, HuggingFace URL (https://huggingface.co/org/repo), or org/repo identifier. /// - /// Model alias. + /// Model alias, HuggingFace URL (https://huggingface.co/org/repo), or org/repo identifier. /// Optional CancellationToken. /// The matching Model, or null if no model with the given alias exists. Task GetModelAsync(string modelAlias, CancellationToken? ct = null); + /// + /// Download a model from its HuggingFace identifier (URL or org/repo format). + /// + /// HuggingFace URL or org/repo identifier. + /// Optional CancellationToken. + /// The downloaded Model. + Task DownloadModelAsync(string modelUri, CancellationToken? ct = null); + + /// + /// Register a HuggingFace model by downloading its config files and generating metadata. + /// Only available on HuggingFace catalogs created via . + /// + /// HuggingFace URL or org/repo identifier. + /// Optional CancellationToken. + /// The registered Model. + /// If called on a non-HuggingFace catalog. + Task RegisterModelAsync(string modelIdentifier, CancellationToken? ct = null); + /// /// Lookup a model variant by its unique model id. /// diff --git a/sdk/cs/src/Model.cs b/sdk/cs/src/Model.cs index bbbbcb5b..a46d178a 100644 --- a/sdk/cs/src/Model.cs +++ b/sdk/cs/src/Model.cs @@ -17,6 +17,7 @@ public class Model : IModel public string Alias { get; init; } public string Id => SelectedVariant.Id; + public string Uri => SelectedVariant.Info.Uri; /// /// Is the currently selected variant cached locally? diff --git a/sdk/cs/src/ModelVariant.cs b/sdk/cs/src/ModelVariant.cs index 6ca7cda7..142aab76 100644 --- a/sdk/cs/src/ModelVariant.cs +++ b/sdk/cs/src/ModelVariant.cs @@ -14,6 +14,7 @@ public class ModelVariant : IModel private readonly IModelLoadManager _modelLoadManager; private readonly ICoreInterop _coreInterop; private readonly ILogger _logger; + private readonly string? _token; public ModelInfo Info { get; } // expose the full info record @@ -21,9 +22,10 @@ public class ModelVariant : IModel public string Id => Info.Id; public string Alias => Info.Alias; public int Version { get; init; } // parsed from Info.Version if possible, else 0 + public string VersionDisplay => Info.Hash ?? Info.Version.ToString(System.Globalization.CultureInfo.InvariantCulture); internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, ICoreInterop coreInterop, - ILogger logger) + ILogger logger, string? token = null) { Info = modelInfo; Version = modelInfo.Version; @@ -31,6 +33,7 @@ internal ModelVariant(ModelInfo modelInfo, IModelLoadManager modelLoadManager, I _modelLoadManager = modelLoadManager; _coreInterop = coreInterop; _logger = logger; + _token = token; } @@ -131,9 +134,14 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, { var request = new CoreInteropRequest { - Params = new() { { "Model", Id } } + Params = new() { { "Model", string.Equals(Info.ProviderType, "HuggingFace", StringComparison.OrdinalIgnoreCase) ? Info.Uri : Id } } }; + if (!string.IsNullOrEmpty(_token)) + { + request.Params["Token"] = _token; + } + ICoreInterop.Response? response; if (downloadProgress == null) diff --git a/sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs b/sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs new file mode 100644 index 00000000..5605b2c7 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/HuggingFaceCatalogTests.cs @@ -0,0 +1,370 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using System.Text.Json; +using Microsoft.AI.Foundry.Local.Detail; +using Microsoft.Extensions.Logging; +using Moq; + +/// +/// Unit tests for HuggingFaceCatalog — validates GetModelAsync lookup by alias and URI, +/// and that DownloadModelAsync accepts model URI from registration. +/// +public class HuggingFaceCatalogTests +{ + private static ModelInfo CreateTestModelInfo(string alias, string id, string uri) + { + return new ModelInfo + { + Id = id, + Name = alias, + DisplayName = alias, + Alias = alias, + Uri = uri, + ProviderType = "HuggingFace", + Version = 0, + ModelType = "ONNX", + Publisher = "test-org", + Task = "chat-completion", + License = "MIT", + FileSizeMb = 100 + }; + } + + private static (Mock coreInterop, Mock loadManager, Mock logger) + CreateMocks(ModelInfo modelInfo) + { + var logger = Utils.CreateCapturingLoggerMock([]); + var loadManager = new Mock(); + loadManager.Setup(x => x.ListLoadedModelsAsync(It.IsAny())) + .ReturnsAsync(Array.Empty()); + + var coreInterop = new Mock(); + + // Mock register_model to return the ModelInfo JSON + var modelInfoJson = JsonSerializer.Serialize(modelInfo, JsonSerializationContext.Default.ModelInfo); + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "register_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response { Data = modelInfoJson, Error = null }); + + // Mock get_cached_models to return empty + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "get_cached_models"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response { Data = "[]", Error = null }); + + return (coreInterop, loadManager, logger); + } + + [Test] + public async Task GetModelAsync_ByAlias_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + // Register the model + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + await Assert.That(registered.Alias).IsEqualTo("phi-3-mini-4k"); + + // Lookup by alias + var found = await catalog.GetModelAsync("phi-3-mini-4k"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task GetModelAsync_ByUri_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Lookup by full URI (what SelectedVariant.Info.Uri returns) + var found = await catalog.GetModelAsync("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task GetModelAsync_ByOrgRepo_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Lookup by org/repo identifier + var found = await catalog.GetModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task GetModelAsync_NotRegistered_ReturnsNull() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + // Don't register anything — lookup should return null + var found = await catalog.GetModelAsync("nonexistent-model"); + await Assert.That(found).IsNull(); + } + + [Test] + public async Task GetModelAsync_BySubpathUri_ReturnsRegisteredModel() + { + var modelInfo = CreateTestModelInfo( + "gemma-3-4b-it", + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile:abcd1234", + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + await catalog.RegisterModelAsync("onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + // Lookup by the full URI with subpath + var found = await catalog.GetModelAsync( + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + await Assert.That(found).IsNotNull(); + await Assert.That(found!.Alias).IsEqualTo("gemma-3-4b-it"); + } + + [Test] + public async Task DownloadModelAsync_WithModelUri_SucceedsWhenModelRegistered() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + // Mock download_model to return the org/repo identifier (as Core does) + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response + { + Data = "microsoft/Phi-3-mini-4k-instruct-onnx", + Error = null + }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + // Register first + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Download using model's URI (the simplified API pattern) + var uri = registered.SelectedVariant.Info.Uri; + var downloaded = await catalog.DownloadModelAsync(uri); + + await Assert.That(downloaded).IsNotNull(); + await Assert.That(downloaded.Alias).IsEqualTo("phi-3-mini-4k"); + } + + [Test] + public async Task DownloadModelAsync_WithSubpathUri_SucceedsWhenModelRegistered() + { + var modelInfo = CreateTestModelInfo( + "gemma-3-4b-it", + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile:abcd1234", + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response + { + Data = "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile", + Error = null + }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var registered = await catalog.RegisterModelAsync( + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + // Download using model's URI (subpath model) + var uri = registered.SelectedVariant.Info.Uri; + var downloaded = await catalog.DownloadModelAsync(uri); + + await Assert.That(downloaded).IsNotNull(); + await Assert.That(downloaded.Alias).IsEqualTo("gemma-3-4b-it"); + } + + [Test] + public async Task RegisterModelAsync_Idempotent_ReturnsSameAlias() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var first = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + var second = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + await Assert.That(first.Alias).IsEqualTo(second.Alias); + } + + [Test] + public async Task SelectedVariantInfoUri_ReturnsExpectedValue() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var model = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // SelectedVariant.Info.Uri holds the HuggingFace URL + await Assert.That(model.SelectedVariant.Info.Uri) + .IsEqualTo("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + } + + [Test] + public async Task DownloadModelAsync_UsingSelectedVariantInfoUri_SucceedsWithoutRepeatingIdentifier() + { + var modelInfo = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + + var (coreInterop, loadManager, logger) = CreateMocks(modelInfo); + + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response + { + Data = "microsoft/Phi-3-mini-4k-instruct-onnx", + Error = null + }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + var registered = await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + + // Simplified pattern: use SelectedVariant.Info.Uri instead of repeating the raw string + var downloaded = await catalog.DownloadModelAsync(registered.SelectedVariant.Info.Uri); + + await Assert.That(downloaded).IsNotNull(); + await Assert.That(downloaded.Alias).IsEqualTo("phi-3-mini-4k"); + await Assert.That(downloaded.SelectedVariant.Info.Uri).IsEqualTo(registered.SelectedVariant.Info.Uri); + } + + [Test] + public async Task ListModelsAsync_ReturnsAllRegisteredModelsSortedByAlias() + { + // Register two models — need separate mock setup for sequential calls + var phi = CreateTestModelInfo( + "phi-3-mini-4k", + "microsoft/Phi-3-mini-4k-instruct-onnx:abcd1234", + "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + var gemma = CreateTestModelInfo( + "gemma-3-4b-it", + "onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile:efgh5678", + "https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var logger = Utils.CreateCapturingLoggerMock([]); + var loadManager = new Mock(); + loadManager.Setup(x => x.ListLoadedModelsAsync(It.IsAny())) + .ReturnsAsync(Array.Empty()); + var coreInterop = new Mock(); + + // Return phi on first call, gemma on second + var callIndex = 0; + var models = new[] { phi, gemma }; + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "register_model"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(() => + { + var json = JsonSerializer.Serialize(models[callIndex++], + JsonSerializationContext.Default.ModelInfo); + return new ICoreInterop.Response { Data = json, Error = null }; + }); + + coreInterop.Setup(x => x.ExecuteCommandAsync( + It.Is(s => s == "get_cached_models"), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(new ICoreInterop.Response { Data = "[]", Error = null }); + + using var catalog = await HuggingFaceCatalog.CreateAsync( + loadManager.Object, coreInterop.Object, logger.Object); + + await catalog.RegisterModelAsync("microsoft/Phi-3-mini-4k-instruct-onnx"); + await catalog.RegisterModelAsync("onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + + var list = await catalog.ListModelsAsync(); + + await Assert.That(list.Count).IsEqualTo(2); + // Sorted alphabetically by alias + await Assert.That(list[0].Alias).IsEqualTo("gemma-3-4b-it"); + await Assert.That(list[1].Alias).IsEqualTo("phi-3-mini-4k"); + // Each model has correct URI via SelectedVariant.Info.Uri + await Assert.That(list[0].SelectedVariant.Info.Uri) + .IsEqualTo("https://huggingface.co/onnxruntime/Gemma-3-ONNX/gemma-3-4b-it/cpu_and_mobile"); + await Assert.That(list[1].SelectedVariant.Info.Uri) + .IsEqualTo("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"); + } +} diff --git a/sdk/js/examples/responses.ts b/sdk/js/examples/responses.ts index fa8a6d93..cf6915b4 100644 --- a/sdk/js/examples/responses.ts +++ b/sdk/js/examples/responses.ts @@ -19,7 +19,7 @@ async function main() { // Load a model const modelAlias = 'MODEL_ALIAS'; // Replace with a valid model alias const catalog = manager.catalog; - const model = await catalog.getModel(modelAlias); + const model = (await catalog.getModel(modelAlias))!; await model.load(); console.log(`✓ Model ${model.id} loaded`); diff --git a/sdk/js/src/catalog.ts b/sdk/js/src/catalog.ts index bf2ae5c9..0b989fd4 100644 --- a/sdk/js/src/catalog.ts +++ b/sdk/js/src/catalog.ts @@ -80,15 +80,17 @@ export class Catalog { /** * Retrieves a model by its alias. - * This method is asynchronous as it may ensure the catalog is up-to-date by fetching from a remote service. - * @param alias - The alias of the model to retrieve. - * @returns A Promise that resolves to the Model object if found, otherwise throws an error. + * For HuggingFace models, use {@link HuggingFaceCatalog} via {@link FoundryLocalManager.addCatalog}. + * @param alias - The alias of the model. + * @returns A Promise that resolves to the Model object if found. * @throws Error - If alias is null, undefined, or empty. + * @throws Error - If the alias is not found in the catalog. */ public async getModel(alias: string): Promise { if (typeof alias !== 'string' || alias.trim() === '') { throw new Error('Model alias must be a non-empty string.'); } + await this.updateModels(); const model = this.modelAliasToModel.get(alias); if (!model) { diff --git a/sdk/js/src/detail/coreInterop.ts b/sdk/js/src/detail/coreInterop.ts index 167784e7..53eccd8a 100644 --- a/sdk/js/src/detail/coreInterop.ts +++ b/sdk/js/src/detail/coreInterop.ts @@ -129,6 +129,55 @@ export class CoreInterop { } } + public executeCommandWithCallback(command: string, params: any, callback: (chunk: string) => void): Promise { + const cmdBuf = koffi.alloc('char', command.length + 1); + koffi.encode(cmdBuf, 'char', command, command.length + 1); + + const dataStr = params ? JSON.stringify(params) : ''; + const dataBytes = this._toBytes(dataStr); + const dataBuf = koffi.alloc('char', dataBytes.length + 1); + koffi.encode(dataBuf, 'char', dataStr, dataBytes.length + 1); + + const cb = koffi.register((data: any, length: number, userData: any) => { + const chunk = koffi.decode(data, 'char', length); + callback(chunk); + }, koffi.pointer(CallbackType)); + + return new Promise((resolve, reject) => { + const req = { + Command: koffi.address(cmdBuf), + CommandLength: command.length, + Data: koffi.address(dataBuf), + DataLength: dataBytes.length + }; + const res = { Data: 0, DataLength: 0, Error: 0, ErrorLength: 0 }; + + this.execute_command_with_callback.async(req, res, cb, null, (err: any) => { + koffi.unregister(cb); + koffi.free(cmdBuf); + koffi.free(dataBuf); + + if (err) { + reject(err); + return; + } + + try { + if (res.Error) { + const errorMsg = koffi.decode(res.Error, 'char', res.ErrorLength); + reject(new Error(`Command '${command}' failed: ${errorMsg}`)); + } else { + const data = res.Data ? koffi.decode(res.Data, 'char', res.DataLength) : ""; + resolve(data); + } + } finally { + if (res.Data) koffi.free(res.Data); + if (res.Error) koffi.free(res.Error); + } + }); + }); + } + public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise { const cmdBuf = koffi.alloc('char', command.length + 1); koffi.encode(cmdBuf, 'char', command, command.length + 1); diff --git a/sdk/js/src/foundryLocalManager.ts b/sdk/js/src/foundryLocalManager.ts index bc408f78..0a0f5f5f 100644 --- a/sdk/js/src/foundryLocalManager.ts +++ b/sdk/js/src/foundryLocalManager.ts @@ -2,6 +2,7 @@ import { Configuration, FoundryLocalConfig } from './configuration.js'; import { CoreInterop } from './detail/coreInterop.js'; import { ModelLoadManager } from './detail/modelLoadManager.js'; import { Catalog } from './catalog.js'; +import { HuggingFaceCatalog } from './huggingFaceCatalog.js'; import { ResponsesClient } from './openai/responsesClient.js'; /** @@ -52,6 +53,31 @@ export class FoundryLocalManager { return this._catalog; } + /** + * Creates a separate HuggingFace catalog for registering and downloading + * models from HuggingFace. + * + * Three-step flow: + * 1. `addCatalog("https://huggingface.co")` — create the catalog + * 2. `catalog.registerModel("org/repo")` — register (config-only download) + * 3. `model.download()` — download ONNX files + * + * Each call creates a new instance with registrations loaded from disk. + * + * @param catalogUrl - Must contain "huggingface.co". + * @param token - Optional authentication token for private repositories. + * @returns A new HuggingFaceCatalog instance. + */ + public async addCatalog(catalogUrl: string, token?: string): Promise { + if (!catalogUrl.toLowerCase().includes("huggingface.co")) { + throw new Error( + `Unsupported catalog URL '${catalogUrl}'. Only HuggingFace catalogs (huggingface.co) are supported.` + ); + } + + return HuggingFaceCatalog.create(this.coreInterop, this._modelLoadManager, token); + } + /** * Gets the URLs where the web service is listening. * Returns an empty array if the web service is not running. diff --git a/sdk/js/src/huggingFaceCatalog.ts b/sdk/js/src/huggingFaceCatalog.ts new file mode 100644 index 00000000..06e8291f --- /dev/null +++ b/sdk/js/src/huggingFaceCatalog.ts @@ -0,0 +1,321 @@ +import * as fs from 'fs'; +import * as path from 'path'; + +import { CoreInterop } from './detail/coreInterop.js'; +import { ModelLoadManager } from './detail/modelLoadManager.js'; +import { Model } from './model.js'; +import { ModelVariant } from './modelVariant.js'; +import { ModelInfo } from './types.js'; + +function trimTrailingSlashes(s: string): string { + let end = s.length; + while (end > 0 && s[end - 1] === '/') end--; + return s.substring(0, end); +} + +/** Filename for the HuggingFace registration persistence file. */ +const REGISTRATIONS_FILENAME = 'huggingface.modelinfo.json'; + +/** + * Normalizes a model identifier to a canonical HuggingFace URL, or returns null if it's a plain alias. + * Strips /tree/{revision}/ from full browser URLs so the result matches the stored Info.Uri format. + */ +function normalizeToHuggingFaceUrl(input: string): string | null { + const hfPrefix = "https://huggingface.co/"; + + if (input.toLowerCase().startsWith(hfPrefix)) { + const urlPath = input.substring(hfPrefix.length); + const parts = urlPath.split('/'); + if (parts.length >= 4 && parts[2].toLowerCase() === 'tree') { + const org = parts[0]; + const repo = parts[1]; + const subPath = parts.length > 4 ? parts.slice(4).join('/') : null; + return subPath + ? `${hfPrefix}${org}/${repo}/${subPath}` + : `${hfPrefix}${org}/${repo}`; + } + return input; + } + + if (input.includes('/') && !input.toLowerCase().startsWith("azureml://")) { + const parts = input.split('/'); + if (parts.length >= 4 && parts[2].toLowerCase() === 'tree') { + const org = parts[0]; + const repo = parts[1]; + const subPath = parts.length > 4 ? parts.slice(4).join('/') : null; + return subPath + ? `${hfPrefix}${org}/${repo}/${subPath}` + : `${hfPrefix}${org}/${repo}`; + } + return hfPrefix + input; + } + + return null; +} + +/** + * A catalog for HuggingFace models. + * + * Created via {@link FoundryLocalManager.addCatalog}. Each call creates a new + * instance with registrations loaded from disk. + * + * Three-step flow: + * ```typescript + * const hf = await manager.addCatalog("https://huggingface.co"); + * const model = await hf.registerModel("org/repo"); // config files only + * await model.download(); // ONNX files + * ``` + */ +export class HuggingFaceCatalog { + private coreInterop: CoreInterop; + private modelLoadManager: ModelLoadManager; + private token: string | undefined; + private variantsById: Map = new Map(); + private modelsById: Map = new Map(); + + private constructor(coreInterop: CoreInterop, modelLoadManager: ModelLoadManager, token?: string) { + this.coreInterop = coreInterop; + this.modelLoadManager = modelLoadManager; + this.token = token; + } + + /** + * Creates a new HuggingFaceCatalog and loads persisted registrations. + * @internal + */ + static async create(coreInterop: CoreInterop, modelLoadManager: ModelLoadManager, token?: string): Promise { + const catalog = new HuggingFaceCatalog(coreInterop, modelLoadManager, token); + catalog.loadRegistrations(); + return catalog; + } + + /** + * Gets the catalog name. + */ + public get name(): string { + return "HuggingFace"; + } + + /** + * Register a HuggingFace model by downloading its config files only (~50KB). + * + * Sends the `register_model` FFI command to the native core, which downloads + * config files (genai_config.json, config.json, tokenizer_config.json, etc.) + * and generates metadata. Returns a Model with `cached: false`. + * + * After registration, call `model.download()` to download the ONNX files. + * + * @param modelIdentifier - A HuggingFace URL or org/repo identifier. + * @returns The registered Model. + */ + public async registerModel(modelIdentifier: string): Promise { + if (!normalizeToHuggingFaceUrl(modelIdentifier)) { + throw new Error(`'${modelIdentifier}' is not a valid HuggingFace URL or org/repo identifier.`); + } + + const request = { + Params: { + Model: modelIdentifier, + Token: this.token ?? "" + } + }; + + const result = this.coreInterop.executeCommand("register_model", request); + + let modelInfo: ModelInfo; + try { + modelInfo = JSON.parse(result); + } catch { + throw new Error(`Failed to parse register_model response: ${result}`); + } + + const variant = new ModelVariant(modelInfo, this.coreInterop, this.modelLoadManager, this.token); + this.variantsById.set(modelInfo.id, variant); + + const model = new Model(variant); + this.modelsById.set(modelInfo.id, model); + + this.saveRegistrations(); + + return model; + } + + /** + * Look up a model by its ID, alias, or HuggingFace URL. + * + * Uses three-tier lookup: + * 1. Direct ID match + * 2. Alias match (case-insensitive) + * 3. URI-based match (normalize to HuggingFace URL and compare) + * + * @param identifier - Model ID, alias, or HuggingFace URL/identifier. + * @returns The Model if found, undefined otherwise. + */ + public async getModel(identifier: string): Promise { + if (typeof identifier !== 'string' || identifier.trim() === '') { + throw new Error('Model identifier must be a non-empty string.'); + } + + // 1. Direct ID match + const byId = this.modelsById.get(identifier); + if (byId) return byId; + + // 2. Alias match (case-insensitive) + for (const model of this.modelsById.values()) { + if (model.alias.toLowerCase() === identifier.toLowerCase()) { + return model; + } + } + + // 3. URI-based match + const normalizedUrl = normalizeToHuggingFaceUrl(identifier); + if (normalizedUrl) { + const normalizedLower = normalizedUrl.toLowerCase(); + const normalizedWithSlash = trimTrailingSlashes(normalizedLower) + '/'; + for (const variant of this.variantsById.values()) { + const uriLower = variant.modelInfo.uri.toLowerCase(); + if (uriLower === normalizedLower || uriLower.startsWith(normalizedWithSlash)) { + return this.modelsById.get(variant.id); + } + } + } + + return undefined; + } + + /** + * Downloads a HuggingFace model's ONNX files. + * + * The model should have been previously registered via {@link registerModel}. + * + * @param modelUri - A HuggingFace URL or org/repo identifier. + * @param progressCallback - Optional callback invoked with download progress percentage (0-100). + * @returns The downloaded Model. + */ + public async downloadModel(modelUri: string, progressCallback?: (progress: number) => void): Promise { + if (!normalizeToHuggingFaceUrl(modelUri)) { + throw new Error(`'${modelUri}' is not a valid HuggingFace URL or org/repo identifier.`); + } + + const request = { + Params: { + Model: modelUri, + Token: this.token ?? "" + } + }; + + let resultData: string; + if (progressCallback) { + resultData = await this.coreInterop.executeCommandWithCallback( + "download_model", + request, + (progressString: string) => { + try { + const progress = JSON.parse(progressString); + if (progress.percent != null) { + progressCallback(progress.percent); + } + } catch { /* ignore malformed progress */ } + } + ); + } else { + resultData = this.coreInterop.executeCommand("download_model", request); + } + + // Match result against registered models by URI + const expectedUri = `https://huggingface.co/${resultData}`; + const expectedLower = expectedUri.toLowerCase(); + const expectedWithSlash = trimTrailingSlashes(expectedLower) + '/'; + + for (const variant of this.variantsById.values()) { + const uriLower = variant.modelInfo.uri.toLowerCase(); + if (uriLower === expectedLower + || uriLower.startsWith(expectedWithSlash) + || expectedLower.startsWith(trimTrailingSlashes(uriLower) + '/')) { + const model = this.modelsById.get(variant.id); + if (model) return model; + } + } + + throw new Error(`Model '${modelUri}' was downloaded but could not be found in the catalog.`); + } + + /** + * Returns all registered models. + */ + public async getModels(): Promise { + return Array.from(this.modelsById.values()); + } + + /** + * Look up a specific model variant by its unique ID. + */ + public async getModelVariant(modelId: string): Promise { + return this.variantsById.get(modelId); + } + + /** + * Returns only the model variants that are currently cached on disk. + */ + public async getCachedModels(): Promise { + const cachedModelListJson = this.coreInterop.executeCommand("get_cached_models"); + let cachedModelIds: string[]; + try { + cachedModelIds = JSON.parse(cachedModelListJson); + } catch { + return []; + } + return cachedModelIds + .map(id => this.variantsById.get(id)) + .filter((v): v is ModelVariant => v !== undefined); + } + + /** + * Returns model variants that are currently loaded into memory. + */ + public async getLoadedModels(): Promise { + const loadedModelIds = await this.modelLoadManager.listLoaded(); + return loadedModelIds + .map(id => this.variantsById.get(id)) + .filter((v): v is ModelVariant => v !== undefined); + } + + // ── Persistence ────────────────────────────────────────────────────── + + private get registrationsPath(): string { + const cacheDir = this.coreInterop.executeCommand("get_cache_directory").trim().replace(/^"|"$/g, ''); + return path.join(cacheDir, 'HuggingFace', REGISTRATIONS_FILENAME); + } + + private loadRegistrations(): void { + try { + const filePath = this.registrationsPath; + if (!fs.existsSync(filePath)) return; + + const json = fs.readFileSync(filePath, 'utf-8'); + if (!json.trim()) return; + + const infos: ModelInfo[] = JSON.parse(json); + for (const info of infos) { + const variant = new ModelVariant(info, this.coreInterop, this.modelLoadManager, this.token); + this.variantsById.set(info.id, variant); + this.modelsById.set(info.id, new Model(variant)); + } + } catch { + // Gracefully skip on any error — empty catalog is valid + } + } + + private saveRegistrations(): void { + try { + const filePath = this.registrationsPath; + const dir = path.dirname(filePath); + fs.mkdirSync(dir, { recursive: true }); + + const infos = Array.from(this.variantsById.values()).map(v => v.modelInfo); + fs.writeFileSync(filePath, JSON.stringify(infos, null, 2)); + } catch { + // Non-critical — loss of persistence file is not fatal + } + } +} diff --git a/sdk/js/src/index.ts b/sdk/js/src/index.ts index 7d7ee17a..d8044b75 100644 --- a/sdk/js/src/index.ts +++ b/sdk/js/src/index.ts @@ -1,6 +1,7 @@ export { FoundryLocalManager } from './foundryLocalManager.js'; export type { FoundryLocalConfig } from './configuration.js'; export { Catalog } from './catalog.js'; +export { HuggingFaceCatalog } from './huggingFaceCatalog.js'; export { Model } from './model.js'; export { ModelVariant } from './modelVariant.js'; export type { IModel } from './imodel.js'; diff --git a/sdk/js/src/modelVariant.ts b/sdk/js/src/modelVariant.ts index 4d3e2bee..d58cdd37 100644 --- a/sdk/js/src/modelVariant.ts +++ b/sdk/js/src/modelVariant.ts @@ -14,11 +14,13 @@ export class ModelVariant implements IModel { private _modelInfo: ModelInfo; private coreInterop: CoreInterop; private modelLoadManager: ModelLoadManager; + private token?: string; - constructor(modelInfo: ModelInfo, coreInterop: CoreInterop, modelLoadManager: ModelLoadManager) { + constructor(modelInfo: ModelInfo, coreInterop: CoreInterop, modelLoadManager: ModelLoadManager, token?: string) { this._modelInfo = modelInfo; this.coreInterop = coreInterop; this.modelLoadManager = modelLoadManager; + this.token = token; } /** @@ -68,7 +70,13 @@ export class ModelVariant implements IModel { * @param progressCallback - Optional callback to report download progress (0-100). */ public async download(progressCallback?: (progress: number) => void): Promise { - const request = { Params: { Model: this._modelInfo.id } }; + const modelParam = this._modelInfo.providerType?.toLowerCase() === 'huggingface' + ? this._modelInfo.uri + : this._modelInfo.id; + const request: { Params: Record } = { Params: { Model: modelParam } }; + if (this.token) { + request.Params.Token = this.token; + } if (!progressCallback) { this.coreInterop.executeCommand("download_model", request); } else { diff --git a/sdk/js/src/types.ts b/sdk/js/src/types.ts index 639676de..c05e45c2 100644 --- a/sdk/js/src/types.ts +++ b/sdk/js/src/types.ts @@ -32,6 +32,7 @@ export interface ModelInfo { id: string; name: string; version: number; + hash?: string | null; alias: string; displayName?: string | null; providerType: string; @@ -49,7 +50,7 @@ export interface ModelInfo { supportsToolCalling?: boolean | null; maxOutputTokens?: number | null; minFLVersion?: string | null; - createdAtUnix: number; + createdAt: number; } export interface ResponseFormat { diff --git a/sdk/js/test/catalog.test.ts b/sdk/js/test/catalog.test.ts index df47d4f6..b823738a 100644 --- a/sdk/js/test/catalog.test.ts +++ b/sdk/js/test/catalog.test.ts @@ -28,7 +28,7 @@ describe('Catalog Tests', () => { const manager = getTestManager(); const catalog = manager.catalog; const model = await catalog.getModel(TEST_MODEL_ALIAS); - + expect(model.alias).to.equal(TEST_MODEL_ALIAS); }); @@ -95,7 +95,7 @@ describe('Catalog Tests', () => { it('should throw when getting model variant with unknown ID', async function() { const manager = getTestManager(); const catalog = manager.catalog; - + const unknownId = 'definitely-not-a-real-model-id-12345'; try { await catalog.getModelVariant(unknownId); diff --git a/sdk/js/test/huggingFaceCatalog.test.ts b/sdk/js/test/huggingFaceCatalog.test.ts new file mode 100644 index 00000000..c2e5c235 --- /dev/null +++ b/sdk/js/test/huggingFaceCatalog.test.ts @@ -0,0 +1,84 @@ +import { describe, it } from 'mocha'; +import { expect } from 'chai'; +import { getTestManager } from './testUtils.js'; + +const HF_URL = 'https://huggingface.co/onnxruntime/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4'; + +describe('HuggingFace Catalog Tests', () => { + it('should create huggingface catalog', async function() { + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + expect(hfCatalog.name).to.equal('HuggingFace'); + }); + + it('should reject non-huggingface url', async function() { + const manager = getTestManager(); + try { + await manager.addCatalog('https://example.com'); + expect.fail('Should have thrown an error for non-HuggingFace URL'); + } catch (error) { + expect(error).to.be.instanceOf(Error); + expect((error as Error).message).to.include('Unsupported catalog URL'); + } + }); + + it('should register model', async function() { + this.timeout(30000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + const model = await hfCatalog.registerModel(HF_URL); + expect(model.alias).to.be.a('string'); + expect(model.alias.length).to.be.greaterThan(0); + expect(model.id).to.be.a('string'); + expect(model.id.length).to.be.greaterThan(0); + }); + + it('should find registered model by identifier', async function() { + this.timeout(30000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + await hfCatalog.registerModel(HF_URL); + + const found = await hfCatalog.getModel(HF_URL); + expect(found).to.not.be.undefined; + }); + + it('should register then download model', async function() { + this.timeout(600000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + const model = await hfCatalog.registerModel(HF_URL); + expect(model.alias.length).to.be.greaterThan(0); + + // Now download the ONNX files + await model.download(); + }); + + it('should reject registration of plain alias', async function() { + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + try { + await hfCatalog.registerModel('phi-3-mini'); + expect.fail('Should have thrown an error for plain alias'); + } catch (error) { + expect(error).to.be.instanceOf(Error); + expect((error as Error).message).to.include('not a valid HuggingFace URL'); + } + }); + + it('should list registered models', async function() { + this.timeout(30000); + const manager = getTestManager(); + const hfCatalog = await manager.addCatalog('https://huggingface.co'); + + await hfCatalog.registerModel(HF_URL); + + const models = await hfCatalog.getModels(); + expect(models).to.be.an('array'); + expect(models.length).to.be.greaterThan(0); + }); +}); diff --git a/sdk/js/test/openai/audioClient.test.ts b/sdk/js/test/openai/audioClient.test.ts index a57c02e5..b5ef3562 100644 --- a/sdk/js/test/openai/audioClient.test.ts +++ b/sdk/js/test/openai/audioClient.test.ts @@ -19,7 +19,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -57,7 +57,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -95,7 +95,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -136,7 +136,7 @@ describe('Audio Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === WHISPER_MODEL_ALIAS); expect(cachedVariant, 'whisper-tiny should be cached').to.not.be.undefined; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -169,7 +169,7 @@ describe('Audio Client Tests', () => { it('should throw when transcribing with empty audio file path', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; const audioClient = model.createAudioClient(); @@ -185,7 +185,7 @@ describe('Audio Client Tests', () => { it('should throw when transcribing streaming with empty audio file path', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; const audioClient = model.createAudioClient(); @@ -201,7 +201,7 @@ describe('Audio Client Tests', () => { it('should throw when transcribing streaming with invalid callback', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(WHISPER_MODEL_ALIAS); + const model = (await catalog.getModel(WHISPER_MODEL_ALIAS))!; const audioClient = model.createAudioClient(); const invalidCallbacks: any[] = [null, undefined, 42, {}, 'not-a-function']; for (const invalidCallback of invalidCallbacks) { diff --git a/sdk/js/test/openai/chatClient.test.ts b/sdk/js/test/openai/chatClient.test.ts index 5f612845..b44e4383 100644 --- a/sdk/js/test/openai/chatClient.test.ts +++ b/sdk/js/test/openai/chatClient.test.ts @@ -15,7 +15,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -58,7 +58,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -122,7 +122,7 @@ describe('Chat Client Tests', () => { it('should throw when completing chat with empty, null, or undefined messages', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); @@ -141,7 +141,7 @@ describe('Chat Client Tests', () => { it('should throw when completing chat with invalid message', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); @@ -165,7 +165,7 @@ describe('Chat Client Tests', () => { it('should throw when completing streaming chat with empty, null, or undefined messages', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); @@ -184,7 +184,7 @@ describe('Chat Client Tests', () => { it('should throw when completing streaming chat with invalid callback', async function() { const manager = getTestManager(); const catalog = manager.catalog; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; const client = model.createChatClient(); const messages = [{ role: 'user', content: 'Hello' }]; const invalidCallbacks: any[] = [null, undefined, {} as any, 'not a function' as any]; @@ -209,7 +209,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; @@ -280,7 +280,7 @@ describe('Chat Client Tests', () => { const cachedVariant = cachedModels.find(m => m.alias === TEST_MODEL_ALIAS); expect(cachedVariant).to.not.be.undefined; - const model = await catalog.getModel(TEST_MODEL_ALIAS); + const model = (await catalog.getModel(TEST_MODEL_ALIAS))!; expect(model).to.not.be.undefined; if (!cachedVariant) return; diff --git a/sdk/js/test/openai/responsesClient.test.ts b/sdk/js/test/openai/responsesClient.test.ts index 925a2360..d251e40e 100644 --- a/sdk/js/test/openai/responsesClient.test.ts +++ b/sdk/js/test/openai/responsesClient.test.ts @@ -394,7 +394,7 @@ describe('ResponsesClient Tests', () => { return; } - model = await catalog.getModel(TEST_MODEL_ALIAS); + model = (await catalog.getModel(TEST_MODEL_ALIAS))!; model.selectVariant(cachedVariant); await model.load(); manager.startWebService(); diff --git a/sdk/rust/examples/chat_completion.rs b/sdk/rust/examples/chat_completion.rs index 3516aa60..a1342552 100644 --- a/sdk/rust/examples/chat_completion.rs +++ b/sdk/rust/examples/chat_completion.rs @@ -35,7 +35,7 @@ async fn main() -> Result<()> { .or_else(|| models.first().map(|m| m.alias().to_string())) .expect("No models available in the catalog"); - let model = manager.catalog().get_model(&model_alias).await?; + let model = manager.catalog().get_model(&model_alias).await?.expect("model not found"); if !model.is_cached().await? { println!("Downloading model '{}'…", model.alias()); diff --git a/sdk/rust/examples/interactive_chat.rs b/sdk/rust/examples/interactive_chat.rs index bd230155..1467e1b5 100644 --- a/sdk/rust/examples/interactive_chat.rs +++ b/sdk/rust/examples/interactive_chat.rs @@ -36,7 +36,7 @@ async fn main() -> Result<(), Box> { .map(|m| m.alias().to_string()) .unwrap_or_else(|| models[0].alias().to_string()); - let model = catalog.get_model(&alias).await?; + let model = catalog.get_model(&alias).await?.expect("model not found"); // Download if needed if !model.is_cached().await? { diff --git a/sdk/rust/src/catalog.rs b/sdk/rust/src/catalog.rs index 78485bff..b5156dbb 100644 --- a/sdk/rust/src/catalog.rs +++ b/sdk/rust/src/catalog.rs @@ -21,7 +21,7 @@ const CACHE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours pub(crate) struct CacheInvalidator(Arc); impl CacheInvalidator { - fn new() -> Self { + pub(crate) fn new() -> Self { Self(Arc::new(AtomicBool::new(false))) } @@ -126,20 +126,27 @@ impl Catalog { } /// Look up a model by its alias. - pub async fn get_model(&self, alias: &str) -> Result> { + /// + /// Returns an error if not found. For HuggingFace models, use + /// [`HuggingFaceCatalog`] via [`FoundryLocalManager::add_catalog`]. + pub async fn get_model(&self, alias: &str) -> Result>> { if alias.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Model alias must be a non-empty string".into(), }); } + self.update_models().await?; let s = self.lock_state()?; - s.models_by_alias.get(alias).cloned().ok_or_else(|| { - let available: Vec<&String> = s.models_by_alias.keys().collect(); - FoundryLocalError::ModelOperation { - reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + match s.models_by_alias.get(alias) { + Some(model) => Ok(Some(Arc::clone(model))), + None => { + let available: Vec<&String> = s.models_by_alias.keys().collect(); + Err(FoundryLocalError::ModelOperation { + reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + }) } - }) + } } /// Look up a specific model variant by its unique id. @@ -221,6 +228,7 @@ impl Catalog { Arc::clone(&self.core), Arc::clone(&self.model_load_manager), self.invalidator.clone(), + None, ); let variant_arc = Arc::new(variant.clone()); id_map.insert(id, variant_arc); diff --git a/sdk/rust/src/foundry_local_manager.rs b/sdk/rust/src/foundry_local_manager.rs index f80a7176..c9891200 100644 --- a/sdk/rust/src/foundry_local_manager.rs +++ b/sdk/rust/src/foundry_local_manager.rs @@ -13,6 +13,7 @@ use crate::configuration::{Configuration, FoundryLocalConfig, Logger}; use crate::detail::core_interop::CoreInterop; use crate::detail::ModelLoadManager; use crate::error::{FoundryLocalError, Result}; +use crate::huggingface_catalog::HuggingFaceCatalog; /// Global singleton holder — only stores a successfully initialised manager. static INSTANCE: OnceLock = OnceLock::new(); @@ -25,6 +26,7 @@ static INIT_GUARD: Mutex<()> = Mutex::new(()); /// the existing instance. pub struct FoundryLocalManager { core: Arc, + model_load_manager: Arc, catalog: Catalog, urls: Mutex>, /// Application logger (stub — not yet wired into the native core). @@ -70,6 +72,7 @@ impl FoundryLocalManager { let manager = FoundryLocalManager { core, + model_load_manager, catalog, urls: Mutex::new(Vec::new()), _logger: logger, @@ -90,6 +93,38 @@ impl FoundryLocalManager { &self.catalog } + /// Create a separate HuggingFace catalog for registering and downloading + /// models from HuggingFace. + /// + /// The three-step flow: + /// 1. `add_catalog("https://huggingface.co", None)` — create the catalog + /// 2. `catalog.register_model("org/repo")` — register (config-only download) + /// 3. `model.download(None)` — download ONNX files + /// + /// The returned catalog is owned by the caller. Each call creates a new + /// instance with registrations loaded from disk. + pub async fn add_catalog( + &self, + catalog_url: &str, + token: Option, + ) -> Result { + if !catalog_url.to_lowercase().contains("huggingface.co") { + return Err(FoundryLocalError::Validation { + reason: format!( + "Unsupported catalog URL '{}'. Only HuggingFace catalogs (huggingface.co) are supported.", + catalog_url + ), + }); + } + + HuggingFaceCatalog::create( + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + token, + ) + .await + } + /// URLs that the local web service is listening on. /// /// Empty until [`Self::start_web_service`] has been called. diff --git a/sdk/rust/src/hf_utils.rs b/sdk/rust/src/hf_utils.rs new file mode 100644 index 00000000..17c46849 --- /dev/null +++ b/sdk/rust/src/hf_utils.rs @@ -0,0 +1,117 @@ +//! Shared HuggingFace URL utilities. + +/// Normalise a model identifier to a canonical HuggingFace URL, or return +/// `None` if it is a plain alias. +/// +/// Strips `/tree/{revision}/` from full browser URLs so the result matches +/// the stored `ModelInfo.uri` format. +/// +/// # Examples +/// +/// ```text +/// "https://huggingface.co/org/repo/tree/main/sub" → Some("https://huggingface.co/org/repo/sub") +/// "https://huggingface.co/org/repo" → Some("https://huggingface.co/org/repo") +/// "org/repo/sub" → Some("https://huggingface.co/org/repo/sub") +/// "phi-3-mini" (plain alias) → None +/// ``` +pub(crate) fn normalize_to_huggingface_url(input: &str) -> Option { + const HF_PREFIX: &str = "https://huggingface.co/"; + + if input.to_lowercase().starts_with(&HF_PREFIX.to_lowercase()) { + let path = &input[HF_PREFIX.len()..]; + let parts: Vec<&str> = path.split('/').collect(); + if parts.len() >= 4 && parts[2].eq_ignore_ascii_case("tree") { + let org = parts[0]; + let repo = parts[1]; + let sub: Vec<&str> = parts[4..].to_vec(); + return if sub.is_empty() { + Some(format!("{HF_PREFIX}{org}/{repo}")) + } else { + Some(format!("{HF_PREFIX}{org}/{repo}/{}", sub.join("/"))) + }; + } + return Some(input.to_string()); + } + + if input.contains('/') && !input.to_lowercase().starts_with("azureml://") { + let parts: Vec<&str> = input.split('/').collect(); + if parts.len() >= 4 && parts[2].eq_ignore_ascii_case("tree") { + let org = parts[0]; + let repo = parts[1]; + let sub: Vec<&str> = parts[4..].to_vec(); + return if sub.is_empty() { + Some(format!("{HF_PREFIX}{org}/{repo}")) + } else { + Some(format!("{HF_PREFIX}{org}/{repo}/{}", sub.join("/"))) + }; + } + return Some(format!("{HF_PREFIX}{input}")); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn plain_alias_returns_none() { + assert!(normalize_to_huggingface_url("phi-3-mini").is_none()); + } + + #[test] + fn org_repo_returns_hf_url() { + assert_eq!( + normalize_to_huggingface_url("microsoft/Phi-3-mini-4k-instruct-onnx"), + Some("https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx".into()) + ); + } + + #[test] + fn full_hf_url_passthrough() { + let url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx"; + assert_eq!(normalize_to_huggingface_url(url), Some(url.into())); + } + + #[test] + fn browser_url_strips_tree_revision() { + let input = "https://huggingface.co/org/repo/tree/main/sub/path"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo/sub/path".into()) + ); + } + + #[test] + fn browser_url_no_subpath() { + let input = "https://huggingface.co/org/repo/tree/main"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo".into()) + ); + } + + #[test] + fn bare_identifier_with_tree_revision() { + let input = "org/repo/tree/main/sub"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo/sub".into()) + ); + } + + #[test] + fn org_repo_with_subpath() { + let input = "org/repo/sub/path"; + assert_eq!( + normalize_to_huggingface_url(input), + Some("https://huggingface.co/org/repo/sub/path".into()) + ); + } + + #[test] + fn azureml_url_returns_none() { + assert!(normalize_to_huggingface_url("azureml://some/model").is_none()); + } +} diff --git a/sdk/rust/src/huggingface_catalog.rs b/sdk/rust/src/huggingface_catalog.rs new file mode 100644 index 00000000..45d9daef --- /dev/null +++ b/sdk/rust/src/huggingface_catalog.rs @@ -0,0 +1,404 @@ +//! HuggingFace model catalog — register, download, and look up HuggingFace models. +//! +//! Created via [`FoundryLocalManager::add_catalog`]. Provides the three-step +//! flow: register (config-only download) → download (ONNX files) → inference. + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +use serde_json::json; + +use crate::catalog::CacheInvalidator; +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; +use crate::hf_utils::normalize_to_huggingface_url; +use crate::model::Model; +use crate::model_variant::ModelVariant; +use crate::types::ModelInfo; + +/// Filename for the HuggingFace registration persistence file. +const REGISTRATIONS_FILENAME: &str = "huggingface.modelinfo.json"; + +/// Internal state protected by a Mutex. +struct HuggingFaceCatalogState { + variants_by_id: HashMap>, + models_by_id: HashMap>, +} + +/// A catalog for HuggingFace models. +/// +/// Created via [`FoundryLocalManager::add_catalog`]. Each call to `add_catalog` +/// creates a new instance with registrations loaded from disk. +/// +/// # Three-step flow +/// +/// ```text +/// let hf = manager.add_catalog("https://huggingface.co", None).await?; +/// let model = hf.register_model("org/repo").await?; // config files only +/// model.download::(None).await?; // ONNX files +/// ``` +pub struct HuggingFaceCatalog { + core: Arc, + model_load_manager: Arc, + token: Option, + state: Mutex, + invalidator: CacheInvalidator, +} + +impl HuggingFaceCatalog { + pub(crate) async fn create( + core: Arc, + model_load_manager: Arc, + token: Option, + ) -> Result { + let invalidator = CacheInvalidator::new(); + let catalog = Self { + core, + model_load_manager, + token, + state: Mutex::new(HuggingFaceCatalogState { + variants_by_id: HashMap::new(), + models_by_id: HashMap::new(), + }), + invalidator, + }; + catalog.load_registrations()?; + Ok(catalog) + } + + /// Catalog name. + pub fn name(&self) -> &str { + "HuggingFace" + } + + /// Register a HuggingFace model by downloading its config files only (~50KB). + /// + /// Sends the `register_model` FFI command to the native core, which downloads + /// config files (genai_config.json, config.json, tokenizer_config.json, etc.) + /// and generates metadata. Returns a `Model` with `cached: false`. + /// + /// After registration, call [`Model::download`] to download the ONNX files. + pub async fn register_model(&self, model_identifier: &str) -> Result> { + if normalize_to_huggingface_url(model_identifier).is_none() { + return Err(FoundryLocalError::Validation { + reason: format!( + "'{model_identifier}' is not a valid HuggingFace URL or org/repo identifier." + ), + }); + } + + let params = json!({ + "Params": { + "Model": model_identifier, + "Token": self.token.as_deref().unwrap_or("") + } + }); + + let result = self + .core + .execute_command_async("register_model".into(), Some(params)) + .await?; + + let model_info: ModelInfo = serde_json::from_str(&result)?; + + let model = { + let mut s = self.lock_state()?; + let variant = ModelVariant::new( + model_info.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + self.invalidator.clone(), + self.token.clone(), + ); + let variant_arc = Arc::new(variant.clone()); + s.variants_by_id + .insert(model_info.id.clone(), variant_arc); + + let mut m = Model::new(model_info.alias.clone(), Arc::clone(&self.core)); + m.add_variant(variant); + let model = Arc::new(m); + s.models_by_id + .insert(model_info.id.clone(), Arc::clone(&model)); + model + }; + + self.save_registrations()?; + + Ok(model) + } + + /// Look up a model by its ID, alias, or HuggingFace URL. + /// + /// Uses three-tier lookup: + /// 1. Direct ID match + /// 2. Alias match (case-insensitive) + /// 3. URI-based match (normalize to HuggingFace URL and compare) + /// + /// Returns an error if the model is not found. + pub async fn get_model(&self, identifier: &str) -> Result> { + if identifier.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "Model identifier must be a non-empty string".into(), + }); + } + + let s = self.lock_state()?; + + // 1. Direct ID match + if let Some(model) = s.models_by_id.get(identifier) { + return Ok(Arc::clone(model)); + } + + // 2. Alias match (case-insensitive) + for model in s.models_by_id.values() { + if model.alias().eq_ignore_ascii_case(identifier) { + return Ok(Arc::clone(model)); + } + } + + // 3. URI-based match — prefer exact, fall back to prefix + if let Some(normalized_url) = normalize_to_huggingface_url(identifier) { + let normalized_lower = normalized_url.to_lowercase(); + let normalized_with_slash = + format!("{}/", normalized_lower.trim_end_matches('/')); + + // Exact match first + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower == normalized_lower { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + // Prefix match fallback + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower.starts_with(&normalized_with_slash) { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + } + + Err(FoundryLocalError::ModelOperation { + reason: format!("Model '{identifier}' not found in HuggingFace catalog."), + }) + } + + /// Download a HuggingFace model's ONNX files. + /// + /// Sends the `download_model` FFI command. The model should have been + /// previously registered via [`HuggingFaceCatalog::register_model`]. + /// + /// If `progress` is provided, it receives human-readable progress strings + /// as the download proceeds. + pub async fn download_model( + &self, + model_uri: &str, + progress: Option, + ) -> Result> + where + F: FnMut(&str) + Send + 'static, + { + if normalize_to_huggingface_url(model_uri).is_none() { + return Err(FoundryLocalError::Validation { + reason: format!( + "'{model_uri}' is not a valid HuggingFace URL or org/repo identifier." + ), + }); + } + + let params = json!({ + "Params": { + "Model": model_uri, + "Token": self.token.as_deref().unwrap_or("") + } + }); + + let result_data = match progress { + Some(cb) => { + self.core + .execute_command_streaming_async( + "download_model".into(), + Some(params), + cb, + ) + .await? + } + None => { + self.core + .execute_command_async("download_model".into(), Some(params)) + .await? + } + }; + + // Match result against registered models by URI + let expected_uri = format!("https://huggingface.co/{result_data}"); + let expected_lower = expected_uri.to_lowercase(); + let expected_with_slash = + format!("{}/", expected_lower.trim_end_matches('/')); + + let s = self.lock_state()?; + + // Exact match first + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower == expected_lower { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + // Prefix match fallback + for variant in s.variants_by_id.values() { + let uri_lower = variant.info().uri.to_lowercase(); + if uri_lower.starts_with(&expected_with_slash) + || expected_lower.starts_with( + &format!("{}/", uri_lower.trim_end_matches('/')), + ) + { + if let Some(model) = s.models_by_id.get(variant.id()) { + return Ok(Arc::clone(model)); + } + } + } + + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Model '{model_uri}' was downloaded but could not be found in the catalog." + ), + }) + } + + /// Return all registered models. + pub async fn get_models(&self) -> Result>> { + let s = self.lock_state()?; + Ok(s.models_by_id.values().cloned().collect()) + } + + /// Look up a specific model variant by its unique id. + pub async fn get_model_variant( + &self, + id: &str, + ) -> Result>> { + let s = self.lock_state()?; + Ok(s.variants_by_id.get(id).cloned()) + } + + /// Return only the model variants that are currently cached on disk. + pub async fn get_cached_models(&self) -> Result>> { + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(Vec::new()); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + let s = self.lock_state()?; + Ok(cached_ids + .iter() + .filter_map(|id| s.variants_by_id.get(id).cloned()) + .collect()) + } + + /// Return model variants that are currently loaded into memory. + pub async fn get_loaded_models(&self) -> Result>> { + let loaded_ids = self.model_load_manager.list_loaded().await?; + let s = self.lock_state()?; + Ok(loaded_ids + .iter() + .filter_map(|id| s.variants_by_id.get(id).cloned()) + .collect()) + } + + // ── Persistence ────────────────────────────────────────────────────── + + fn registrations_path(&self) -> Result { + let cache_dir = self + .core + .execute_command("get_cache_directory".into(), None)?; + Ok(PathBuf::from(cache_dir.trim().trim_matches('"')) + .join("HuggingFace") + .join(REGISTRATIONS_FILENAME)) + } + + fn load_registrations(&self) -> Result<()> { + let path = match self.registrations_path() { + Ok(p) => p, + Err(_) => return Ok(()), // gracefully skip if home dir unknown + }; + if !path.exists() { + return Ok(()); + } + + let json = match std::fs::read_to_string(&path) { + Ok(s) => s, + Err(_) => return Ok(()), // gracefully skip on read error + }; + if json.trim().is_empty() { + return Ok(()); + } + + let infos: Vec = match serde_json::from_str(&json) { + Ok(v) => v, + Err(_) => return Ok(()), // gracefully skip on parse error + }; + + let mut s = self.lock_state()?; + for info in infos { + let variant = ModelVariant::new( + info.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + self.invalidator.clone(), + self.token.clone(), + ); + let variant_arc = Arc::new(variant.clone()); + s.variants_by_id.insert(info.id.clone(), variant_arc); + + let mut m = Model::new(info.alias.clone(), Arc::clone(&self.core)); + m.add_variant(variant); + s.models_by_id.insert(info.id.clone(), Arc::new(m)); + } + + Ok(()) + } + + fn save_registrations(&self) -> Result<()> { + let path = match self.registrations_path() { + Ok(p) => p, + Err(_) => return Ok(()), // gracefully skip + }; + if let Some(dir) = path.parent() { + let _ = std::fs::create_dir_all(dir); + } + + // Snapshot data under the lock, then release before doing I/O + let infos: Vec = { + let s = self.lock_state()?; + s.variants_by_id.values().map(|v| v.info().clone()).collect() + }; + + let json = serde_json::to_string_pretty(&infos) + .map_err(|e| FoundryLocalError::Internal { + reason: format!("Failed to serialize registrations: {e}"), + })?; + std::fs::write(&path, json).map_err(|e| FoundryLocalError::Internal { + reason: format!("Failed to write registrations file: {e}"), + })?; + Ok(()) + } + + fn lock_state( + &self, + ) -> Result> { + self.state.lock().map_err(|_| FoundryLocalError::Internal { + reason: "HuggingFace catalog state mutex poisoned".into(), + }) + } +} diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index c6d6e6c4..0bae2b07 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -6,6 +6,8 @@ mod catalog; mod configuration; mod error; mod foundry_local_manager; +mod hf_utils; +mod huggingface_catalog; mod model; mod model_variant; mod types; @@ -17,6 +19,7 @@ pub use self::catalog::Catalog; pub use self::configuration::{FoundryLocalConfig, LogLevel, Logger}; pub use self::error::FoundryLocalError; pub use self::foundry_local_manager::FoundryLocalManager; +pub use self::huggingface_catalog::HuggingFaceCatalog; pub use self::model::Model; pub use self::model_variant::ModelVariant; pub use self::types::{ diff --git a/sdk/rust/src/model_variant.rs b/sdk/rust/src/model_variant.rs index c4be6822..06c70db7 100644 --- a/sdk/rust/src/model_variant.rs +++ b/sdk/rust/src/model_variant.rs @@ -22,6 +22,7 @@ pub struct ModelVariant { core: Arc, model_load_manager: Arc, cache_invalidator: CacheInvalidator, + token: Option, } impl fmt::Debug for ModelVariant { @@ -39,12 +40,14 @@ impl ModelVariant { core: Arc, model_load_manager: Arc, cache_invalidator: CacheInvalidator, + token: Option, ) -> Self { Self { info, core, model_load_manager, cache_invalidator, + token, } } @@ -93,7 +96,17 @@ impl ModelVariant { where F: FnMut(&str) + Send + 'static, { - let params = json!({ "Params": { "Model": self.info.id } }); + let model_param = if self.info.provider_type.eq_ignore_ascii_case("huggingface") { + &self.info.uri + } else { + &self.info.id + }; + let mut params_map = serde_json::Map::new(); + params_map.insert("Model".into(), json!(model_param)); + if let Some(ref t) = self.token { + params_map.insert("Token".into(), json!(t)); + } + let params = json!({ "Params": params_map }); match progress { Some(cb) => { self.core diff --git a/sdk/rust/src/types.rs b/sdk/rust/src/types.rs index d1d1f002..6ea0f2b5 100644 --- a/sdk/rust/src/types.rs +++ b/sdk/rust/src/types.rs @@ -56,6 +56,8 @@ pub struct ModelInfo { pub id: String, pub name: String, pub version: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub hash: Option, pub alias: String, #[serde(skip_serializing_if = "Option::is_none")] pub display_name: Option, @@ -86,7 +88,7 @@ pub struct ModelInfo { #[serde(skip_serializing_if = "Option::is_none")] pub min_fl_version: Option, #[serde(default)] - pub created_at_unix: u64, + pub created_at: u64, } /// Desired response format for chat completions. diff --git a/sdk/rust/tests/integration/audio_client_test.rs b/sdk/rust/tests/integration/audio_client_test.rs index 47cef9d0..8b69867c 100644 --- a/sdk/rust/tests/integration/audio_client_test.rs +++ b/sdk/rust/tests/integration/audio_client_test.rs @@ -9,7 +9,8 @@ async fn setup_audio_client() -> (AudioClient, Arc) { let model = catalog .get_model(common::WHISPER_MODEL_ALIAS) .await - .expect("get_model(whisper-tiny) failed"); + .expect("get_model(whisper-tiny) failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); (model.create_audio_client(), model) } diff --git a/sdk/rust/tests/integration/catalog_test.rs b/sdk/rust/tests/integration/catalog_test.rs index d418c7a7..d3f0fc8d 100644 --- a/sdk/rust/tests/integration/catalog_test.rs +++ b/sdk/rust/tests/integration/catalog_test.rs @@ -36,7 +36,8 @@ async fn should_get_model_by_alias() { let model = cat .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); } diff --git a/sdk/rust/tests/integration/chat_client_test.rs b/sdk/rust/tests/integration/chat_client_test.rs index b24f3804..34dab08a 100644 --- a/sdk/rust/tests/integration/chat_client_test.rs +++ b/sdk/rust/tests/integration/chat_client_test.rs @@ -15,7 +15,8 @@ async fn setup_chat_client() -> (ChatClient, Arc) { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); let client = model.create_chat_client().max_tokens(500).temperature(0.0); diff --git a/sdk/rust/tests/integration/huggingface_catalog_test.rs b/sdk/rust/tests/integration/huggingface_catalog_test.rs new file mode 100644 index 00000000..00308fd2 --- /dev/null +++ b/sdk/rust/tests/integration/huggingface_catalog_test.rs @@ -0,0 +1,116 @@ +use super::common; + +const HF_URL: &str = "https://huggingface.co/onnxruntime/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"; + +#[tokio::test] +async fn should_create_huggingface_catalog() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + assert_eq!(hf_catalog.name(), "HuggingFace"); +} + +#[tokio::test] +async fn should_reject_non_huggingface_url() { + let manager = common::get_test_manager(); + let result = manager.add_catalog("https://example.com", None).await; + assert!(result.is_err(), "Non-HuggingFace URL should be rejected"); +} + +#[tokio::test] +async fn should_register_model() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let model = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + assert!(!model.alias().is_empty(), "Model alias should be non-empty"); + assert!(!model.id().is_empty(), "Model id should be non-empty"); +} + +#[tokio::test] +async fn should_find_registered_model_by_identifier() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let _model = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + let found = hf_catalog + .get_model(HF_URL) + .await + .expect("get_model failed"); + + assert!(!found.alias().is_empty(), "Should find model by HuggingFace URL"); +} + +#[tokio::test] +async fn should_register_then_download_model() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let registered = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + assert!(!registered.alias().is_empty()); + + // Now download the ONNX files + registered + .download::(None) + .await + .expect("download failed"); +} + +#[tokio::test] +async fn should_reject_registration_of_plain_alias() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let result = hf_catalog.register_model("phi-3-mini").await; + assert!(result.is_err(), "Plain alias should be rejected"); +} + +#[tokio::test] +async fn should_list_registered_models() { + let manager = common::get_test_manager(); + let hf_catalog = manager + .add_catalog("https://huggingface.co", None) + .await + .expect("add_catalog failed"); + + let _model = hf_catalog + .register_model(HF_URL) + .await + .expect("register_model failed"); + + let models = hf_catalog + .get_models() + .await + .expect("get_models failed"); + + assert!( + !models.is_empty(), + "Should have at least one registered model" + ); +} diff --git a/sdk/rust/tests/integration/main.rs b/sdk/rust/tests/integration/main.rs index 04de9a23..d4e25d1f 100644 --- a/sdk/rust/tests/integration/main.rs +++ b/sdk/rust/tests/integration/main.rs @@ -11,6 +11,7 @@ mod common; mod audio_client_test; mod catalog_test; mod chat_client_test; +mod huggingface_catalog_test; mod manager_test; mod model_test; mod web_service_test; diff --git a/sdk/rust/tests/integration/model_test.rs b/sdk/rust/tests/integration/model_test.rs index d2b68b77..02a7761e 100644 --- a/sdk/rust/tests/integration/model_test.rs +++ b/sdk/rust/tests/integration/model_test.rs @@ -38,7 +38,8 @@ async fn should_load_and_unload_model() { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); assert!( @@ -62,7 +63,8 @@ async fn should_expose_alias() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); } @@ -74,7 +76,8 @@ async fn should_expose_non_empty_id() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); println!("Model id: {}", model.id()); @@ -91,7 +94,8 @@ async fn should_have_at_least_one_variant() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let variants = model.variants(); println!("Model has {} variant(s)", variants.len()); @@ -109,7 +113,8 @@ async fn should_have_selected_variant_matching_id() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let selected = model.selected_variant(); assert_eq!( @@ -126,7 +131,8 @@ async fn should_report_cached_model_as_cached() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let cached = model.is_cached().await.expect("is_cached() should succeed"); assert!( @@ -143,7 +149,8 @@ async fn should_return_non_empty_path_for_cached_model() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let path = model.path().await.expect("path() should succeed"); println!("Model path: {}", path.display()); @@ -161,7 +168,8 @@ async fn should_select_variant_by_id() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); // Remember the original selection so we can restore it afterward. let original_id = model.id().to_string(); @@ -190,7 +198,8 @@ async fn should_fail_to_select_unknown_variant() { .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); let result = model.select_variant("nonexistent-variant-id"); assert!( @@ -214,6 +223,7 @@ async fn get_test_model() -> Arc { .get_model(common::TEST_MODEL_ALIAS) .await .expect("get_model failed") + .expect("model not found") } #[tokio::test] diff --git a/sdk/rust/tests/integration/web_service_test.rs b/sdk/rust/tests/integration/web_service_test.rs index 9222f9d4..e6511f65 100644 --- a/sdk/rust/tests/integration/web_service_test.rs +++ b/sdk/rust/tests/integration/web_service_test.rs @@ -10,7 +10,8 @@ async fn should_complete_chat_via_rest_api() { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); manager @@ -72,7 +73,8 @@ async fn should_stream_chat_via_rest_api() { let model = catalog .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed") + .expect("model not found"); model.load().await.expect("model.load() failed"); manager