diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0956f9fa..0776fb4b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,6 +30,25 @@ jobs: - name: Run unit tests run: make test.unit + javascript: + name: javascript + runs-on: ubuntu-latest + + steps: + - name: checkout code + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '24' + + - name: Install dependencies + run: npm ci + + - name: Run JS client tests + run: make test.client.js + python: name: python (${{ matrix.python-version }}) runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index 699a3cb2..ed0eebd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2695,6 +2695,30 @@ dependencies = [ "foldhash 0.2.0", ] +[[package]] +name = "headers" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" +dependencies = [ + "base64", + "bytes 1.11.1", + "headers-core", + "http 1.4.0", + "httpdate", + "mime", + "sha1 0.10.6", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.4.0", +] + [[package]] name = "heck" version = "0.5.0" @@ -4889,6 +4913,8 @@ dependencies = [ "minijinja", "minijinja-contrib", "nanoid", + "paddler_cache_dir", + "paddler_download_manager", "paddler_types", "rand 0.9.4", "reqwest", @@ -4922,6 +4948,18 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "paddler_cache_dir" +version = "4.0.0" +dependencies = [ + "anyhow", + "fslock", + "sha2", + "tempfile", + "thiserror 2.0.18", + "tokio", +] + [[package]] name = "paddler_cli" version = "4.0.0" @@ -4985,6 +5023,21 @@ dependencies = [ "url", ] +[[package]] +name = "paddler_download_manager" +version = "4.0.0" +dependencies = [ + "anyhow", + "bytes 1.11.1", + "futures-util", + "headers", + "reqwest", + "tempfile", + "thiserror 2.0.18", + "tokio", + "url", +] + [[package]] name = "paddler_gui" version = "4.0.0" diff --git a/Cargo.toml b/Cargo.toml index 235e702d..4be68b5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["paddler", "paddler_bootstrap", "paddler_cli", "paddler_client", "paddler_client_cli", "paddler_gui", "paddler_tests", "paddler_types"] +members = ["paddler", "paddler_bootstrap", "paddler_cache_dir", "paddler_cli", "paddler_client", "paddler_client_cli", "paddler_download_manager", "paddler_gui", "paddler_tests", "paddler_types"] resolver = "2" [workspace.package] @@ -30,8 +30,10 @@ dashmap = "6.1" encoding_rs = { version = "0.8", features = ["serde"] } env_logger = "0.11" esbuild-metafile = "0.5.2" +fslock = "=0.2.1" futures = "0.3" futures-util = { version = "0.3", features = ["tokio-io"] } +headers = "=0.4.1" hf-hub = { version = "0.4", features = ["tokio"] } image = "0.25" indoc = "2" @@ -56,6 +58,7 @@ rust-embed = { version = "8.9", features = ["interpolate-folder-path"] } serial_test = { version = "3", features = ["file_locks"] } serde = { version = "1", features = ["derive"] } serde_json = "1" +sha2 = "0.10" shellexpand = "3" iced = { version = "0.14", features = ["image", "svg", "tokio"] } if-addrs = "0.13" @@ -70,7 +73,9 @@ thiserror = "2" url = { version = "2.5", features = ["serde"] } paddler = { version = "4.0.0", path = "paddler" } paddler_bootstrap = { version = "4.0.0", path = "paddler_bootstrap" } +paddler_cache_dir = { version = "4.0.0", path = "paddler_cache_dir" } paddler_client = { version = "4.0.0", path = "paddler_client" } +paddler_download_manager = { version = "4.0.0", path = "paddler_download_manager" } paddler_tests = { version = "4.0.0", path = "paddler_tests" } paddler_types = { version = "4.0.0", path = "paddler_types" } diff --git a/Makefile b/Makefile index 1f90ea08..c9e8ea3b 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,8 @@ RUST_LOG ?= debug -PADDLER_CLI_SOURCES := $(shell find paddler/src paddler_bootstrap/src paddler_cli/src paddler_client/src paddler_types/src -name '*.rs') -PADDLER_GUI_SOURCES := $(shell find paddler/src paddler_bootstrap/src paddler_gui/src paddler_types/src -name '*.rs') +COVERAGE_PACKAGES := -p paddler_cache_dir -p paddler_download_manager +PADDLER_SOURCES := $(shell find paddler/src paddler_bootstrap/src paddler_cache_dir/src paddler_cli/src paddler_client/src paddler_download_manager/src paddler_gui/src paddler_types/src -name '*.rs') FRONTEND_SOURCES := $(shell find resources -type f) $(wildcard jarmuz/*.mjs) # ----------------------------------------------------------------------------- @@ -20,28 +20,28 @@ node_modules: package-lock.json esbuild-meta.json: $(FRONTEND_SOURCES) jarmuz-static.mjs tsconfig.json package.json node_modules ./jarmuz-static.mjs -target/debug/paddler: $(PADDLER_CLI_SOURCES) +target/debug/paddler: $(PADDLER_SOURCES) cargo build -p paddler_cli -target/release/paddler: $(PADDLER_CLI_SOURCES) esbuild-meta.json +target/release/paddler: $(PADDLER_SOURCES) esbuild-meta.json cargo build --release -p paddler_cli --features web_admin_panel -target/cuda/debug/paddler: $(PADDLER_CLI_SOURCES) esbuild-meta.json +target/cuda/debug/paddler: $(PADDLER_SOURCES) esbuild-meta.json cargo build -p paddler_cli --features cuda,web_admin_panel --target-dir target/cuda -target/cuda/release/paddler: $(PADDLER_CLI_SOURCES) esbuild-meta.json +target/cuda/release/paddler: $(PADDLER_SOURCES) esbuild-meta.json cargo build --release -p paddler_cli --features cuda,web_admin_panel --target-dir target/cuda -target/metal/debug/paddler: $(PADDLER_CLI_SOURCES) esbuild-meta.json +target/metal/debug/paddler: $(PADDLER_SOURCES) esbuild-meta.json cargo build -p paddler_cli --features metal,web_admin_panel --target-dir target/metal -target/metal/release/paddler: $(PADDLER_CLI_SOURCES) esbuild-meta.json +target/metal/release/paddler: $(PADDLER_SOURCES) esbuild-meta.json cargo build --release -p paddler_cli --features metal,web_admin_panel --target-dir target/metal -target/vulkan/release/paddler: $(PADDLER_CLI_SOURCES) esbuild-meta.json +target/vulkan/release/paddler: $(PADDLER_SOURCES) esbuild-meta.json cargo build --release -p paddler_cli --features vulkan,web_admin_panel --target-dir target/vulkan -target/release/paddler_gui: $(PADDLER_GUI_SOURCES) esbuild-meta.json +target/release/paddler_gui: $(PADDLER_SOURCES) esbuild-meta.json cargo build --release -p paddler_gui --features web_admin_panel # ----------------------------------------------------------------------------- @@ -57,37 +57,59 @@ clean: .PHONY: clippy clippy: esbuild-meta.json - cargo clippy --workspace --all-targets --features web_admin_panel,tests_that_use_llms,tests_that_use_compiled_paddler + cargo clippy --workspace --all-targets --features web_admin_panel,tests_that_use_llms,tests_that_use_compiled_paddler,tests_that_use_in_process_cluster + +.PHONY: coverage +coverage: node_modules + cargo llvm-cov clean --workspace + cargo llvm-cov $(COVERAGE_PACKAGES) --no-report + cargo llvm-cov report --json --output-path target/llvm-cov.json + cargo llvm-cov report --lcov --output-path target/lcov.info + cargo llvm-cov report + npx rust-coverage-check target/llvm-cov.json \ + --workspace-root $(CURDIR) \ + --gated paddler_cache_dir=100 \ + --gated paddler_download_manager=99 + +.PHONY: coverage-clean +coverage-clean: + cargo llvm-cov clean --workspace + rm -rf target/llvm-cov-target + rm -f target/llvm-cov.json target/lcov.info + +.PHONY: coverage-report +coverage-report: + cargo llvm-cov $(COVERAGE_PACKAGES) --html .PHONY: fmt fmt: node_modules ./jarmuz-fmt.mjs .PHONY: test -test: test.unit test.integration +test: test.client.js test.unit test.integration .PHONY: test.integration test.integration: target/debug/paddler - cargo test -p paddler_tests --features tests_that_use_compiled_paddler,tests_that_use_llms + cargo test -p paddler_tests --features tests_that_use_compiled_paddler,tests_that_use_in_process_cluster,tests_that_use_llms .PHONY: test.integration.cuda test.integration.cuda: target/cuda/debug/paddler - PADDLER_BINARY_PATH=../target/cuda/debug/paddler PADDLER_TEST_DEVICE=cuda cargo test --target-dir target/cuda -p paddler_tests --features cuda,tests_that_use_compiled_paddler,tests_that_use_llms + PADDLER_BINARY_PATH=../target/cuda/debug/paddler PADDLER_TEST_DEVICE=cuda cargo test --target-dir target/cuda -p paddler_tests --features cuda,tests_that_use_compiled_paddler,tests_that_use_in_process_cluster,tests_that_use_llms .PHONY: test.integration.metal test.integration.metal: target/metal/debug/paddler - PADDLER_BINARY_PATH=../target/metal/debug/paddler PADDLER_TEST_DEVICE=metal cargo test --target-dir target/metal -p paddler_tests --features metal,tests_that_use_compiled_paddler,tests_that_use_llms + PADDLER_BINARY_PATH=../target/metal/debug/paddler PADDLER_TEST_DEVICE=metal cargo test --target-dir target/metal -p paddler_tests --features metal,tests_that_use_compiled_paddler,tests_that_use_in_process_cluster,tests_that_use_llms .PHONY: test.unit test.unit: esbuild-meta.json cargo test --features web_admin_panel .PHONY: build.client.js -build.client.js: +build.client.js: node_modules npm --workspace @intentee/paddler-client run build .PHONY: test.client.js -test.client.js: +test.client.js: node_modules npm --workspace @intentee/paddler-client test .PHONY: watch diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000..59cb72d1 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,2 @@ +allow-expect-in-tests = true +allow-unwrap-in-tests = true diff --git a/package-lock.json b/package-lock.json index 65a10d2e..5174e97c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -23,6 +23,7 @@ "zod": "^4.0.17" }, "devDependencies": { + "@intentee/rust-coverage-check": "^0.2.0", "@types/hotwired__turbo": "^8", "@types/react": "^19.1.10", "@types/react-dom": "^19.1.7", @@ -984,6 +985,19 @@ "resolved": "paddler_client_javascript", "link": true }, + "node_modules/@intentee/rust-coverage-check": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/@intentee/rust-coverage-check/-/rust-coverage-check-0.2.0.tgz", + "integrity": "sha512-RRCHxYdYLk5SjmBw4YAOsaE3GwjJYuzhpY/Gaz8frW0dLqXZNBz5Xtd2k79dI8januHjLDjH12/vYNHOoHNsSw==", + "dev": true, + "license": "MIT", + "bin": { + "rust-coverage-check": "src/main.mjs" + }, + "engines": { + "node": ">=24.0.0" + } + }, "node_modules/@isaacs/cliui": { "version": "8.0.2", "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", diff --git a/package.json b/package.json index cdf98de8..9b1d712a 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,6 @@ { "devDependencies": { + "@intentee/rust-coverage-check": "^0.2.0", "@types/hotwired__turbo": "^8", "@types/react": "^19.1.10", "@types/react-dom": "^19.1.7", diff --git a/paddler/Cargo.toml b/paddler/Cargo.toml index 603b3f6c..117950c3 100644 --- a/paddler/Cargo.toml +++ b/paddler/Cargo.toml @@ -37,6 +37,8 @@ log = { workspace = true } minijinja = { workspace = true } minijinja-contrib = { workspace = true } nanoid = { workspace = true } +paddler_cache_dir = { workspace = true } +paddler_download_manager = { workspace = true } paddler_types = { workspace = true } thiserror = { workspace = true } rand = { workspace = true } diff --git a/paddler/src/agent_issue_fix.rs b/paddler/src/agent_issue_fix.rs index 278b64f1..ab4e2ec5 100644 --- a/paddler/src/agent_issue_fix.rs +++ b/paddler/src/agent_issue_fix.rs @@ -10,6 +10,8 @@ pub enum AgentIssueFix { ModelChatTemplateIsLoaded(ModelPath), ModelFileExists(ModelPath), ModelIsLoaded(ModelPath), + ModelDownloadCompleted(ModelPath), + ModelDownloadStarted(ModelPath), ModelStateIsReconciled, MultimodalProjectionIsLoaded(ModelPath), SlotStarted(u32), @@ -72,6 +74,24 @@ impl AgentIssueFix { Self::ModelStateIsReconciled => true, _ => false, }, + AgentIssue::CacheCannotAcquireLock(issue_model_path) + | AgentIssue::CacheDirectoryIsNotWritable(issue_model_path) + | AgentIssue::CacheStorageIsFull(issue_model_path) + | AgentIssue::DownloadInterrupted(issue_model_path) + | AgentIssue::DownloadServerDeniedAccess(issue_model_path) + | AgentIssue::DownloadServerErrored(issue_model_path) + | AgentIssue::DownloadServerIsUnreachable(issue_model_path) + | AgentIssue::DownloadServerRejectedRequest(issue_model_path) + | AgentIssue::DownloadUrlIsMalformed(issue_model_path) + | AgentIssue::ModelCacheIsCorrupted(issue_model_path) + | AgentIssue::ModelDoesNotExistAtUrl(issue_model_path) => match self { + Self::ModelDownloadCompleted(fix_model_path) + | Self::ModelDownloadStarted(fix_model_path) => { + issue_model_path.eq(fix_model_path) + } + Self::ModelStateIsReconciled => true, + _ => false, + }, } } } @@ -172,4 +192,119 @@ mod tests { assert!(fix.can_fix(&issue)); } + + #[test] + fn model_download_completed_fixes_model_does_not_exist_at_url_with_same_path() { + let fix = AgentIssueFix::ModelDownloadCompleted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::ModelDoesNotExistAtUrl(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_download_server_denied_access() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::DownloadServerDeniedAccess(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_cache_directory_is_not_writable() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = + AgentIssue::CacheDirectoryIsNotWritable(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_cache_storage_is_full() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::CacheStorageIsFull(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_download_server_is_unreachable() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = + AgentIssue::DownloadServerIsUnreachable(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_download_url_is_malformed() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::DownloadUrlIsMalformed(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_model_cache_is_corrupted() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::ModelCacheIsCorrupted(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_download_server_errored() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::DownloadServerErrored(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_download_interrupted() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = + AgentIssue::DownloadInterrupted(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_completed_does_not_fix_different_url() { + let fix = AgentIssueFix::ModelDownloadCompleted(model_path("https://example.com/a.gguf")); + let issue = AgentIssue::ModelDoesNotExistAtUrl(model_path("https://example.com/b.gguf")); + + assert!(!fix.can_fix(&issue)); + } + + #[test] + fn model_state_is_reconciled_fixes_model_cache_is_corrupted() { + let fix = AgentIssueFix::ModelStateIsReconciled; + let issue = AgentIssue::ModelCacheIsCorrupted(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_fixes_cache_cannot_acquire_lock() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::CacheCannotAcquireLock(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_completed_fixes_cache_cannot_acquire_lock() { + let fix = AgentIssueFix::ModelDownloadCompleted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::CacheCannotAcquireLock(model_path("https://example.com/m.gguf")); + + assert!(fix.can_fix(&issue)); + } + + #[test] + fn model_download_started_does_not_fix_huggingface_issues() { + let fix = AgentIssueFix::ModelDownloadStarted(model_path("https://example.com/m.gguf")); + let issue = AgentIssue::HuggingFaceModelDoesNotExist(model_path("https://example.com/m.gguf")); + + assert!(!fix.can_fix(&issue)); + } } diff --git a/paddler/src/atomic_value.rs b/paddler/src/atomic_value.rs index 9f0202d7..6061fe1a 100644 --- a/paddler/src/atomic_value.rs +++ b/paddler/src/atomic_value.rs @@ -1,5 +1,6 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; +use std::sync::atomic::AtomicU64; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; @@ -79,6 +80,37 @@ impl AtomicValue { } } +impl AtomicValue { + #[must_use] + pub const fn new(initial: u64) -> Self { + Self { + value: AtomicU64::new(initial), + } + } + + pub fn get(&self) -> u64 { + self.value.load(Ordering::SeqCst) + } + + pub fn increment_by(&self, increment: u64) { + self.value.fetch_add(increment, Ordering::SeqCst); + } + + pub fn set(&self, value: u64) { + self.value.store(value, Ordering::SeqCst); + } + + pub fn set_check(&self, value: u64) -> bool { + if self.get() == value { + false + } else { + self.set(value); + + true + } + } +} + impl AtomicValue { #[must_use] pub const fn new(initial: usize) -> Self { diff --git a/paddler/src/balancer/agent_controller.rs b/paddler/src/balancer/agent_controller.rs index 6b852af8..27c8e3d7 100644 --- a/paddler/src/balancer/agent_controller.rs +++ b/paddler/src/balancer/agent_controller.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::AtomicU64; use anyhow::Result; use async_trait::async_trait; @@ -44,9 +44,10 @@ pub struct AgentController { pub chat_template_override_sender_collection: Arc, pub connection_close: CancellationToken, pub desired_slots_total: AtomicValue, - pub download_current: AtomicValue, + pub download_current: AtomicValue, pub download_filename: RwLock>, - pub download_total: AtomicValue, + pub download_indeterminate: AtomicValue, + pub download_total: AtomicValue, pub embedding_sender_collection: Arc, pub generate_tokens_sender_collection: Arc, pub id: String, @@ -145,6 +146,7 @@ impl AgentController { desired_slots_total, download_current, download_filename, + download_indeterminate, download_total, issues, model_path, @@ -167,6 +169,9 @@ impl AgentController { changed |= self.desired_slots_total.set_check(desired_slots_total); changed |= self.download_current.set_check(download_current); + changed |= self + .download_indeterminate + .set_check(download_indeterminate); changed |= self.download_total.set_check(download_total); changed |= self.slots_total.set_check(slots_total); changed |= self @@ -313,6 +318,7 @@ impl ProducesSnapshot for AgentController { desired_slots_total: self.desired_slots_total.get(), download_current: self.download_current.get(), download_filename: self.get_download_filename(), + download_indeterminate: self.download_indeterminate.get(), download_total: self.download_total.get(), id: self.id.clone(), issues: self.get_issues(), @@ -363,9 +369,10 @@ mod tests { ), connection_close: CancellationToken::new(), desired_slots_total: AtomicValue::::new(0), - download_current: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), download_filename: RwLock::new(None), - download_total: AtomicValue::::new(0), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), id: "agent-test".to_owned(), @@ -391,6 +398,7 @@ mod tests { desired_slots_total: 4, download_current: 10, download_filename: None, + download_indeterminate: false, download_total: 100, issues: BTreeSet::new(), model_path: None, diff --git a/paddler/src/balancer/buffered_request_manager.rs b/paddler/src/balancer/buffered_request_manager.rs index b6429916..7fc812f5 100644 --- a/paddler/src/balancer/buffered_request_manager.rs +++ b/paddler/src/balancer/buffered_request_manager.rs @@ -100,7 +100,7 @@ mod tests { use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; - use std::sync::atomic::AtomicUsize; + use std::sync::atomic::AtomicU64; use std::task::Poll; use paddler_types::agent_state_application_status::AgentStateApplicationStatus; @@ -162,9 +162,10 @@ mod tests { ), connection_close: CancellationToken::new(), desired_slots_total: AtomicValue::::new(1), - download_current: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), download_filename: RwLock::new(None), - download_total: AtomicValue::::new(0), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), id: "agent-1".to_owned(), @@ -211,9 +212,10 @@ mod tests { ), connection_close: CancellationToken::new(), desired_slots_total: AtomicValue::::new(1), - download_current: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), download_filename: RwLock::new(None), - download_total: AtomicValue::::new(0), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), id: "agent-pre".to_owned(), diff --git a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs index db057885..cd6c09d5 100644 --- a/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs +++ b/paddler/src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::AtomicU64; use actix_web::Error; use actix_web::HttpRequest; @@ -118,6 +118,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController { desired_slots_total, download_current, download_filename, + download_indeterminate, download_total, issues, model_path, @@ -138,9 +139,10 @@ impl ControlsWebSocketEndpoint for AgentSocketController { .clone(), connection_close: connection_close.clone(), desired_slots_total: AtomicValue::::new(desired_slots_total), - download_current: AtomicValue::::new(download_current), + download_current: AtomicValue::::new(download_current), download_filename: RwLock::new(download_filename), - download_total: AtomicValue::::new(download_total), + download_indeterminate: AtomicValue::::new(download_indeterminate), + download_total: AtomicValue::::new(download_total), embedding_sender_collection: context.embedding_sender_collection.clone(), generate_tokens_sender_collection: context .generate_tokens_sender_collection diff --git a/paddler/src/desired_model_resolution.rs b/paddler/src/desired_model_resolution.rs index a825b19d..d5367ec7 100644 --- a/paddler/src/desired_model_resolution.rs +++ b/paddler/src/desired_model_resolution.rs @@ -1,5 +1,6 @@ use std::path::PathBuf; +#[derive(Debug)] pub enum DesiredModelResolution { NotConfigured, Resolved(PathBuf), diff --git a/paddler/src/download_huggingface_model.rs b/paddler/src/download_huggingface_model.rs deleted file mode 100644 index 787d8d40..00000000 --- a/paddler/src/download_huggingface_model.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; - -use anyhow::Result; -use anyhow::anyhow; -use hf_hub::Cache; -use hf_hub::Repo; -use hf_hub::RepoType; -use hf_hub::api::tokio::ApiBuilder; -use hf_hub::api::tokio::ApiError; -use log::warn; -use paddler_types::agent_issue::AgentIssue; -use paddler_types::agent_issue_params::HuggingFaceDownloadLock; -use paddler_types::agent_issue_params::ModelPath; -use paddler_types::huggingface_model_reference::HuggingFaceModelReference; -use tokio::time::Duration; -use tokio::time::sleep; - -use crate::agent_issue_fix::AgentIssueFix; -use crate::slot_aggregated_status::SlotAggregatedStatus; -use crate::slot_aggregated_status_download_progress::SlotAggregatedStatusDownloadProgress; - -const LOCK_RETRY_TIMEOUT: Duration = Duration::from_secs(10); - -pub async fn download_huggingface_model( - reference: &HuggingFaceModelReference, - slot_aggregated_status: Arc, -) -> Result { - let HuggingFaceModelReference { - filename, - repo_id, - revision, - } = reference; - let model_path = format!("{repo_id}/{revision}/{filename}"); - - if slot_aggregated_status.has_issue(&AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { - model_path: model_path.clone(), - })) { - return Err(anyhow!( - "Model '{model_path}' does not exist on Hugging Face. Not attempting to download it again." - )); - } - - let hf_cache = Cache::from_env(); - let hf_api = ApiBuilder::from_cache(hf_cache.clone()).build()?; - let hf_repo = hf_api.repo(Repo::with_revision( - repo_id.to_owned(), - RepoType::Model, - revision.to_owned(), - )); - - if let Some(cached_path) = hf_cache - .repo(Repo::new(repo_id.to_owned(), RepoType::Model)) - .get(filename) - { - slot_aggregated_status.reset_download(); - - return Ok(cached_path); - } - - match hf_repo - .download_with_progress( - filename, - SlotAggregatedStatusDownloadProgress::new(slot_aggregated_status.clone()), - ) - .await - { - Ok(resolved_filename) => { - slot_aggregated_status.register_fix(&AgentIssueFix::HuggingFaceDownloadedModel( - ModelPath { model_path }, - )); - - Ok(resolved_filename) - } - Err(ApiError::LockAcquisition(lock_path)) => { - slot_aggregated_status.register_issue(AgentIssue::HuggingFaceCannotAcquireLock( - HuggingFaceDownloadLock { - lock_path: lock_path.display().to_string(), - model_path: ModelPath { model_path }, - }, - )); - - warn!( - "Waiting to acquire download lock for '{}'. Sleeping for {} secs", - lock_path.display(), - LOCK_RETRY_TIMEOUT.as_secs() - ); - - sleep(LOCK_RETRY_TIMEOUT).await; - - Err(anyhow!( - "Failed to acquire download lock '{}'. Is more than one agent running on this machine?", - lock_path.display() - )) - } - Err(ApiError::RequestError(reqwest_error)) => match reqwest_error.status() { - Some(reqwest::StatusCode::NOT_FOUND) => { - slot_aggregated_status.register_issue(AgentIssue::HuggingFaceModelDoesNotExist( - ModelPath { - model_path: model_path.clone(), - }, - )); - - Err(anyhow!( - "Model '{model_path}' does not exist on Hugging Face." - )) - } - Some(reqwest::StatusCode::FORBIDDEN | reqwest::StatusCode::UNAUTHORIZED) => { - slot_aggregated_status.register_issue(AgentIssue::HuggingFacePermissions( - ModelPath { - model_path: model_path.clone(), - }, - )); - - Err(anyhow!( - "You do not have enough permissions to download '{model_path}' from Hugging Face." - )) - } - _ => Err(anyhow!( - "Failed to download model from Hugging Face: {reqwest_error}" - )), - }, - Err(err_other) => Err(err_other.into()), - } -} diff --git a/paddler/src/lib.rs b/paddler/src/lib.rs index 88a18d87..b6516edb 100644 --- a/paddler/src/lib.rs +++ b/paddler/src/lib.rs @@ -22,11 +22,12 @@ pub mod decoded_image; pub mod decoded_image_error; pub mod desired_model_resolution; pub mod dispenses_slots; -pub mod download_huggingface_model; pub mod embedding_input_tokenized; +pub mod model_source; pub mod produces_snapshot; pub mod resolve_desired_model; pub mod resolved_socket_addr; +pub mod resolves_model_source; pub mod sends_rpc_message; pub mod service; pub mod service_manager; diff --git a/paddler/src/model_source/huggingface.rs b/paddler/src/model_source/huggingface.rs new file mode 100644 index 00000000..5abdcf4f --- /dev/null +++ b/paddler/src/model_source/huggingface.rs @@ -0,0 +1,131 @@ +use std::sync::Arc; + +use anyhow::Result; +use anyhow::anyhow; +use async_trait::async_trait; +use hf_hub::Cache; +use hf_hub::Repo; +use hf_hub::RepoType; +use hf_hub::api::tokio::ApiBuilder; +use hf_hub::api::tokio::ApiError; +use log::warn; +use tokio::time::Duration; +use tokio::time::sleep; + +use paddler_types::agent_issue::AgentIssue; +use paddler_types::agent_issue_params::HuggingFaceDownloadLock; +use paddler_types::agent_issue_params::ModelPath; +use paddler_types::huggingface_model_reference::HuggingFaceModelReference; + +use crate::agent_issue_fix::AgentIssueFix; +use crate::desired_model_resolution::DesiredModelResolution; +use crate::resolves_model_source::ResolvesModelSource; +use crate::slot_aggregated_status::SlotAggregatedStatus; +use crate::slot_aggregated_status_download_progress::SlotAggregatedStatusDownloadProgress; + +const LOCK_RETRY_TIMEOUT: Duration = Duration::from_secs(10); + +#[async_trait] +impl ResolvesModelSource for HuggingFaceModelReference { + async fn resolve( + &self, + slot_aggregated_status: Arc, + ) -> Result { + let Self { + filename, + repo_id, + revision, + } = self; + let model_path = format!("{repo_id}/{revision}/{filename}"); + + if slot_aggregated_status.has_issue(&AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { + model_path: model_path.clone(), + })) { + return Err(anyhow!( + "Model '{model_path}' does not exist on Hugging Face. Not attempting to download it again." + )); + } + + let hf_cache = Cache::from_env(); + let hf_api = ApiBuilder::from_cache(hf_cache.clone()).build()?; + let hf_repo = hf_api.repo(Repo::with_revision( + repo_id.to_owned(), + RepoType::Model, + revision.to_owned(), + )); + + if let Some(cached_path) = hf_cache + .repo(Repo::new(repo_id.to_owned(), RepoType::Model)) + .get(filename) + { + slot_aggregated_status.reset_download(); + + return Ok(DesiredModelResolution::Resolved(cached_path)); + } + + match hf_repo + .download_with_progress( + filename, + SlotAggregatedStatusDownloadProgress::new(slot_aggregated_status.clone()), + ) + .await + { + Ok(resolved_filename) => { + slot_aggregated_status.register_fix(&AgentIssueFix::HuggingFaceDownloadedModel( + ModelPath { model_path }, + )); + + Ok(DesiredModelResolution::Resolved(resolved_filename)) + } + Err(ApiError::LockAcquisition(lock_path)) => { + slot_aggregated_status.register_issue(AgentIssue::HuggingFaceCannotAcquireLock( + HuggingFaceDownloadLock { + lock_path: lock_path.display().to_string(), + model_path: ModelPath { model_path }, + }, + )); + + warn!( + "Waiting to acquire download lock for '{}'. Sleeping for {} secs", + lock_path.display(), + LOCK_RETRY_TIMEOUT.as_secs() + ); + + sleep(LOCK_RETRY_TIMEOUT).await; + + Err(anyhow!( + "Failed to acquire download lock '{}'. Is more than one agent running on this machine?", + lock_path.display() + )) + } + Err(ApiError::RequestError(reqwest_error)) => match reqwest_error.status() { + Some(reqwest::StatusCode::NOT_FOUND) => { + slot_aggregated_status.register_issue( + AgentIssue::HuggingFaceModelDoesNotExist(ModelPath { + model_path: model_path.clone(), + }), + ); + + Err(anyhow!( + "Model '{model_path}' does not exist on Hugging Face." + )) + } + Some(reqwest::StatusCode::FORBIDDEN | reqwest::StatusCode::UNAUTHORIZED) => { + slot_aggregated_status.register_issue(AgentIssue::HuggingFacePermissions( + ModelPath { + model_path: model_path.clone(), + }, + )); + + Err(anyhow!( + "You do not have enough permissions to download '{model_path}' from Hugging Face." + )) + } + _ => Err(anyhow!( + "Failed to download model from Hugging Face: {reqwest_error}" + )), + }, + Err(err_other) => Err(err_other.into()), + } + } +} diff --git a/paddler/src/model_source/local.rs b/paddler/src/model_source/local.rs new file mode 100644 index 00000000..cee294fd --- /dev/null +++ b/paddler/src/model_source/local.rs @@ -0,0 +1,36 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; + +use crate::desired_model_resolution::DesiredModelResolution; +use crate::resolves_model_source::ResolvesModelSource; +use crate::slot_aggregated_status::SlotAggregatedStatus; + +pub struct LocalModelPath { + pub path: String, +} + +impl LocalModelPath { + #[must_use] + pub const fn new(path: String) -> Self { + Self { path } + } +} + +#[async_trait] +impl ResolvesModelSource for LocalModelPath { + async fn resolve( + &self, + _slot_aggregated_status: Arc, + ) -> Result { + let local_path = PathBuf::from(&self.path); + + if tokio::fs::try_exists(&local_path).await? { + Ok(DesiredModelResolution::Resolved(local_path)) + } else { + Ok(DesiredModelResolution::LocalFileMissing(local_path)) + } + } +} diff --git a/paddler/src/model_source/mod.rs b/paddler/src/model_source/mod.rs new file mode 100644 index 00000000..019c6fd8 --- /dev/null +++ b/paddler/src/model_source/mod.rs @@ -0,0 +1,3 @@ +pub mod huggingface; +pub mod local; +pub mod url; diff --git a/paddler/src/model_source/url.rs b/paddler/src/model_source/url.rs new file mode 100644 index 00000000..72992a5c --- /dev/null +++ b/paddler/src/model_source/url.rs @@ -0,0 +1,573 @@ +use std::io; +use std::sync::Arc; + +use anyhow::Result; +use anyhow::anyhow; +use async_trait::async_trait; +use url::Url; + +use paddler_cache_dir::CacheDir; +use paddler_cache_dir::CachedDownloadedModel; +use paddler_cache_dir::DownloadLockAcquisitionError; +use paddler_download_manager::download_error::DownloadError; +use paddler_download_manager::download_manager::DownloadManager; +use paddler_download_manager::progress_sink::ProgressSink; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::agent_issue_params::ModelPath; +use paddler_types::url_model_reference::UrlModelReference; + +use crate::agent_issue_fix::AgentIssueFix; +use crate::desired_model_resolution::DesiredModelResolution; +use crate::resolves_model_source::ResolvesModelSource; +use crate::slot_aggregated_status::SlotAggregatedStatus; + +#[cfg(unix)] +fn is_disk_full(error: &io::Error) -> bool { + error.raw_os_error() == Some(28) +} + +#[cfg(windows)] +fn is_disk_full(error: &io::Error) -> bool { + error.raw_os_error() == Some(112) +} + +fn classify_cache_io_error(url_string: &str, error: &io::Error) -> AgentIssue { + let model_path = ModelPath { + model_path: url_string.to_owned(), + }; + + if error.kind() == io::ErrorKind::PermissionDenied { + AgentIssue::CacheDirectoryIsNotWritable(model_path) + } else if is_disk_full(error) { + AgentIssue::CacheStorageIsFull(model_path) + } else { + AgentIssue::ModelCacheIsCorrupted(model_path) + } +} + +fn agent_issue_for(error: &DownloadError, url_string: &str) -> AgentIssue { + let model_path = ModelPath { + model_path: url_string.to_owned(), + }; + + match error { + DownloadError::InvalidUrl { .. } | DownloadError::UnsupportedUrlScheme { .. } => { + AgentIssue::DownloadUrlIsMalformed(model_path) + } + DownloadError::NotFound { .. } => AgentIssue::ModelDoesNotExistAtUrl(model_path), + DownloadError::PermissionDenied { .. } => { + AgentIssue::DownloadServerDeniedAccess(model_path) + } + DownloadError::DownloadServerIsUnreachable { .. } => { + AgentIssue::DownloadServerIsUnreachable(model_path) + } + DownloadError::DownloadServerErrored { .. } => { + AgentIssue::DownloadServerErrored(model_path) + } + DownloadError::DownloadServerRejectedRequest { .. } => { + AgentIssue::DownloadServerRejectedRequest(model_path) + } + DownloadError::DownloadInterrupted { .. } => AgentIssue::DownloadInterrupted(model_path), + DownloadError::CachePermissionDenied { .. } => { + AgentIssue::CacheDirectoryIsNotWritable(model_path) + } + DownloadError::CacheDiskFull { .. } => AgentIssue::CacheStorageIsFull(model_path), + DownloadError::PartialFileStale { .. } | DownloadError::Io { .. } => { + AgentIssue::ModelCacheIsCorrupted(model_path) + } + } +} + +struct SlotAggregatedStatusSink { + basename: Option, + slot_aggregated_status: Arc, + url: String, +} + +impl ProgressSink for SlotAggregatedStatusSink { + fn on_started(&self, total_bytes: Option, already_downloaded: u64) { + self.slot_aggregated_status.set_download_status( + already_downloaded, + total_bytes, + self.basename.clone(), + ); + self.slot_aggregated_status + .register_fix(&AgentIssueFix::ModelDownloadStarted(ModelPath { + model_path: self.url.clone(), + })); + } + + fn on_chunk(&self, additional_bytes: u64) { + self.slot_aggregated_status + .increment_download_current(additional_bytes); + } + + fn on_finished(&self) { + self.slot_aggregated_status + .register_fix(&AgentIssueFix::ModelDownloadCompleted(ModelPath { + model_path: self.url.clone(), + })); + self.slot_aggregated_status.reset_download(); + } +} + +async fn resolve_url_into_cache( + url_string: &str, + cache_dir: &CacheDir, + slot_aggregated_status: Arc, +) -> Result { + let parsed_url = match Url::parse(url_string) { + Ok(url) => url, + Err(parse_error) => { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(AgentIssue::DownloadUrlIsMalformed(ModelPath { + model_path: url_string.to_owned(), + })); + + return Err(anyhow::Error::new(parse_error) + .context(format!("Invalid URL '{url_string}'"))); + } + }; + + if !matches!(parsed_url.scheme(), "http" | "https") { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(AgentIssue::DownloadUrlIsMalformed(ModelPath { + model_path: url_string.to_owned(), + })); + + return Err(anyhow!( + "Unsupported URL scheme '{}' for '{url_string}'; only http and https are supported", + parsed_url.scheme(), + )); + } + + let cached = CachedDownloadedModel::new(cache_dir, url_string)?; + + let is_cached = match cached.is_cached().await { + Ok(value) => value, + Err(io_error) => { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(classify_cache_io_error(url_string, &io_error)); + + return Err(anyhow::Error::new(io_error)); + } + }; + + if is_cached { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_fix(&AgentIssueFix::ModelDownloadCompleted(ModelPath { + model_path: url_string.to_owned(), + })); + + return Ok(DesiredModelResolution::Resolved(cached.cache_file_path)); + } + + if let Err(io_error) = cached.ensure_cache_subdir_exists().await { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(classify_cache_io_error(url_string, &io_error)); + + return Err(anyhow::Error::new(io_error)); + } + + let _lock_guard = match cached.try_acquire_download_lock() { + Ok(guard) => guard, + Err(DownloadLockAcquisitionError::AnotherProcessIsDownloading) => { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(AgentIssue::CacheCannotAcquireLock(ModelPath { + model_path: url_string.to_owned(), + })); + + return Err(anyhow!( + "Another agent on this host is currently downloading '{url_string}'" + )); + } + Err(DownloadLockAcquisitionError::Io(io_error)) => { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(classify_cache_io_error(url_string, &io_error)); + + return Err(anyhow::Error::new(io_error)); + } + }; + + let basename = cached + .cache_file_path + .file_name() + .and_then(|name| name.to_str()) + .map(str::to_owned); + let sink: Arc = Arc::new(SlotAggregatedStatusSink { + basename, + slot_aggregated_status: slot_aggregated_status.clone(), + url: url_string.to_owned(), + }); + + match DownloadManager::new()? + .download(url_string, &cached.cache_file_path, sink) + .await + { + Ok(()) => Ok(DesiredModelResolution::Resolved(cached.cache_file_path)), + Err(error) => { + slot_aggregated_status.reset_download(); + slot_aggregated_status.register_issue(agent_issue_for(&error, url_string)); + + Err(anyhow::Error::new(error)) + } + } +} + +#[async_trait] +impl ResolvesModelSource for UrlModelReference { + async fn resolve( + &self, + slot_aggregated_status: Arc, + ) -> Result { + let cache_dir = CacheDir::from_process_env(); + + resolve_url_into_cache(&self.url, &cache_dir, slot_aggregated_status).await + } +} + +#[cfg(test)] +mod tests { + use std::io; + use std::path::PathBuf; + use std::sync::Arc; + + use anyhow::Context as _; + use anyhow::Result; + use anyhow::anyhow; + use paddler_cache_dir::CacheDir; + use paddler_cache_dir::CachedDownloadedModel; + use paddler_download_manager::download_error::DownloadError; + use paddler_types::agent_issue::AgentIssue; + use reqwest::StatusCode; + use tempfile::TempDir; + use url::Url; + + use crate::desired_model_resolution::DesiredModelResolution; + use crate::model_source::url::agent_issue_for; + use crate::model_source::url::classify_cache_io_error; + use crate::model_source::url::resolve_url_into_cache; + use crate::slot_aggregated_status::SlotAggregatedStatus; + + const TEST_URL: &str = "https://example.com/m.gguf"; + + fn fresh_status() -> Arc { + Arc::new(SlotAggregatedStatus::new(1)) + } + + fn cache_dir_at(path: &std::path::Path) -> CacheDir { + #[cfg(unix)] + { + CacheDir { + explicit: Some(path.to_string_lossy().into_owned()), + home: None, + xdg: None, + } + } + #[cfg(windows)] + { + CacheDir { + explicit: Some(path.to_string_lossy().into_owned()), + localappdata: None, + userprofile: None, + } + } + } + + #[tokio::test] + async fn cache_hit_returns_path_without_calling_download_manager() -> Result<()> { + let directory = TempDir::new()?; + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/cached.gguf"; + let cached = CachedDownloadedModel::new(&cache_dir, url_string)?; + cached.ensure_cache_subdir_exists().await?; + tokio::fs::write(&cached.cache_file_path, b"cached content").await?; + + let resolution = + resolve_url_into_cache(url_string, &cache_dir, fresh_status()).await?; + + match resolution { + DesiredModelResolution::Resolved(path) => { + assert_eq!(path, cached.cache_file_path); + } + other => return Err(anyhow!("expected Resolved, got {other:?}")), + } + + Ok(()) + } + + #[tokio::test] + async fn malformed_url_registers_download_url_is_malformed() -> Result<()> { + let directory = TempDir::new()?; + let cache_dir = cache_dir_at(directory.path()); + let url_string = "not a url"; + + let status = fresh_status(); + let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; + + assert!(result.is_err(), "malformed URL must produce an Err"); + assert!(status.has_issue(&AgentIssue::DownloadUrlIsMalformed( + paddler_types::agent_issue_params::ModelPath { + model_path: url_string.to_owned(), + }, + ))); + + Ok(()) + } + + #[tokio::test] + async fn unsupported_scheme_registers_download_url_is_malformed_without_creating_cache_state() + -> Result<()> { + let directory = TempDir::new()?; + let cache_dir = cache_dir_at(directory.path()); + let url_string = "ftp://example.invalid/m.gguf"; + + let status = fresh_status(); + let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; + + assert!(result.is_err(), "unsupported scheme must produce an Err"); + assert!(status.has_issue(&AgentIssue::DownloadUrlIsMalformed( + paddler_types::agent_issue_params::ModelPath { + model_path: url_string.to_owned(), + }, + ))); + assert!( + !directory.path().join("downloaded-models").exists(), + "no cache subdirectory must be created for an unsupported scheme" + ); + + Ok(()) + } + + #[tokio::test] + async fn lock_contention_registers_cache_cannot_acquire_lock() -> Result<()> { + let directory = TempDir::new()?; + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/contended.gguf"; + let cached = CachedDownloadedModel::new(&cache_dir, url_string)?; + cached.ensure_cache_subdir_exists().await?; + + let _blocker = cached.try_acquire_download_lock()?; + + let status = fresh_status(); + let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; + + assert!(result.is_err(), "lock contention must produce an Err"); + assert!(status.has_issue(&AgentIssue::CacheCannotAcquireLock( + paddler_types::agent_issue_params::ModelPath { + model_path: url_string.to_owned(), + }, + ))); + + Ok(()) + } + + #[test] + fn invalid_url_maps_to_download_url_is_malformed() -> Result<()> { + let parse_error = Url::parse("not a url") + .err() + .context("'not a url' should not parse")?; + let error = DownloadError::InvalidUrl { + url: "not a url".to_owned(), + source: parse_error, + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadUrlIsMalformed(_) + )); + + Ok(()) + } + + #[test] + fn unsupported_url_scheme_maps_to_download_url_is_malformed() { + let error = DownloadError::UnsupportedUrlScheme { + url: TEST_URL.to_owned(), + scheme: "ftp".to_owned(), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadUrlIsMalformed(_) + )); + } + + #[test] + fn not_found_maps_to_model_does_not_exist_at_url() { + let error = DownloadError::NotFound { + url: TEST_URL.to_owned(), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::ModelDoesNotExistAtUrl(_) + )); + } + + #[test] + fn permission_denied_maps_to_download_server_denied_access() { + let error = DownloadError::PermissionDenied { + url: TEST_URL.to_owned(), + status: StatusCode::FORBIDDEN, + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadServerDeniedAccess(_) + )); + } + + #[test] + fn partial_file_stale_maps_to_model_cache_is_corrupted() { + let error = DownloadError::PartialFileStale { + url: TEST_URL.to_owned(), + partial_path: PathBuf::from("/tmp/stale.partial"), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::ModelCacheIsCorrupted(_) + )); + } + + #[test] + fn download_server_is_unreachable_maps_to_agent_issue() { + let error = DownloadError::DownloadServerIsUnreachable { + url: TEST_URL.to_owned(), + source: anyhow!("connection refused"), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadServerIsUnreachable(_) + )); + } + + #[test] + fn download_server_errored_maps_to_agent_issue() { + let error = DownloadError::DownloadServerErrored { + url: TEST_URL.to_owned(), + status: StatusCode::INTERNAL_SERVER_ERROR, + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadServerErrored(_) + )); + } + + #[test] + fn download_server_rejected_request_maps_to_agent_issue() { + let error = DownloadError::DownloadServerRejectedRequest { + url: TEST_URL.to_owned(), + status: StatusCode::BAD_REQUEST, + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadServerRejectedRequest(_) + )); + } + + #[test] + fn download_interrupted_maps_to_agent_issue() { + let error = DownloadError::DownloadInterrupted { + url: TEST_URL.to_owned(), + source: anyhow!("stream dropped"), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::DownloadInterrupted(_) + )); + } + + #[test] + fn cache_permission_denied_maps_to_cache_directory_is_not_writable() { + let error = DownloadError::CachePermissionDenied { + path: PathBuf::from("/tmp/locked/model.partial"), + source: io::Error::from(io::ErrorKind::PermissionDenied), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::CacheDirectoryIsNotWritable(_) + )); + } + + #[test] + fn cache_disk_full_maps_to_cache_storage_is_full() { + let error = DownloadError::CacheDiskFull { + path: PathBuf::from("/tmp/full/model.partial"), + source: io::Error::from_raw_os_error(28), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::CacheStorageIsFull(_) + )); + } + + #[test] + fn io_maps_to_model_cache_is_corrupted() { + let error = DownloadError::Io { + path: PathBuf::from("/tmp/anywhere/model.partial"), + source: io::Error::from(io::ErrorKind::NotFound), + }; + + assert!(matches!( + agent_issue_for(&error, TEST_URL), + AgentIssue::ModelCacheIsCorrupted(_) + )); + } + + #[test] + fn classify_cache_io_error_maps_permission_denied_to_cache_directory_is_not_writable() { + let error = io::Error::from(io::ErrorKind::PermissionDenied); + + assert!(matches!( + classify_cache_io_error(TEST_URL, &error), + AgentIssue::CacheDirectoryIsNotWritable(_) + )); + } + + #[test] + fn classify_cache_io_error_maps_enospc_to_cache_storage_is_full() { + let error = io::Error::from_raw_os_error(28); + + assert!(matches!( + classify_cache_io_error(TEST_URL, &error), + AgentIssue::CacheStorageIsFull(_) + )); + } + + #[test] + fn classify_cache_io_error_falls_back_to_model_cache_is_corrupted() { + let error = io::Error::from(io::ErrorKind::NotFound); + + assert!(matches!( + classify_cache_io_error(TEST_URL, &error), + AgentIssue::ModelCacheIsCorrupted(_) + )); + } + + #[tokio::test] + async fn ensure_cache_subdir_failure_registers_model_cache_is_corrupted() -> Result<()> { + let directory = TempDir::new()?; + tokio::fs::write(directory.path().join("downloaded-models"), b"blocker").await?; + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/blocked.gguf"; + + let status = fresh_status(); + let result = resolve_url_into_cache(url_string, &cache_dir, status.clone()).await; + + assert!(result.is_err(), "blocked cache subdir must produce an Err"); + assert!(status.has_issue_like(|issue| matches!( + issue, + AgentIssue::ModelCacheIsCorrupted(_) + ))); + + Ok(()) + } +} diff --git a/paddler/src/resolve_desired_model.rs b/paddler/src/resolve_desired_model.rs index 25683ea8..51fde0d7 100644 --- a/paddler/src/resolve_desired_model.rs +++ b/paddler/src/resolve_desired_model.rs @@ -1,11 +1,11 @@ -use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; use paddler_types::agent_desired_model::AgentDesiredModel; use crate::desired_model_resolution::DesiredModelResolution; -use crate::download_huggingface_model::download_huggingface_model; +use crate::model_source::local::LocalModelPath; +use crate::resolves_model_source::ResolvesModelSource; use crate::slot_aggregated_status::SlotAggregatedStatus; pub async fn resolve_desired_model( @@ -14,19 +14,14 @@ pub async fn resolve_desired_model( ) -> Result { match desired { AgentDesiredModel::HuggingFace(reference) => { - let path = download_huggingface_model(reference, slot_aggregated_status).await?; - - Ok(DesiredModelResolution::Resolved(path)) + reference.resolve(slot_aggregated_status).await } AgentDesiredModel::LocalToAgent(path) => { - let local_path = PathBuf::from(path); - - if tokio::fs::try_exists(&local_path).await? { - Ok(DesiredModelResolution::Resolved(local_path)) - } else { - Ok(DesiredModelResolution::LocalFileMissing(local_path)) - } + LocalModelPath::new(path.clone()) + .resolve(slot_aggregated_status) + .await } + AgentDesiredModel::Url(reference) => reference.resolve(slot_aggregated_status).await, AgentDesiredModel::None => Ok(DesiredModelResolution::NotConfigured), } } diff --git a/paddler/src/resolves_model_source.rs b/paddler/src/resolves_model_source.rs new file mode 100644 index 00000000..fe3a0471 --- /dev/null +++ b/paddler/src/resolves_model_source.rs @@ -0,0 +1,15 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; + +use crate::desired_model_resolution::DesiredModelResolution; +use crate::slot_aggregated_status::SlotAggregatedStatus; + +#[async_trait] +pub trait ResolvesModelSource { + async fn resolve( + &self, + slot_aggregated_status: Arc, + ) -> Result; +} diff --git a/paddler/src/slot_aggregated_status.rs b/paddler/src/slot_aggregated_status.rs index edb6299d..0ad9d2ff 100644 --- a/paddler/src/slot_aggregated_status.rs +++ b/paddler/src/slot_aggregated_status.rs @@ -1,7 +1,7 @@ use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::AtomicU64; use anyhow::Result; use dashmap::DashSet; @@ -18,9 +18,10 @@ use crate::subscribes_to_updates::SubscribesToUpdates; pub struct SlotAggregatedStatus { desired_slots_total: i32, - download_current: AtomicValue, + download_current: AtomicValue, download_filename: RwLock>, - download_total: AtomicValue, + download_indeterminate: AtomicValue, + download_total: AtomicValue, issues: DashSet, model_path: RwLock>, slots_processing: AtomicValue, @@ -38,9 +39,10 @@ impl SlotAggregatedStatus { Self { desired_slots_total, - download_current: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), download_filename: RwLock::new(None), - download_total: AtomicValue::::new(0), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), issues: DashSet::new(), model_path: RwLock::new(None), state_application_status_code: AtomicValue::::new( @@ -77,7 +79,7 @@ impl SlotAggregatedStatus { .any(|ref_multi| issue_like(ref_multi.key())) } - pub fn increment_download_current(&self, size: usize) { + pub fn increment_download_current(&self, size: u64) { self.download_current.increment_by(size); self.version.increment(); self.update_tx.send_replace(()); @@ -117,14 +119,26 @@ impl SlotAggregatedStatus { pub fn reset_download(&self) { self.download_current.set(0); self.download_total.set(0); + self.download_indeterminate.set(true); self.set_download_filename(None); self.version.increment(); self.update_tx.send_replace(()); } - pub fn set_download_status(&self, current: usize, total: usize, filename: Option) { + pub fn set_download_status( + &self, + current: u64, + total: Option, + filename: Option, + ) { self.download_current.set(current); - self.download_total.set(total); + if let Some(value) = total { + self.download_total.set(value); + self.download_indeterminate.set(false); + } else { + self.download_total.set(0); + self.download_indeterminate.set(true); + } self.set_download_filename(filename); } @@ -209,6 +223,7 @@ impl ProducesSnapshot for SlotAggregatedStatus { .read() .expect("Lock poisoned when getting download filename") .clone(), + download_indeterminate: self.download_indeterminate.get(), download_total: self.download_total.get(), model_path: self .model_path @@ -395,7 +410,7 @@ mod tests { fn set_download_status_updates_all_fields() -> Result<()> { let status = SlotAggregatedStatus::new(2); - status.set_download_status(100, 500, Some("model.gguf".to_owned())); + status.set_download_status(100, Some(500), Some("model.gguf".to_owned())); let snapshot = status.make_snapshot()?; @@ -406,11 +421,55 @@ mod tests { Ok(()) } + #[test] + fn set_download_status_with_indeterminate_total_keeps_flag_true() -> Result<()> { + let status = SlotAggregatedStatus::new(2); + + status.set_download_status(123, None, Some("model.gguf".to_owned())); + + let snapshot = status.make_snapshot()?; + + assert_eq!(snapshot.download_current, 123); + assert_eq!(snapshot.download_total, 0); + assert!(snapshot.download_indeterminate); + + Ok(()) + } + + #[test] + fn set_download_status_indeterminate_after_known_total_resets_download_total() -> Result<()> { + let status = SlotAggregatedStatus::new(2); + + status.set_download_status(0, Some(5000), Some("model.gguf".to_owned())); + status.set_download_status(10, None, Some("model.gguf".to_owned())); + + let snapshot = status.make_snapshot()?; + + assert_eq!(snapshot.download_total, 0); + assert!(snapshot.download_indeterminate); + + Ok(()) + } + + #[test] + fn set_download_status_with_known_total_flips_indeterminate_false() -> Result<()> { + let status = SlotAggregatedStatus::new(2); + + status.set_download_status(0, Some(5000), Some("model.gguf".to_owned())); + + let snapshot = status.make_snapshot()?; + + assert_eq!(snapshot.download_total, 5000); + assert!(!snapshot.download_indeterminate); + + Ok(()) + } + #[test] fn increment_download_current_accumulates() -> Result<()> { let status = SlotAggregatedStatus::new(2); - status.set_download_status(0, 1000, Some("model.gguf".to_owned())); + status.set_download_status(0, Some(1000), Some("model.gguf".to_owned())); status.increment_download_current(100); status.increment_download_current(200); @@ -426,13 +485,14 @@ mod tests { fn reset_download_clears_download_fields() -> Result<()> { let status = SlotAggregatedStatus::new(2); - status.set_download_status(500, 1000, Some("model.gguf".to_owned())); + status.set_download_status(500, Some(1000), Some("model.gguf".to_owned())); status.reset_download(); let snapshot = status.make_snapshot()?; assert_eq!(snapshot.download_current, 0); assert_eq!(snapshot.download_total, 0); + assert!(snapshot.download_indeterminate); assert_eq!(snapshot.download_filename, None); Ok(()) diff --git a/paddler/src/slot_aggregated_status_download_progress.rs b/paddler/src/slot_aggregated_status_download_progress.rs index 09f6e0b8..a97ffe27 100644 --- a/paddler/src/slot_aggregated_status_download_progress.rs +++ b/paddler/src/slot_aggregated_status_download_progress.rs @@ -27,11 +27,12 @@ impl Progress for SlotAggregatedStatusDownloadProgress { })); self.slot_aggregated_status - .set_download_status(0, size, Some(filename.to_owned())); + .set_download_status(0, Some(size as u64), Some(filename.to_owned())); } async fn update(&mut self, size: usize) { - self.slot_aggregated_status.increment_download_current(size); + self.slot_aggregated_status + .increment_download_current(size as u64); } async fn finish(&mut self) { diff --git a/paddler_cache_dir/Cargo.toml b/paddler_cache_dir/Cargo.toml new file mode 100644 index 00000000..3df411fa --- /dev/null +++ b/paddler_cache_dir/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "paddler_cache_dir" +authors.workspace = true +description.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +anyhow = { workspace = true } +fslock = { workspace = true } +sha2 = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_cache_dir/src/cache_dir/mod.rs b/paddler_cache_dir/src/cache_dir/mod.rs new file mode 100644 index 00000000..a2e566c3 --- /dev/null +++ b/paddler_cache_dir/src/cache_dir/mod.rs @@ -0,0 +1,9 @@ +#[cfg(unix)] +mod unix; +#[cfg(unix)] +pub use crate::cache_dir::unix::CacheDir; + +#[cfg(windows)] +mod windows; +#[cfg(windows)] +pub use crate::cache_dir::windows::CacheDir; diff --git a/paddler_cache_dir/src/cache_dir/unix.rs b/paddler_cache_dir/src/cache_dir/unix.rs new file mode 100644 index 00000000..7ac0d285 --- /dev/null +++ b/paddler_cache_dir/src/cache_dir/unix.rs @@ -0,0 +1,95 @@ +use std::path::PathBuf; + +use anyhow::Context as _; +use anyhow::Result; + +pub struct CacheDir { + pub explicit: Option, + pub home: Option, + pub xdg: Option, +} + +impl CacheDir { + #[must_use] + pub fn from_process_env() -> Self { + Self { + explicit: std::env::var("PADDLER_CACHE_DIR").ok(), + home: std::env::var("HOME").ok(), + xdg: std::env::var("XDG_CACHE_HOME").ok(), + } + } + + pub fn resolve(&self) -> Result { + if let Some(explicit) = &self.explicit { + return Ok(PathBuf::from(explicit)); + } + + if let Some(xdg) = &self.xdg { + return Ok(PathBuf::from(xdg).join("paddler")); + } + + let home = self + .home + .as_ref() + .context("HOME not set; cannot derive paddler cache directory")?; + + Ok(PathBuf::from(home).join(".cache").join("paddler")) + } +} + +#[cfg(test)] +mod tests { + use crate::cache_dir::unix::CacheDir; + + #[test] + fn explicit_value_wins_over_xdg_and_home() { + let cache = CacheDir { + explicit: Some("/explicit/cache".to_owned()), + home: Some("/home/user".to_owned()), + xdg: Some("/xdg/cache".to_owned()), + }; + let path = cache.resolve().unwrap_or_default(); + + assert_eq!(path.to_string_lossy(), "/explicit/cache"); + } + + #[test] + fn xdg_value_used_when_no_explicit() { + let cache = CacheDir { + explicit: None, + home: Some("/home/user".to_owned()), + xdg: Some("/xdg/cache".to_owned()), + }; + let path = cache.resolve().unwrap_or_default(); + + assert_eq!(path.to_string_lossy(), "/xdg/cache/paddler"); + } + + #[test] + fn falls_back_to_home_dot_cache_paddler() { + let cache = CacheDir { + explicit: None, + home: Some("/home/user".to_owned()), + xdg: None, + }; + let path = cache.resolve().unwrap_or_default(); + + assert_eq!(path.to_string_lossy(), "/home/user/.cache/paddler"); + } + + #[test] + fn errors_when_no_env_set() { + let cache = CacheDir { + explicit: None, + home: None, + xdg: None, + }; + + assert!(cache.resolve().is_err()); + } + + #[test] + fn from_process_env_constructs_without_panicking() { + let _ = CacheDir::from_process_env(); + } +} diff --git a/paddler_cache_dir/src/cache_dir/windows.rs b/paddler_cache_dir/src/cache_dir/windows.rs new file mode 100644 index 00000000..88606051 --- /dev/null +++ b/paddler_cache_dir/src/cache_dir/windows.rs @@ -0,0 +1,104 @@ +use std::path::PathBuf; + +use anyhow::Context as _; +use anyhow::Result; + +pub struct CacheDir { + pub explicit: Option, + pub localappdata: Option, + pub userprofile: Option, +} + +impl CacheDir { + #[must_use] + pub fn from_process_env() -> Self { + Self { + explicit: std::env::var("PADDLER_CACHE_DIR").ok(), + localappdata: std::env::var("LOCALAPPDATA").ok(), + userprofile: std::env::var("USERPROFILE").ok(), + } + } + + pub fn resolve(&self) -> Result { + if let Some(explicit) = &self.explicit { + return Ok(PathBuf::from(explicit)); + } + + if let Some(localappdata) = &self.localappdata { + return Ok(PathBuf::from(localappdata).join("paddler")); + } + + let userprofile = self + .userprofile + .as_ref() + .context("USERPROFILE not set; cannot derive paddler cache directory")?; + + Ok(PathBuf::from(userprofile) + .join("AppData") + .join("Local") + .join("paddler")) + } +} + +#[cfg(test)] +mod tests { + use crate::cache_dir::windows::CacheDir; + + #[test] + fn explicit_value_wins_over_localappdata_and_userprofile() { + let cache = CacheDir { + explicit: Some(r"D:\explicit\cache".to_owned()), + localappdata: Some(r"C:\Users\user\AppData\Local".to_owned()), + userprofile: Some(r"C:\Users\user".to_owned()), + }; + let path = cache.resolve().unwrap_or_default(); + + assert_eq!(path.to_string_lossy(), r"D:\explicit\cache"); + } + + #[test] + fn localappdata_used_when_no_explicit() { + let cache = CacheDir { + explicit: None, + localappdata: Some(r"C:\Users\user\AppData\Local".to_owned()), + userprofile: Some(r"C:\Users\user".to_owned()), + }; + let path = cache.resolve().unwrap_or_default(); + + assert_eq!( + path.to_string_lossy(), + r"C:\Users\user\AppData\Local\paddler" + ); + } + + #[test] + fn falls_back_to_userprofile_appdata_local_paddler() { + let cache = CacheDir { + explicit: None, + localappdata: None, + userprofile: Some(r"C:\Users\user".to_owned()), + }; + let path = cache.resolve().unwrap_or_default(); + + assert_eq!( + path.to_string_lossy(), + r"C:\Users\user\AppData\Local\paddler" + ); + } + + #[test] + fn errors_when_no_env_set() { + let cache = CacheDir { + explicit: None, + localappdata: None, + userprofile: None, + }; + + assert!(cache.resolve().is_err()); + } + + #[test] + fn from_process_env_constructs_without_panicking() { + let _ = CacheDir::from_process_env(); + } +} diff --git a/paddler_cache_dir/src/cached_downloaded_model.rs b/paddler_cache_dir/src/cached_downloaded_model.rs new file mode 100644 index 00000000..8c6bdcf2 --- /dev/null +++ b/paddler_cache_dir/src/cached_downloaded_model.rs @@ -0,0 +1,288 @@ +use std::fmt::Write as _; +use std::path::PathBuf; + +use anyhow::Result; +use fslock::LockFile; +use sha2::Digest; +use sha2::Sha256; +use tokio::fs; + +use crate::cache_dir::CacheDir; +use crate::cached_downloaded_model_lock::CachedDownloadedModelLock; +use crate::download_lock_acquisition_error::DownloadLockAcquisitionError; + +const DOWNLOADED_MODELS_SUBDIR: &str = "downloaded-models"; + +fn hex_lowercase(bytes: &[u8]) -> String { + bytes + .iter() + .fold(String::with_capacity(bytes.len() * 2), |mut acc, byte| { + let _ = write!(acc, "{byte:02x}"); + acc + }) +} + +pub struct CachedDownloadedModel { + pub cache_file_path: PathBuf, + pub cache_subdir: PathBuf, + pub lock_file_path: PathBuf, +} + +impl CachedDownloadedModel { + pub fn new(cache_dir: &CacheDir, url_string: &str) -> Result { + let cache_root = cache_dir.resolve()?; + let basename = hex_lowercase(&Sha256::digest(url_string.as_bytes())); + + let cache_subdir = cache_root.join(DOWNLOADED_MODELS_SUBDIR); + let cache_file_path = cache_subdir.join(&basename); + let lock_file_path = cache_subdir.join(format!("{basename}.lock")); + + Ok(Self { + cache_file_path, + cache_subdir, + lock_file_path, + }) + } + + pub async fn is_cached(&self) -> Result { + fs::try_exists(&self.cache_file_path).await + } + + pub async fn ensure_cache_subdir_exists(&self) -> Result<(), std::io::Error> { + fs::create_dir_all(&self.cache_subdir).await + } + + pub fn try_acquire_download_lock( + &self, + ) -> Result { + let (acquired, lock_file) = LockFile::open(&self.lock_file_path) + .and_then(|mut file| file.try_lock().map(|acquired| (acquired, file)))?; + if acquired { + Ok(CachedDownloadedModelLock::new(lock_file)) + } else { + Err(DownloadLockAcquisitionError::AnotherProcessIsDownloading) + } + } +} + +#[cfg(test)] +mod tests { + use fslock::LockFile; + use sha2::Digest; + use sha2::Sha256; + use tempfile::TempDir; + + use crate::cache_dir::CacheDir; + use crate::cached_downloaded_model::CachedDownloadedModel; + use crate::cached_downloaded_model::hex_lowercase; + + fn cache_dir_at(path: &std::path::Path) -> CacheDir { + #[cfg(unix)] + { + CacheDir { + explicit: Some(path.to_string_lossy().into_owned()), + home: None, + xdg: None, + } + } + #[cfg(windows)] + { + CacheDir { + explicit: Some(path.to_string_lossy().into_owned()), + localappdata: None, + userprofile: None, + } + } + } + + #[test] + fn cache_file_basename_is_only_lowercase_hex_for_traversal_url() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://example.com/../../etc/passwd?token=secret"; + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + + let file_name = cached + .cache_file_path + .file_name() + .and_then(|name| name.to_str()) + .unwrap(); + + assert_eq!(file_name.len(), 64, "SHA-256 hex is 64 chars"); + assert!( + file_name + .chars() + .all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)), + "basename {file_name:?} must be lowercase hex only" + ); + } + + #[test] + fn cache_file_path_for_traversal_url_stays_directly_under_downloaded_models() { + let traversal_urls = [ + "https://example.com/..", + "https://example.com/../../etc/passwd", + "https://example.com//etc//passwd", + "https://example.com/foo%2Fbar", + "https://example.com/", + ]; + + for url_string in traversal_urls { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + let expected_parent = directory.path().join("downloaded-models"); + + assert_eq!( + cached.cache_file_path.parent(), + Some(expected_parent.as_path()), + "URL {url_string:?} produced cache file outside downloaded-models" + ); + } + } + + #[test] + fn cache_file_path_is_sha256_hex_under_downloaded_models() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/folder/model.gguf"; + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + + let expected_hex = hex_lowercase(&Sha256::digest(url_string.as_bytes())); + let expected_path = directory + .path() + .join("downloaded-models") + .join(&expected_hex); + + assert_eq!(cached.cache_file_path, expected_path); + } + + #[test] + fn lock_file_path_is_hex_dot_lock_next_to_cache_file() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let url_string = "https://host.example/model.gguf"; + let cached = CachedDownloadedModel::new(&cache_dir, url_string).unwrap(); + + let expected_hex = hex_lowercase(&Sha256::digest(url_string.as_bytes())); + let expected_lock = directory + .path() + .join("downloaded-models") + .join(format!("{expected_hex}.lock")); + + assert_eq!(cached.lock_file_path, expected_lock); + assert_eq!( + cached.cache_file_path.parent(), + cached.lock_file_path.parent() + ); + } + + #[tokio::test] + async fn is_cached_returns_false_when_cache_file_absent() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/missing.gguf").unwrap(); + + assert!(!cached.is_cached().await.unwrap()); + } + + #[tokio::test] + async fn is_cached_returns_true_when_cache_file_present() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/present.gguf").unwrap(); + + cached.ensure_cache_subdir_exists().await.unwrap(); + tokio::fs::write(&cached.cache_file_path, b"cached") + .await + .unwrap(); + + assert!(cached.is_cached().await.unwrap()); + } + + #[tokio::test] + async fn try_acquire_download_lock_succeeds_when_uncontested() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/model.gguf").unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); + + let _guard = cached.try_acquire_download_lock().unwrap(); + } + + #[tokio::test] + async fn try_acquire_download_lock_returns_another_process_when_locked() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/model.gguf").unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); + + let mut blocker = LockFile::open(&cached.lock_file_path).unwrap(); + let blocker_acquired = blocker.try_lock().unwrap(); + assert!(blocker_acquired, "blocker must acquire the lock first"); + + let result = cached.try_acquire_download_lock(); + + assert!( + result + .unwrap_err() + .is_another_process_downloading() + ); + } + + #[test] + fn new_returns_error_when_cache_dir_cannot_resolve() { + let unresolvable; + #[cfg(unix)] + { + unresolvable = CacheDir { + explicit: None, + home: None, + xdg: None, + }; + } + #[cfg(windows)] + { + unresolvable = CacheDir { + explicit: None, + localappdata: None, + userprofile: None, + }; + } + + let result = CachedDownloadedModel::new(&unresolvable, "https://host.example/m.gguf"); + + assert!(result.is_err()); + } + + #[tokio::test] + async fn try_acquire_download_lock_returns_io_when_cache_subdir_missing() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/model.gguf").unwrap(); + + let result = cached.try_acquire_download_lock(); + + assert!(result.unwrap_err().is_io()); + } + + #[tokio::test] + async fn lock_releases_on_drop_so_subsequent_acquire_succeeds() { + let directory = TempDir::new().unwrap(); + let cache_dir = cache_dir_at(directory.path()); + let cached = + CachedDownloadedModel::new(&cache_dir, "https://host.example/model.gguf").unwrap(); + cached.ensure_cache_subdir_exists().await.unwrap(); + + { + let _guard = cached.try_acquire_download_lock().unwrap(); + } + + let _second_guard = cached.try_acquire_download_lock().unwrap(); + } +} diff --git a/paddler_cache_dir/src/cached_downloaded_model_lock.rs b/paddler_cache_dir/src/cached_downloaded_model_lock.rs new file mode 100644 index 00000000..a961fc05 --- /dev/null +++ b/paddler_cache_dir/src/cached_downloaded_model_lock.rs @@ -0,0 +1,15 @@ +use fslock::LockFile; + +#[derive(Debug)] +pub struct CachedDownloadedModelLock { + _lock_file: LockFile, +} + +impl CachedDownloadedModelLock { + #[must_use] + pub const fn new(lock_file: LockFile) -> Self { + Self { + _lock_file: lock_file, + } + } +} diff --git a/paddler_cache_dir/src/download_lock_acquisition_error.rs b/paddler_cache_dir/src/download_lock_acquisition_error.rs new file mode 100644 index 00000000..edada94c --- /dev/null +++ b/paddler_cache_dir/src/download_lock_acquisition_error.rs @@ -0,0 +1,48 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum DownloadLockAcquisitionError { + #[error("another agent on this host is currently downloading this URL")] + AnotherProcessIsDownloading, + #[error(transparent)] + Io(#[from] std::io::Error), +} + +impl DownloadLockAcquisitionError { + #[must_use] + pub const fn is_another_process_downloading(&self) -> bool { + matches!(self, Self::AnotherProcessIsDownloading) + } + + #[must_use] + pub const fn is_io(&self) -> bool { + matches!(self, Self::Io(_)) + } +} + +#[cfg(test)] +mod tests { + use std::io; + + use crate::download_lock_acquisition_error::DownloadLockAcquisitionError; + + #[test] + fn is_another_process_downloading_returns_true_only_for_that_variant() { + let another_process = DownloadLockAcquisitionError::AnotherProcessIsDownloading; + let io_error = + DownloadLockAcquisitionError::Io(io::Error::from(io::ErrorKind::NotFound)); + + assert!(another_process.is_another_process_downloading()); + assert!(!io_error.is_another_process_downloading()); + } + + #[test] + fn is_io_returns_true_only_for_io_variant() { + let io_error = + DownloadLockAcquisitionError::Io(io::Error::from(io::ErrorKind::NotFound)); + let another_process = DownloadLockAcquisitionError::AnotherProcessIsDownloading; + + assert!(io_error.is_io()); + assert!(!another_process.is_io()); + } +} diff --git a/paddler_cache_dir/src/lib.rs b/paddler_cache_dir/src/lib.rs new file mode 100644 index 00000000..d891ebc8 --- /dev/null +++ b/paddler_cache_dir/src/lib.rs @@ -0,0 +1,9 @@ +mod cache_dir; +mod cached_downloaded_model; +mod cached_downloaded_model_lock; +mod download_lock_acquisition_error; + +pub use crate::cache_dir::CacheDir; +pub use crate::cached_downloaded_model::CachedDownloadedModel; +pub use crate::cached_downloaded_model_lock::CachedDownloadedModelLock; +pub use crate::download_lock_acquisition_error::DownloadLockAcquisitionError; diff --git a/paddler_client_javascript/src/schemas/Agent.ts b/paddler_client_javascript/src/schemas/Agent.ts index 5d5e6fe6..1ad87eaa 100644 --- a/paddler_client_javascript/src/schemas/Agent.ts +++ b/paddler_client_javascript/src/schemas/Agent.ts @@ -7,6 +7,7 @@ export const AgentSchema = z desired_slots_total: z.number(), download_current: z.number(), download_filename: z.string().nullable(), + download_indeterminate: z.boolean(), download_total: z.number(), id: z.string(), issues: z.array(AgentIssueSchema), diff --git a/paddler_client_javascript/src/schemas/AgentDesiredModel.ts b/paddler_client_javascript/src/schemas/AgentDesiredModel.ts index d2102349..b330b4cf 100644 --- a/paddler_client_javascript/src/schemas/AgentDesiredModel.ts +++ b/paddler_client_javascript/src/schemas/AgentDesiredModel.ts @@ -1,6 +1,7 @@ import { z } from "zod"; import { HuggingFaceModelReferenceSchema } from "./HuggingFaceModelReference"; +import { UrlModelReferenceSchema } from "./UrlModelReference"; export const AgentDesiredModelSchema = z.union([ z.object({ @@ -9,6 +10,9 @@ export const AgentDesiredModelSchema = z.union([ z.object({ LocalToAgent: z.string(), }), + z.object({ + Url: UrlModelReferenceSchema, + }), z.literal("None"), ]); diff --git a/paddler_client_javascript/src/schemas/AgentIssue.ts b/paddler_client_javascript/src/schemas/AgentIssue.ts index 93060ac0..b1247273 100644 --- a/paddler_client_javascript/src/schemas/AgentIssue.ts +++ b/paddler_client_javascript/src/schemas/AgentIssue.ts @@ -4,6 +4,15 @@ import { AgentIssueModelPathSchema } from "./AgentIssueModelPath"; import { HuggingFaceDownloadLockSchema } from "./HuggingFaceDownloadLock"; export const AgentIssueSchema = z.union([ + z.object({ + CacheCannotAcquireLock: AgentIssueModelPathSchema, + }), + z.object({ + CacheDirectoryIsNotWritable: AgentIssueModelPathSchema, + }), + z.object({ + CacheStorageIsFull: AgentIssueModelPathSchema, + }), z.object({ ChatTemplateDoesNotCompile: z.object({ error: z.string(), @@ -11,6 +20,24 @@ export const AgentIssueSchema = z.union([ template_content: z.string(), }), }), + z.object({ + DownloadServerDeniedAccess: AgentIssueModelPathSchema, + }), + z.object({ + DownloadServerErrored: AgentIssueModelPathSchema, + }), + z.object({ + DownloadServerIsUnreachable: AgentIssueModelPathSchema, + }), + z.object({ + DownloadServerRejectedRequest: AgentIssueModelPathSchema, + }), + z.object({ + DownloadInterrupted: AgentIssueModelPathSchema, + }), + z.object({ + DownloadUrlIsMalformed: AgentIssueModelPathSchema, + }), z.object({ HuggingFaceCannotAcquireLock: HuggingFaceDownloadLockSchema, }), @@ -20,9 +47,15 @@ export const AgentIssueSchema = z.union([ z.object({ HuggingFacePermissions: AgentIssueModelPathSchema, }), + z.object({ + ModelCacheIsCorrupted: AgentIssueModelPathSchema, + }), z.object({ ModelCannotBeLoaded: AgentIssueModelPathSchema, }), + z.object({ + ModelDoesNotExistAtUrl: AgentIssueModelPathSchema, + }), z.object({ ModelFileDoesNotExist: AgentIssueModelPathSchema, }), diff --git a/paddler_client_javascript/src/schemas/UrlModelReference.ts b/paddler_client_javascript/src/schemas/UrlModelReference.ts new file mode 100644 index 00000000..dce78a9a --- /dev/null +++ b/paddler_client_javascript/src/schemas/UrlModelReference.ts @@ -0,0 +1,7 @@ +import { z } from "zod"; + +export const UrlModelReferenceSchema = z.object({ + url: z.string(), +}); + +export type UrlModelReference = z.infer; diff --git a/paddler_client_javascript/src/urlToAgentDesiredModel.ts b/paddler_client_javascript/src/urlToAgentDesiredModel.ts index 9372ad44..2a73b69d 100644 --- a/paddler_client_javascript/src/urlToAgentDesiredModel.ts +++ b/paddler_client_javascript/src/urlToAgentDesiredModel.ts @@ -14,5 +14,11 @@ export function urlToAgentDesiredModel(url: URL): AgentDesiredModel { }; } + if (url.protocol === "http:" || url.protocol === "https:") { + return { + Url: { url: url.toString() }, + }; + } + throw new Error("Unsupported URL format"); } diff --git a/paddler_client_javascript/tests/schemas/Agent.test.ts b/paddler_client_javascript/tests/schemas/Agent.test.ts index a01ff9c3..3ad0ba4e 100644 --- a/paddler_client_javascript/tests/schemas/Agent.test.ts +++ b/paddler_client_javascript/tests/schemas/Agent.test.ts @@ -8,6 +8,7 @@ test("parses a fully populated agent payload", function () { desired_slots_total: 4, download_current: 0, download_filename: null, + download_indeterminate: false, download_total: 0, id: "agent-0", issues: [], @@ -29,6 +30,7 @@ test("rejects an unknown state_application_status", function () { desired_slots_total: 1, download_current: 0, download_filename: null, + download_indeterminate: false, download_total: 0, id: "agent-x", issues: [], diff --git a/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts b/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts index 06e6b517..671906cb 100644 --- a/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts +++ b/paddler_client_javascript/tests/urlToAgentDesiredModel.test.ts @@ -25,8 +25,8 @@ test("agent: URLs become LocalToAgent variant", function () { }); }); -test("unsupported URLs throw", function () { - const url = new URL("https://example.com/some/path"); +test("non-http(s), non-agent URLs throw", function () { + const url = new URL("ftp://example.com/file.gguf"); throws( function () { @@ -35,3 +35,33 @@ test("unsupported URLs throw", function () { { message: "Unsupported URL format" }, ); }); + +test("the user's Qwen 3.6 35B blob URL still routes to HuggingFace", function () { + const url = new URL( + "https://huggingface.co/unsloth/Qwen3.6-35B-A3B-GGUF/blob/main/Qwen3.6-35B-A3B-UD-Q4_K_M.gguf", + ); + + deepStrictEqual(urlToAgentDesiredModel(url), { + HuggingFace: { + filename: "Qwen3.6-35B-A3B-UD-Q4_K_M.gguf", + repo_id: "unsloth/Qwen3.6-35B-A3B-GGUF", + revision: "main", + }, + }); +}); + +test("https URLs off huggingface.co route to the Url variant", function () { + const url = new URL("https://example.com/path/to/model.gguf"); + + deepStrictEqual(urlToAgentDesiredModel(url), { + Url: { url: "https://example.com/path/to/model.gguf" }, + }); +}); + +test("plain http URLs route to the Url variant", function () { + const url = new URL("http://mirror.example.org/Qwen3-0.6B.gguf"); + + deepStrictEqual(urlToAgentDesiredModel(url), { + Url: { url: "http://mirror.example.org/Qwen3-0.6B.gguf" }, + }); +}); diff --git a/paddler_client_python/.gitignore b/paddler_client_python/.gitignore index 6b152e7e..4cf16182 100644 --- a/paddler_client_python/.gitignore +++ b/paddler_client_python/.gitignore @@ -18,3 +18,4 @@ venv/ *.log .coverage htmlcov/ +target/ diff --git a/paddler_client_python/paddler_client/agent_controller_snapshot.py b/paddler_client_python/paddler_client/agent_controller_snapshot.py index 522be9b4..72247e6c 100644 --- a/paddler_client_python/paddler_client/agent_controller_snapshot.py +++ b/paddler_client_python/paddler_client/agent_controller_snapshot.py @@ -10,6 +10,7 @@ class AgentControllerSnapshot(BaseModel): desired_slots_total: int download_current: int download_filename: str | None = None + download_indeterminate: bool download_total: int id: str issues: list[AgentIssue] = [] diff --git a/paddler_client_python/paddler_client/agent_desired_model.py b/paddler_client_python/paddler_client/agent_desired_model.py index 4b8abda8..74f94fee 100644 --- a/paddler_client_python/paddler_client/agent_desired_model.py +++ b/paddler_client_python/paddler_client/agent_desired_model.py @@ -5,6 +5,7 @@ from paddler_client.huggingface_model_reference import ( HuggingFaceModelReference, ) +from paddler_client.url_model_reference import UrlModelReference class AgentDesiredModel(BaseModel): @@ -13,6 +14,7 @@ class AgentDesiredModel(BaseModel): variant: str huggingface: HuggingFaceModelReference | None = None local_path: str | None = None + url: UrlModelReference | None = None @model_validator(mode="before") @classmethod @@ -35,6 +37,12 @@ def from_serde(cls, data: Any) -> dict[str, Any]: "local_path": typed_data["LocalToAgent"], } + if "Url" in typed_data: + return { + "variant": "Url", + "url": typed_data["Url"], + } + if "variant" in typed_data: return typed_data @@ -56,6 +64,13 @@ def to_serde(self) -> str | dict[str, Any]: return {"LocalToAgent": self.local_path} + if self.variant == "Url": + if self.url is None: + msg = "url is required for Url" + raise ValueError(msg) + + return {"Url": self.url.model_dump()} + msg = f"Unknown AgentDesiredModel variant: {self.variant}" raise ValueError(msg) @@ -72,3 +87,7 @@ def from_huggingface( @classmethod def local_to_agent(cls, path: str) -> "AgentDesiredModel": return cls(variant="LocalToAgent", local_path=path) + + @classmethod + def from_url(cls, reference: UrlModelReference) -> "AgentDesiredModel": + return cls(variant="Url", url=reference) diff --git a/paddler_client_python/paddler_client/inference_parameters.py b/paddler_client_python/paddler_client/inference_parameters.py index 44e674f2..85e9f265 100644 --- a/paddler_client_python/paddler_client/inference_parameters.py +++ b/paddler_client_python/paddler_client/inference_parameters.py @@ -1,14 +1,19 @@ from pydantic import BaseModel +from paddler_client.kv_cache_dtype import KvCacheDtype from paddler_client.pooling_type import PoolingType class InferenceParameters(BaseModel): n_batch: int = 2048 context_size: int = 8192 + embedding_batch_size: int = 256 enable_embeddings: bool = False image_resize_to_fit: int = 1024 + k_cache_dtype: KvCacheDtype = KvCacheDtype.Q8_0 + v_cache_dtype: KvCacheDtype = KvCacheDtype.Q8_0 min_p: float = 0.05 + n_gpu_layers: int = 0 penalty_frequency: float = 0.0 penalty_last_n: int = -1 penalty_presence: float = 0.8 diff --git a/paddler_client_python/paddler_client/kv_cache_dtype.py b/paddler_client_python/paddler_client/kv_cache_dtype.py new file mode 100644 index 00000000..44650eda --- /dev/null +++ b/paddler_client_python/paddler_client/kv_cache_dtype.py @@ -0,0 +1,13 @@ +from enum import StrEnum + + +class KvCacheDtype(StrEnum): + F32 = "F32" + F16 = "F16" + BF16 = "BF16" + Q8_0 = "Q8_0" + Q4_0 = "Q4_0" + Q4_1 = "Q4_1" + IQ4_NL = "IQ4_NL" + Q5_0 = "Q5_0" + Q5_1 = "Q5_1" diff --git a/paddler_client_python/paddler_client/url_model_reference.py b/paddler_client_python/paddler_client/url_model_reference.py new file mode 100644 index 00000000..a760f80c --- /dev/null +++ b/paddler_client_python/paddler_client/url_model_reference.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class UrlModelReference(BaseModel): + url: str diff --git a/paddler_client_python/tests/test_agent_controller_snapshot.py b/paddler_client_python/tests/test_agent_controller_snapshot.py index 5e3aef85..f806e997 100644 --- a/paddler_client_python/tests/test_agent_controller_snapshot.py +++ b/paddler_client_python/tests/test_agent_controller_snapshot.py @@ -13,6 +13,7 @@ def test_agent_controller_snapshot_deserialization() -> None: "desired_slots_total": 4, "download_current": 100, "download_filename": "model.gguf", + "download_indeterminate": False, "download_total": 1000, "id": "agent-1", "issues": [{"SlotCannotStart": {"error": "OOM", "slot_index": 0}}], @@ -39,6 +40,7 @@ def test_agent_controller_pool_snapshot_deserialization() -> None: { "desired_slots_total": 2, "download_current": 0, + "download_indeterminate": True, "download_total": 0, "id": "a1", "issues": [], diff --git a/paddler_client_python/tests/test_agent_desired_model.py b/paddler_client_python/tests/test_agent_desired_model.py index 335a192f..0665e497 100644 --- a/paddler_client_python/tests/test_agent_desired_model.py +++ b/paddler_client_python/tests/test_agent_desired_model.py @@ -4,6 +4,7 @@ from paddler_client.huggingface_model_reference import ( HuggingFaceModelReference, ) +from paddler_client.url_model_reference import UrlModelReference def test_agent_desired_model_none_serialization() -> None: @@ -87,3 +88,28 @@ def test_agent_desired_model_local_to_agent_missing_path_raises() -> None: with pytest.raises(ValueError, match="local_path is required"): model.model_dump(mode="json") + + +def test_agent_desired_model_url_serialization() -> None: + reference = UrlModelReference(url="https://example.com/model.gguf") + model = AgentDesiredModel.from_url(reference) + dumped = model.model_dump(mode="json") + + assert dumped == {"Url": {"url": "https://example.com/model.gguf"}} + + +def test_agent_desired_model_url_deserialization() -> None: + model = AgentDesiredModel.model_validate( + {"Url": {"url": "https://example.com/model.gguf"}} + ) + + assert model.variant == "Url" + assert model.url is not None + assert model.url.url == "https://example.com/model.gguf" + + +def test_agent_desired_model_url_missing_reference_raises() -> None: + model = AgentDesiredModel(variant="Url", url=None) + + with pytest.raises(ValueError, match="url is required"): + model.model_dump(mode="json") diff --git a/paddler_client_python/tests/test_client_management.py b/paddler_client_python/tests/test_client_management.py index 603f2903..56bf8713 100644 --- a/paddler_client_python/tests/test_client_management.py +++ b/paddler_client_python/tests/test_client_management.py @@ -13,6 +13,7 @@ def _agent_snapshot_json() -> dict[str, object]: "desired_slots_total": 4, "download_current": 0, "download_filename": None, + "download_indeterminate": False, "download_total": 0, "id": "agent-1", "issues": [], diff --git a/paddler_client_python/tests/test_integration_inference.py b/paddler_client_python/tests/test_integration_inference.py index ccb03a69..317fc157 100644 --- a/paddler_client_python/tests/test_integration_inference.py +++ b/paddler_client_python/tests/test_integration_inference.py @@ -60,7 +60,7 @@ async def test_http_continue_from_conversation_history( ): _assert_not_error(message) - if message.kind == InferenceMessageKind.CONTENT_TOKEN: + if message.is_token: assert message.token is not None tokens.append(message.token) elif message.is_terminal: @@ -89,7 +89,7 @@ async def test_websocket_continue_from_conversation_history( async for message in stream: _assert_not_error(message) - if message.kind == InferenceMessageKind.CONTENT_TOKEN: + if message.is_token: assert message.token is not None tokens.append(message.token) elif message.is_terminal: @@ -114,7 +114,7 @@ async def test_websocket_continue_from_raw_prompt( async for message in stream: _assert_not_error(message) - if message.kind == InferenceMessageKind.CONTENT_TOKEN: + if message.is_token: assert message.token is not None tokens.append(message.token) elif message.is_terminal: diff --git a/paddler_client_python/tests/test_kv_cache_dtype.py b/paddler_client_python/tests/test_kv_cache_dtype.py new file mode 100644 index 00000000..b023c7ab --- /dev/null +++ b/paddler_client_python/tests/test_kv_cache_dtype.py @@ -0,0 +1,13 @@ +from paddler_client.kv_cache_dtype import KvCacheDtype + + +def test_kv_cache_dtype_values() -> None: + assert KvCacheDtype.F32.value == "F32" + assert KvCacheDtype.F16.value == "F16" + assert KvCacheDtype.BF16.value == "BF16" + assert KvCacheDtype.Q8_0.value == "Q8_0" + assert KvCacheDtype.Q4_0.value == "Q4_0" + assert KvCacheDtype.Q4_1.value == "Q4_1" + assert KvCacheDtype.IQ4_NL.value == "IQ4_NL" + assert KvCacheDtype.Q5_0.value == "Q5_0" + assert KvCacheDtype.Q5_1.value == "Q5_1" diff --git a/paddler_download_manager/Cargo.toml b/paddler_download_manager/Cargo.toml new file mode 100644 index 00000000..e44fa4e6 --- /dev/null +++ b/paddler_download_manager/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "paddler_download_manager" +authors.workspace = true +description.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +anyhow = { workspace = true } +bytes = { workspace = true } +futures-util = { workspace = true } +headers = { workspace = true } +reqwest = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } +url = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true diff --git a/paddler_download_manager/src/download_attempt_error.rs b/paddler_download_manager/src/download_attempt_error.rs new file mode 100644 index 00000000..0aec8a8c --- /dev/null +++ b/paddler_download_manager/src/download_attempt_error.rs @@ -0,0 +1,31 @@ +use std::io; + +use reqwest::StatusCode; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum DownloadAttemptError { + #[error("client error: {0}")] + ClientError(StatusCode), + + #[error("io")] + Io(#[from] io::Error), + + #[error("not found")] + NotFound, + + #[error("partial file stale")] + PartialFileStale, + + #[error("permission denied: {0}")] + PermissionDenied(StatusCode), + + #[error("server returned error status: {0}")] + ServerError(StatusCode), + + #[error("download interrupted: {0}")] + Interrupted(anyhow::Error), + + #[error("server unreachable: {0}")] + Unreachable(anyhow::Error), +} diff --git a/paddler_download_manager/src/download_error.rs b/paddler_download_manager/src/download_error.rs new file mode 100644 index 00000000..35423ef9 --- /dev/null +++ b/paddler_download_manager/src/download_error.rs @@ -0,0 +1,76 @@ +use std::io; +use std::path::PathBuf; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum DownloadError { + #[error("URL '{url}' is malformed: {source}")] + InvalidUrl { + url: String, + #[source] + source: url::ParseError, + }, + + #[error("URL '{url}' has unsupported scheme '{scheme}'; expected http or https")] + UnsupportedUrlScheme { url: String, scheme: String }, + + #[error("URL '{url}' returned 404 Not Found")] + NotFound { url: String }, + + #[error("URL '{url}' returned {status}")] + PermissionDenied { + url: String, + status: reqwest::StatusCode, + }, + + #[error("URL '{url}' returned 416 Range Not Satisfiable; '{partial_path_display}' was discarded", partial_path_display = partial_path.display())] + PartialFileStale { url: String, partial_path: PathBuf }, + + #[error("server unreachable for URL '{url}': {source}")] + DownloadServerIsUnreachable { + url: String, + #[source] + source: anyhow::Error, + }, + + #[error("server returned error status {status} for URL '{url}'")] + DownloadServerErrored { + url: String, + status: reqwest::StatusCode, + }, + + #[error("server rejected request to URL '{url}' with status {status}")] + DownloadServerRejectedRequest { + url: String, + status: reqwest::StatusCode, + }, + + #[error("download interrupted while downloading URL '{url}': {source}")] + DownloadInterrupted { + url: String, + #[source] + source: anyhow::Error, + }, + + #[error("I/O on '{path_display}': {source}", path_display = path.display())] + Io { + path: PathBuf, + #[source] + source: io::Error, + }, + + #[error("cache write denied at '{path_display}': {source}", path_display = path.display())] + CachePermissionDenied { + path: PathBuf, + #[source] + source: io::Error, + }, + + #[error("cache disk full at '{path_display}': {source}", path_display = path.display())] + CacheDiskFull { + path: PathBuf, + #[source] + source: io::Error, + }, +} diff --git a/paddler_download_manager/src/download_manager.rs b/paddler_download_manager/src/download_manager.rs new file mode 100644 index 00000000..e06d8125 --- /dev/null +++ b/paddler_download_manager/src/download_manager.rs @@ -0,0 +1,210 @@ +use std::io; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use headers::ContentRange; +use headers::HeaderMapExt as _; +use reqwest::Client; +use reqwest::Url; +use reqwest::header::RANGE; + +use crate::download_attempt_error::DownloadAttemptError; +use crate::download_error::DownloadError; +use crate::partial_file::PartialFile; +use crate::progress_sink::ProgressSink; +use crate::response_classification::ResponseClassification; +use crate::stream_to_partial_file::stream_to_partial_file; +use crate::stream_to_partial_file_error::StreamToPartialFileError; + +#[cfg(unix)] +fn is_disk_full(error: &io::Error) -> bool { + error.raw_os_error() == Some(28) +} + +#[cfg(windows)] +fn is_disk_full(error: &io::Error) -> bool { + error.raw_os_error() == Some(112) +} + +fn classify_cache_failure(path: PathBuf, source: io::Error) -> DownloadError { + if source.kind() == io::ErrorKind::PermissionDenied { + DownloadError::CachePermissionDenied { path, source } + } else if is_disk_full(&source) { + DownloadError::CacheDiskFull { path, source } + } else { + DownloadError::Io { path, source } + } +} + +pub struct DownloadManager { + client: Client, +} + +impl DownloadManager { + pub fn new() -> Result { + let client = Client::builder() + .connect_timeout(Duration::from_secs(10)) + .read_timeout(Duration::from_secs(10)) + .build()?; + + Ok(Self { client }) + } + + pub async fn download( + &self, + url: &str, + final_path: &Path, + progress_sink: Arc, + ) -> Result<(), DownloadError> { + let parsed_url = Url::parse(url).map_err(|parse_error| DownloadError::InvalidUrl { + url: url.to_owned(), + source: parse_error, + })?; + + if !matches!(parsed_url.scheme(), "http" | "https") { + return Err(DownloadError::UnsupportedUrlScheme { + url: url.to_owned(), + scheme: parsed_url.scheme().to_owned(), + }); + } + + let partial = PartialFile::new(final_path.to_path_buf()); + + match self.attempt_download(url, &partial, &progress_sink).await { + Ok(()) => Ok(()), + Err(DownloadAttemptError::Unreachable(source)) => { + Err(DownloadError::DownloadServerIsUnreachable { + url: url.to_owned(), + source, + }) + } + Err(DownloadAttemptError::ServerError(status)) => { + Err(DownloadError::DownloadServerErrored { + url: url.to_owned(), + status, + }) + } + Err(DownloadAttemptError::ClientError(status)) => { + Err(DownloadError::DownloadServerRejectedRequest { + url: url.to_owned(), + status, + }) + } + Err(DownloadAttemptError::Interrupted(source)) => { + Err(DownloadError::DownloadInterrupted { + url: url.to_owned(), + source, + }) + } + Err(DownloadAttemptError::NotFound) => Err(DownloadError::NotFound { + url: url.to_owned(), + }), + Err(DownloadAttemptError::PermissionDenied(status)) => { + Err(DownloadError::PermissionDenied { + url: url.to_owned(), + status, + }) + } + Err(DownloadAttemptError::PartialFileStale) => Err(DownloadError::PartialFileStale { + url: url.to_owned(), + partial_path: partial.partial_path.clone(), + }), + Err(DownloadAttemptError::Io(io_error)) => Err(classify_cache_failure( + partial.partial_path.clone(), + io_error, + )), + } + } + + async fn attempt_download( + &self, + url: &str, + partial: &PartialFile, + progress_sink: &Arc, + ) -> Result<(), DownloadAttemptError> { + let mut offset = partial.current_size().await?; + let sent_range_header = offset > 0; + + let mut request = self.client.get(url); + if sent_range_header { + request = request.header(RANGE, format!("bytes={offset}-")); + } + + let response = match request.send().await { + Ok(response) => response, + Err(send_error) => { + return Err(DownloadAttemptError::Unreachable(anyhow::Error::new(send_error))); + } + }; + + let classification = + ResponseClassification::from_status(response.status(), sent_range_header); + + match classification { + ResponseClassification::NotFound => return Err(DownloadAttemptError::NotFound), + ResponseClassification::PermissionDenied(status) => { + return Err(DownloadAttemptError::PermissionDenied(status)); + } + ResponseClassification::PartialFileStale => { + partial.remove().await?; + return Err(DownloadAttemptError::PartialFileStale); + } + ResponseClassification::ServerError(status) => { + return Err(DownloadAttemptError::ServerError(status)); + } + ResponseClassification::ClientError(status) => { + return Err(DownloadAttemptError::ClientError(status)); + } + ResponseClassification::StreamFromStartIgnoringRange => { + partial.truncate().await?; + offset = 0; + } + ResponseClassification::StreamFromCurrentOffset + | ResponseClassification::StreamFromStart => {} + } + + if matches!(classification, ResponseClassification::StreamFromCurrentOffset) { + let server_start = response + .headers() + .typed_get::() + .and_then(|content_range| content_range.bytes_range()) + .map(|(start, _end)| start); + + if server_start != Some(offset) { + partial.remove().await?; + + return Err(DownloadAttemptError::PartialFileStale); + } + } + + let total = response + .content_length() + .map(|content_length| offset + content_length); + progress_sink.on_started(total, offset); + + let mut file = partial.open_for_append().await?; + + match stream_to_partial_file(response.bytes_stream(), &mut file, progress_sink).await { + Ok(()) => {} + Err(StreamToPartialFileError::Stream(stream_error)) => { + return Err(DownloadAttemptError::Interrupted(anyhow::Error::new( + stream_error, + ))); + } + Err(StreamToPartialFileError::Write(write_error)) => { + return Err(DownloadAttemptError::Io(write_error)); + } + } + + drop(file); + + partial.finalize().await?; + progress_sink.on_finished(); + + Ok(()) + } +} + + diff --git a/paddler_download_manager/src/lib.rs b/paddler_download_manager/src/lib.rs new file mode 100644 index 00000000..8eefba44 --- /dev/null +++ b/paddler_download_manager/src/lib.rs @@ -0,0 +1,8 @@ +pub mod download_attempt_error; +pub mod download_error; +pub mod download_manager; +pub mod partial_file; +pub mod progress_sink; +pub mod response_classification; +pub mod stream_to_partial_file; +pub mod stream_to_partial_file_error; diff --git a/paddler_download_manager/src/partial_file.rs b/paddler_download_manager/src/partial_file.rs new file mode 100644 index 00000000..335ac57f --- /dev/null +++ b/paddler_download_manager/src/partial_file.rs @@ -0,0 +1,331 @@ +use std::io; +use std::path::Path; +use std::path::PathBuf; + +use tokio::fs; +use tokio::fs::File; +use tokio::fs::OpenOptions; + +const PARTIAL_EXTENSION: &str = "partial"; + +pub struct PartialFile { + pub final_path: PathBuf, + pub partial_path: PathBuf, +} + +impl PartialFile { + #[must_use] + pub fn new(final_path: PathBuf) -> Self { + let partial_path = final_path.with_extension(PARTIAL_EXTENSION); + + Self { + final_path, + partial_path, + } + } + + pub async fn current_size(&self) -> Result { + match fs::metadata(&self.partial_path).await { + Ok(metadata) => Ok(metadata.len()), + Err(metadata_error) if metadata_error.kind() == io::ErrorKind::NotFound => Ok(0), + Err(metadata_error) => Err(metadata_error), + } + } + + pub async fn open_for_append(&self) -> Result { + self.ensure_partial_parent_exists().await?; + + OpenOptions::new() + .append(true) + .create(true) + .open(&self.partial_path) + .await + } + + pub async fn truncate(&self) -> Result<(), io::Error> { + self.ensure_partial_parent_exists().await?; + + OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&self.partial_path) + .await?; + + Ok(()) + } + + pub async fn finalize(&self) -> Result<(), io::Error> { + fs::rename(&self.partial_path, &self.final_path).await + } + + pub async fn remove(&self) -> Result<(), io::Error> { + match fs::remove_file(&self.partial_path).await { + Ok(()) => Ok(()), + Err(remove_error) if remove_error.kind() == io::ErrorKind::NotFound => Ok(()), + Err(remove_error) => Err(remove_error), + } + } + + async fn ensure_partial_parent_exists(&self) -> Result<(), io::Error> { + let parent = self + .partial_path + .parent() + .unwrap_or_else(|| Path::new(".")); + + fs::create_dir_all(parent).await + } +} + +#[cfg(test)] +mod tests { + use tempfile::TempDir; + use tokio::io::AsyncWriteExt; + + use crate::partial_file::PartialFile; + + #[tokio::test] + async fn current_size_returns_zero_when_missing() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + + let size = partial.current_size().await.unwrap(); + + assert_eq!(size, 0); + } + + #[tokio::test] + async fn current_size_returns_existing_size() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"twelve bytes") + .await + .unwrap(); + + let size = partial.current_size().await.unwrap(); + + assert_eq!(size, 12); + } + + #[tokio::test] + async fn open_for_append_creates_when_missing() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + + let mut file = partial.open_for_append().await.unwrap(); + file.write_all(b"hello").await.unwrap(); + file.flush().await.unwrap(); + + let bytes = tokio::fs::read(&partial.partial_path).await.unwrap(); + assert_eq!(bytes, b"hello"); + } + + #[tokio::test] + async fn open_for_append_appends_to_existing() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"first") + .await + .unwrap(); + + let mut file = partial.open_for_append().await.unwrap(); + file.write_all(b"-second").await.unwrap(); + file.flush().await.unwrap(); + + let bytes = tokio::fs::read(&partial.partial_path).await.unwrap(); + assert_eq!(bytes, b"first-second"); + } + + #[tokio::test] + async fn truncate_resets_to_zero() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"keep me?") + .await + .unwrap(); + + partial.truncate().await.unwrap(); + + let size = partial.current_size().await.unwrap(); + assert_eq!(size, 0); + } + + #[tokio::test] + async fn finalize_renames_partial_to_final() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"complete") + .await + .unwrap(); + let final_path = partial.final_path.clone(); + + partial.finalize().await.unwrap(); + + let exists = tokio::fs::try_exists(&final_path).await.unwrap(); + assert!(exists); + let bytes = tokio::fs::read(&final_path).await.unwrap(); + assert_eq!(bytes, b"complete"); + } + + #[tokio::test] + async fn remove_deletes_partial() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"go away") + .await + .unwrap(); + let partial_path = partial.partial_path.clone(); + + partial.remove().await.unwrap(); + + let exists = tokio::fs::try_exists(&partial_path).await.unwrap(); + assert!(!exists); + } + + #[tokio::test] + async fn remove_is_noop_when_missing() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + + partial.remove().await.unwrap(); + } + + #[cfg(unix)] + #[tokio::test] + async fn current_size_propagates_non_notfound_error() { + let directory = TempDir::new().unwrap(); + let blocking_file = directory.path().join("blocker"); + tokio::fs::write(&blocking_file, b"a regular file") + .await + .unwrap(); + let partial = PartialFile::new(blocking_file.join("subdir").join("model.gguf")); + + let result = partial.current_size().await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn truncate_returns_io_error_when_partial_is_a_directory() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::create_dir(&partial.partial_path).await.unwrap(); + + let result = partial.truncate().await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn open_for_append_returns_io_error_when_partial_is_a_directory() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::create_dir(&partial.partial_path).await.unwrap(); + + let result = partial.open_for_append().await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn finalize_returns_io_error_when_final_is_a_non_empty_directory() { + let directory = TempDir::new().unwrap(); + let partial = PartialFile::new(directory.path().join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"complete") + .await + .unwrap(); + tokio::fs::create_dir(&partial.final_path).await.unwrap(); + tokio::fs::write(partial.final_path.join("blocker"), b"x") + .await + .unwrap(); + + let result = partial.finalize().await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn remove_propagates_non_notfound_error() { + use std::os::unix::fs::PermissionsExt; + + let directory = TempDir::new().unwrap(); + let locked_parent = directory.path().join("locked"); + tokio::fs::create_dir(&locked_parent).await.unwrap(); + let partial = PartialFile::new(locked_parent.join("model.gguf")); + tokio::fs::write(&partial.partial_path, b"go away") + .await + .unwrap(); + let mut perms = tokio::fs::metadata(&locked_parent) + .await + .unwrap() + .permissions(); + perms.set_mode(0o500); + tokio::fs::set_permissions(&locked_parent, perms) + .await + .unwrap(); + + let result = partial.remove().await; + + let mut restore = tokio::fs::metadata(&locked_parent) + .await + .unwrap() + .permissions(); + restore.set_mode(0o700); + tokio::fs::set_permissions(&locked_parent, restore) + .await + .unwrap(); + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn open_for_append_fails_when_parent_blocked_by_file() { + let directory = TempDir::new().unwrap(); + let blocker = directory.path().join("blocker"); + tokio::fs::write(&blocker, b"i am a file").await.unwrap(); + let partial = PartialFile::new(blocker.join("subdir").join("model.gguf")); + + let result = partial.open_for_append().await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn truncate_fails_when_parent_blocked_by_file() { + let directory = TempDir::new().unwrap(); + let blocker = directory.path().join("blocker"); + tokio::fs::write(&blocker, b"i am a file").await.unwrap(); + let partial = PartialFile::new(blocker.join("subdir").join("model.gguf")); + + let result = partial.truncate().await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn finalize_returns_io_error_when_parent_was_deleted_mid_download() { + let directory = TempDir::new().unwrap(); + let cache_subdir = directory.path().join("model-cache"); + let dest = cache_subdir.join("model.gguf"); + let partial = PartialFile::new(dest); + + tokio::fs::create_dir_all(&cache_subdir).await.unwrap(); + let mut file = partial.open_for_append().await.unwrap(); + file.write_all(b"partial data").await.unwrap(); + file.flush().await.unwrap(); + drop(file); + + tokio::fs::remove_dir_all(&cache_subdir).await.unwrap(); + + let result = partial.finalize().await; + + assert!(result.is_err()); + } +} diff --git a/paddler_download_manager/src/progress_sink.rs b/paddler_download_manager/src/progress_sink.rs new file mode 100644 index 00000000..e757358f --- /dev/null +++ b/paddler_download_manager/src/progress_sink.rs @@ -0,0 +1,5 @@ +pub trait ProgressSink: Send + Sync { + fn on_started(&self, total_bytes: Option, already_downloaded: u64); + fn on_chunk(&self, additional_bytes: u64); + fn on_finished(&self); +} diff --git a/paddler_download_manager/src/response_classification.rs b/paddler_download_manager/src/response_classification.rs new file mode 100644 index 00000000..3383b705 --- /dev/null +++ b/paddler_download_manager/src/response_classification.rs @@ -0,0 +1,142 @@ +use reqwest::StatusCode; + +#[derive(Debug, Eq, PartialEq)] +pub enum ResponseClassification { + ClientError(StatusCode), + NotFound, + PartialFileStale, + PermissionDenied(StatusCode), + ServerError(StatusCode), + StreamFromCurrentOffset, + StreamFromStart, + StreamFromStartIgnoringRange, +} + +impl ResponseClassification { + #[must_use] + pub fn from_status(status: StatusCode, sent_range_header: bool) -> Self { + if status == StatusCode::PARTIAL_CONTENT { + return Self::StreamFromCurrentOffset; + } + + if status == StatusCode::OK { + if sent_range_header { + return Self::StreamFromStartIgnoringRange; + } + return Self::StreamFromStart; + } + + if status == StatusCode::NOT_FOUND { + return Self::NotFound; + } + + if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN { + return Self::PermissionDenied(status); + } + + if status == StatusCode::RANGE_NOT_SATISFIABLE { + return Self::PartialFileStale; + } + + if status.is_client_error() { + return Self::ClientError(status); + } + + Self::ServerError(status) + } +} + +#[cfg(test)] +mod tests { + use reqwest::StatusCode; + + use crate::response_classification::ResponseClassification; + + #[test] + fn from_status_206_returns_stream_from_current_offset() { + assert_eq!( + ResponseClassification::from_status(StatusCode::PARTIAL_CONTENT, true), + ResponseClassification::StreamFromCurrentOffset + ); + } + + #[test] + fn from_status_200_on_range_request_returns_stream_from_start_ignoring_range() { + assert_eq!( + ResponseClassification::from_status(StatusCode::OK, true), + ResponseClassification::StreamFromStartIgnoringRange + ); + } + + #[test] + fn from_status_200_on_plain_request_returns_stream_from_start() { + assert_eq!( + ResponseClassification::from_status(StatusCode::OK, false), + ResponseClassification::StreamFromStart + ); + } + + #[test] + fn from_status_404_returns_not_found() { + assert_eq!( + ResponseClassification::from_status(StatusCode::NOT_FOUND, false), + ResponseClassification::NotFound + ); + } + + #[test] + fn from_status_401_returns_permission_denied() { + assert_eq!( + ResponseClassification::from_status(StatusCode::UNAUTHORIZED, false), + ResponseClassification::PermissionDenied(StatusCode::UNAUTHORIZED) + ); + } + + #[test] + fn from_status_403_returns_permission_denied() { + assert_eq!( + ResponseClassification::from_status(StatusCode::FORBIDDEN, false), + ResponseClassification::PermissionDenied(StatusCode::FORBIDDEN) + ); + } + + #[test] + fn from_status_416_returns_partial_file_stale() { + assert_eq!( + ResponseClassification::from_status(StatusCode::RANGE_NOT_SATISFIABLE, true), + ResponseClassification::PartialFileStale + ); + } + + #[test] + fn from_status_400_returns_client_error() { + assert_eq!( + ResponseClassification::from_status(StatusCode::BAD_REQUEST, false), + ResponseClassification::ClientError(StatusCode::BAD_REQUEST) + ); + } + + #[test] + fn from_status_429_returns_client_error() { + assert_eq!( + ResponseClassification::from_status(StatusCode::TOO_MANY_REQUESTS, false), + ResponseClassification::ClientError(StatusCode::TOO_MANY_REQUESTS) + ); + } + + #[test] + fn from_status_503_returns_server_error() { + assert_eq!( + ResponseClassification::from_status(StatusCode::SERVICE_UNAVAILABLE, false), + ResponseClassification::ServerError(StatusCode::SERVICE_UNAVAILABLE) + ); + } + + #[test] + fn from_status_500_returns_server_error() { + assert_eq!( + ResponseClassification::from_status(StatusCode::INTERNAL_SERVER_ERROR, false), + ResponseClassification::ServerError(StatusCode::INTERNAL_SERVER_ERROR) + ); + } +} diff --git a/paddler_download_manager/src/stream_to_partial_file.rs b/paddler_download_manager/src/stream_to_partial_file.rs new file mode 100644 index 00000000..349faa6f --- /dev/null +++ b/paddler_download_manager/src/stream_to_partial_file.rs @@ -0,0 +1,171 @@ +use std::sync::Arc; + +use bytes::Bytes; +use futures_util::Stream; +use futures_util::StreamExt as _; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt as _; + +use crate::progress_sink::ProgressSink; +use crate::stream_to_partial_file_error::StreamToPartialFileError; + +pub async fn stream_to_partial_file( + mut body_stream: TStream, + writer: &mut TWriter, + progress_sink: &Arc, +) -> Result<(), StreamToPartialFileError> +where + TStream: Stream> + Unpin, + TWriter: AsyncWrite + Unpin, +{ + while let Some(next_chunk) = body_stream.next().await { + let bytes = next_chunk.map_err(StreamToPartialFileError::Stream)?; + + writer + .write_all(&bytes) + .await + .map_err(StreamToPartialFileError::Write)?; + + progress_sink.on_chunk(bytes.len() as u64); + } + + writer + .flush() + .await + .map_err(StreamToPartialFileError::Write)?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::AtomicU64; + use std::sync::atomic::Ordering; + + use bytes::Bytes; + use futures_util::stream; + use tempfile::TempDir; + use tokio::fs::OpenOptions; + + use crate::progress_sink::ProgressSink; + use crate::stream_to_partial_file::stream_to_partial_file; + + struct CountingSink { + chunks: AtomicU64, + bytes: AtomicU64, + } + + impl CountingSink { + fn new() -> Self { + Self { + bytes: AtomicU64::new(0), + chunks: AtomicU64::new(0), + } + } + } + + impl ProgressSink for CountingSink { + fn on_started(&self, _total_bytes: Option, _already_downloaded: u64) {} + fn on_chunk(&self, additional_bytes: u64) { + self.bytes.fetch_add(additional_bytes, Ordering::Relaxed); + self.chunks.fetch_add(1, Ordering::Relaxed); + } + fn on_finished(&self) {} + } + + #[test] + fn counting_sink_lifecycle_methods_are_inert() { + let sink = CountingSink::new(); + + sink.on_started(Some(1024), 0); + sink.on_finished(); + + assert_eq!(sink.chunks.load(Ordering::Relaxed), 0); + assert_eq!(sink.bytes.load(Ordering::Relaxed), 0); + } + + #[tokio::test] + async fn writes_every_chunk_in_order() { + let directory = TempDir::new().unwrap(); + let path = directory.path().join("dest.bin"); + let chunks: Vec> = vec![ + Ok(Bytes::from_static(b"first")), + Ok(Bytes::from_static(b"second")), + ]; + let body_stream = stream::iter(chunks); + let mut file = OpenOptions::new() + .append(true) + .create(true) + .open(&path) + .await + .unwrap(); + let sink: Arc = Arc::new(CountingSink::new()); + + stream_to_partial_file(body_stream, &mut file, &sink) + .await + .unwrap(); + + let bytes = tokio::fs::read(&path).await.unwrap(); + assert_eq!(bytes, b"firstsecond"); + } + + #[tokio::test] + async fn calls_progress_sink_once_per_chunk() { + let directory = TempDir::new().unwrap(); + let path = directory.path().join("dest.bin"); + let chunks: Vec> = vec![ + Ok(Bytes::from_static(b"aaa")), + Ok(Bytes::from_static(b"bb")), + Ok(Bytes::from_static(b"c")), + ]; + let body_stream = stream::iter(chunks); + let mut file = OpenOptions::new() + .append(true) + .create(true) + .open(&path) + .await + .unwrap(); + let counting = Arc::new(CountingSink::new()); + let sink: Arc = counting.clone(); + + stream_to_partial_file(body_stream, &mut file, &sink) + .await + .unwrap(); + + assert_eq!(counting.chunks.load(Ordering::Relaxed), 3); + assert_eq!(counting.bytes.load(Ordering::Relaxed), 6); + } + + #[tokio::test] + async fn write_to_closed_duplex_returns_error() { + let (reader_half, mut writer_half) = tokio::io::duplex(0); + drop(reader_half); + + let chunks: Vec> = + vec![Ok(Bytes::from_static(b"data"))]; + let body_stream = stream::iter(chunks); + let sink: Arc = Arc::new(CountingSink::new()); + + let result = stream_to_partial_file(body_stream, &mut writer_half, &sink).await; + + assert!(result.is_err()); + } + + #[cfg(unix)] + #[tokio::test] + async fn flush_to_read_only_file_returns_error() { + let directory = TempDir::new().unwrap(); + let path = directory.path().join("read_only.bin"); + tokio::fs::write(&path, b"existing").await.unwrap(); + let chunks: Vec> = + vec![Ok(Bytes::from_static(b"more bytes"))]; + let body_stream = stream::iter(chunks); + let mut read_only_file = OpenOptions::new().read(true).open(&path).await.unwrap(); + let sink: Arc = Arc::new(CountingSink::new()); + + let result = stream_to_partial_file(body_stream, &mut read_only_file, &sink).await; + + assert!(result.is_err()); + } +} diff --git a/paddler_download_manager/src/stream_to_partial_file_error.rs b/paddler_download_manager/src/stream_to_partial_file_error.rs new file mode 100644 index 00000000..b2f53184 --- /dev/null +++ b/paddler_download_manager/src/stream_to_partial_file_error.rs @@ -0,0 +1,12 @@ +use std::io; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum StreamToPartialFileError { + #[error("stream error: {0}")] + Stream(#[source] reqwest::Error), + + #[error("write error: {0}")] + Write(#[source] io::Error), +} diff --git a/paddler_download_manager/tests/download.rs b/paddler_download_manager/tests/download.rs new file mode 100644 index 00000000..7012ea5f --- /dev/null +++ b/paddler_download_manager/tests/download.rs @@ -0,0 +1,748 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; + +use anyhow::Result; +use anyhow::bail; +use paddler_download_manager::download_error::DownloadError; +use paddler_download_manager::download_manager::DownloadManager; +use paddler_download_manager::progress_sink::ProgressSink; + +use tempfile::TempDir; + +use crate::local_http_fixture::FixtureResponse; +use crate::local_http_fixture::LocalHttpFixture; +use crate::local_http_fixture::Scenario; + +mod local_http_fixture; + +struct RecordingSink { + chunk_count: AtomicU64, + chunk_bytes: AtomicU64, + finished_count: AtomicU64, + started_total: AtomicU64, + started_total_indeterminate: AtomicBool, + started_already: AtomicU64, +} + +impl RecordingSink { + const fn new() -> Self { + Self { + chunk_bytes: AtomicU64::new(0), + chunk_count: AtomicU64::new(0), + finished_count: AtomicU64::new(0), + started_already: AtomicU64::new(0), + started_total: AtomicU64::new(0), + started_total_indeterminate: AtomicBool::new(true), + } + } +} + +impl ProgressSink for RecordingSink { + fn on_started(&self, total_bytes: Option, already_downloaded: u64) { + match total_bytes { + Some(value) => { + self.started_total.store(value, Ordering::Relaxed); + self.started_total_indeterminate + .store(false, Ordering::Relaxed); + } + None => { + self.started_total_indeterminate + .store(true, Ordering::Relaxed); + } + } + self.started_already + .store(already_downloaded, Ordering::Relaxed); + } + fn on_chunk(&self, additional_bytes: u64) { + self.chunk_bytes + .fetch_add(additional_bytes, Ordering::Relaxed); + self.chunk_count.fetch_add(1, Ordering::Relaxed); + } + fn on_finished(&self) { + self.finished_count.fetch_add(1, Ordering::Relaxed); + } +} + +#[tokio::test] +async fn streams_200_response_to_disk_and_calls_progress_sink_per_chunk() -> Result<()> { + let directory = TempDir::new()?; + let body = b"Hello, GGUF world!".to_vec(); + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(body.clone()))).await?; + let sink = Arc::new(RecordingSink::new()); + let progress_sink: Arc = sink.clone(); + let dest = directory.path().join("model.gguf"); + + DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, progress_sink) + .await?; + + assert_eq!(tokio::fs::read(&dest).await?, body); + assert_eq!( + sink.started_total.load(Ordering::Relaxed), + body.len() as u64 + ); + assert_eq!(sink.started_already.load(Ordering::Relaxed), 0); + assert_eq!(sink.chunk_bytes.load(Ordering::Relaxed), body.len() as u64); + assert!(sink.chunk_count.load(Ordering::Relaxed) >= 1); + assert_eq!(sink.finished_count.load(Ordering::Relaxed), 1); + + Ok(()) +} + +#[tokio::test] +async fn resumes_from_existing_partial_file_with_range_request() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::write(&partial_path, b"first half ").await?; + + let body = b"second half".to_vec(); + let total = 11_u64 + body.len() as u64; + let fixture = LocalHttpFixture::start(Scenario::always( + FixtureResponse::partial_content_with_range( + body.clone(), + format!("bytes 11-{}/{}", total - 1, total), + ), + )) + .await?; + let sink = Arc::new(RecordingSink::new()); + let progress_sink: Arc = sink.clone(); + + DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, progress_sink) + .await?; + + assert_eq!(tokio::fs::read(&dest).await?, b"first half second half"); + assert_eq!(sink.started_already.load(Ordering::Relaxed), 11); + assert!( + fixture + .last_recorded_range_header() + .unwrap_or_default() + .contains("bytes=11-") + ); + + Ok(()) +} + +#[tokio::test] +async fn starts_over_when_server_returns_200_to_range_request() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::write(&partial_path, b"stale partial bytes").await?; + + let body = b"fresh entire body".to_vec(); + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(body.clone()))).await?; + let sink = Arc::new(RecordingSink::new()); + let progress_sink: Arc = sink.clone(); + + DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, progress_sink) + .await?; + + assert_eq!(tokio::fs::read(&dest).await?, body); + + Ok(()) +} + +#[tokio::test] +async fn returns_not_found_on_404_without_retrying() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(404))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/missing.gguf"), &dest, sink) + .await; + + assert!(matches!(result, Err(DownloadError::NotFound { .. }))); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn returns_permission_denied_on_401_without_retrying() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(401))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/private.gguf"), &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::PermissionDenied { .. }) + )); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn returns_permission_denied_on_403_without_retrying() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(403))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/forbidden.gguf"), &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::PermissionDenied { .. }) + )); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn returns_partial_file_stale_on_416_and_removes_partial() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::write(&partial_path, b"stale").await?; + + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(416))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::PartialFileStale { .. }) + )); + assert!(!tokio::fs::try_exists(&partial_path).await?); + + Ok(()) +} + +#[tokio::test] +async fn returns_partial_file_stale_on_416_even_when_no_partial_existed() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(416))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::PartialFileStale { .. }) + )); + assert!(!tokio::fs::try_exists(&partial_path).await?); + + Ok(()) +} + +#[tokio::test] +async fn mismatched_content_range_is_treated_as_partial_file_stale() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::write(&partial_path, b"first half ").await?; + + let body = b"second half".to_vec(); + let total = 11_u64 + body.len() as u64; + let fixture = LocalHttpFixture::start(Scenario::always( + FixtureResponse::partial_content_with_range( + body, + format!("bytes 999-{}/{}", 999 + 10, total), + ), + )) + .await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::PartialFileStale { .. }) + )); + assert!(!tokio::fs::try_exists(&partial_path).await?); + + Ok(()) +} + +#[tokio::test] +async fn four_hundred_status_returns_download_server_rejected_request() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(400))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, sink) + .await; + + let Err(DownloadError::DownloadServerRejectedRequest { status, .. }) = result else { + bail!("expected DownloadServerRejectedRequest, got {result:?}"); + }; + assert_eq!(status, reqwest::StatusCode::BAD_REQUEST); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn five_hundred_status_returns_download_server_errored() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(500))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, sink) + .await; + + let Err(DownloadError::DownloadServerErrored { status, .. }) = result else { + bail!("expected DownloadServerErrored, got {result:?}"); + }; + assert_eq!(status, reqwest::StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn stream_drop_after_partial_body_returns_download_stream_interrupted() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let full_body = b"abcdefghijklmnop".to_vec(); + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok_drop_after( + full_body.clone(), + 6, + ))) + .await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/model.gguf"), &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::DownloadInterrupted { .. }) + )); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn progress_sink_on_finished_fires_only_on_success() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + + let fixture_success = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"ok body".to_vec()))).await?; + let sink_success = Arc::new(RecordingSink::new()); + let progress_success: Arc = sink_success.clone(); + DownloadManager::new()? + .download(&fixture_success.url("/x"), &dest, progress_success) + .await?; + assert_eq!(sink_success.finished_count.load(Ordering::Relaxed), 1); + + tokio::fs::remove_file(&dest).await?; + + let fixture_404 = + LocalHttpFixture::start(Scenario::always(FixtureResponse::status(404))).await?; + let sink_404 = Arc::new(RecordingSink::new()); + let progress_404: Arc = sink_404.clone(); + let _ = DownloadManager::new()? + .download(&fixture_404.url("/x"), &dest, progress_404) + .await; + assert_eq!(sink_404.finished_count.load(Ordering::Relaxed), 0); + + let fixture_500 = + LocalHttpFixture::start(Scenario::always(FixtureResponse::status(500))).await?; + let sink_500 = Arc::new(RecordingSink::new()); + let progress_500: Arc = sink_500.clone(); + let _ = DownloadManager::new()? + .download(&fixture_500.url("/x"), &dest, progress_500) + .await; + assert_eq!(sink_500.finished_count.load(Ordering::Relaxed), 0); + + Ok(()) +} + +#[tokio::test] +async fn unsupported_url_scheme_returns_invalid_url_error_without_network_call() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download("ftp://example.invalid/model.gguf", &dest, sink) + .await; + + assert!(matches!( + result, + Err(DownloadError::UnsupportedUrlScheme { .. }) + )); + + Ok(()) +} + +#[tokio::test] +async fn invalid_url_returns_invalid_url_error_without_network_call() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download("not a valid url", &dest, sink) + .await; + + assert!(matches!(result, Err(DownloadError::InvalidUrl { .. }))); + + Ok(()) +} + +#[tokio::test] +async fn fixture_serves_configured_status_and_body() -> Result<()> { + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"hello".to_vec()))).await?; + let response = reqwest::get(fixture.url("/x")).await?; + + assert_eq!(response.status(), 200); + assert_eq!(response.bytes().await?.as_ref(), b"hello"); + assert_eq!(fixture.request_count(), 1); + + Ok(()) +} + +#[tokio::test] +async fn fixture_distinct_ports_per_instance() -> Result<()> { + let first = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(Vec::new()))).await?; + let second = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(Vec::new()))).await?; + + assert_ne!(first.port(), second.port()); + + Ok(()) +} + +#[tokio::test] +async fn fixture_drops_connection_when_configured_to() -> Result<()> { + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok_drop_after( + b"abcdefgh".to_vec(), + 4, + ))) + .await?; + let response = reqwest::get(fixture.url("/x")).await?; + let body_result = response.bytes().await; + + assert!(body_result.is_err(), "expected dropped connection during body read"); + + Ok(()) +} + +#[tokio::test] +async fn fixture_request_count_increments() -> Result<()> { + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(Vec::new()))).await?; + + let _ = reqwest::get(fixture.url("/a")).await?; + let _ = reqwest::get(fixture.url("/b")).await?; + let _ = reqwest::get(fixture.url("/c")).await?; + + assert_eq!(fixture.request_count(), 3); + + Ok(()) +} + +#[tokio::test] +async fn returns_io_error_when_destination_directory_does_not_exist_and_cannot_be_created() +-> Result<()> { + let directory = TempDir::new()?; + let blocker = directory.path().join("blocker"); + tokio::fs::write(&blocker, b"i am a file, not a directory").await?; + let dest = blocker.join("subdir").join("model.gguf"); + + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + assert!(matches!(result, Err(DownloadError::Io { .. }))); + + Ok(()) +} + +#[tokio::test] +async fn last_recorded_range_header_returns_none_when_no_range_was_sent() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await?; + + assert!(fixture.last_recorded_range_header().is_none()); + + Ok(()) +} + +#[tokio::test] +async fn read_timeout_fires_when_server_stalls_before_headers() -> Result<()> { + use std::time::Duration; + + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::stall_before_headers())).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let outcome = tokio::time::timeout( + Duration::from_secs(20), + DownloadManager::new()?.download(&fixture.url("/model.gguf"), &dest, sink), + ) + .await; + + let result = outcome.map_err(|_elapsed| { + anyhow::anyhow!("download did not return within test guard; read_timeout never fired") + })?; + + assert!( + result.is_err(), + "stalled server must produce an error, got Ok" + ); + + Ok(()) +} + +#[tokio::test] +async fn send_error_returns_download_server_is_unreachable() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let sink: Arc = Arc::new(RecordingSink::new()); + + let url = "http://127.0.0.1:1/never-listens".to_owned(); + let result = DownloadManager::new()?.download(&url, &dest, sink).await; + + let Err(DownloadError::DownloadServerIsUnreachable { + url: error_url, .. + }) = result + else { + bail!("expected DownloadServerIsUnreachable, got {result:?}"); + }; + assert_eq!(error_url, url); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn open_for_append_error_returns_io_when_partial_path_is_a_directory() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::create_dir(&partial_path).await?; + + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + assert!(matches!(result, Err(DownloadError::Io { .. }))); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn download_returns_cache_permission_denied_when_dir_is_read_only() -> Result<()> { + use std::io; + use std::os::unix::fs::PermissionsExt; + + let directory = TempDir::new()?; + let readonly_parent = directory.path().join("readonly"); + tokio::fs::create_dir(&readonly_parent).await?; + let dest = readonly_parent.join("model.gguf"); + let mut perms = tokio::fs::metadata(&readonly_parent).await?.permissions(); + perms.set_mode(0o500); + tokio::fs::set_permissions(&readonly_parent, perms).await?; + + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + let mut restore = tokio::fs::metadata(&readonly_parent).await?.permissions(); + restore.set_mode(0o700); + tokio::fs::set_permissions(&readonly_parent, restore).await?; + + let Err(DownloadError::CachePermissionDenied { source, .. }) = result else { + bail!("expected CachePermissionDenied, got {result:?}"); + }; + assert_eq!(source.kind(), io::ErrorKind::PermissionDenied); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn finalize_error_returns_io_when_destination_is_a_non_empty_directory() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + tokio::fs::create_dir(&dest).await?; + tokio::fs::write(dest.join("blocker"), b"x").await?; + + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + assert!(matches!(result, Err(DownloadError::Io { .. }))); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn partial_file_stale_with_unremovable_partial_returns_cache_permission_denied() -> Result<()> +{ + use std::os::unix::fs::PermissionsExt; + + let directory = TempDir::new()?; + let locked_parent = directory.path().join("locked"); + tokio::fs::create_dir(&locked_parent).await?; + let dest = locked_parent.join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::write(&partial_path, b"stale").await?; + let mut perms = tokio::fs::metadata(&locked_parent).await?.permissions(); + perms.set_mode(0o500); + tokio::fs::set_permissions(&locked_parent, perms).await?; + + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::status(416))).await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + let mut restore = tokio::fs::metadata(&locked_parent).await?.permissions(); + restore.set_mode(0o700); + tokio::fs::set_permissions(&locked_parent, restore).await?; + + assert!(matches!( + result, + Err(DownloadError::CachePermissionDenied { .. }) + )); + + Ok(()) +} + +#[cfg(unix)] +#[tokio::test] +async fn truncate_error_during_ignore_range_returns_io() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + tokio::fs::create_dir(&partial_path).await?; + + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(b"body".to_vec()))) + .await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + assert!(matches!(result, Err(DownloadError::Io { .. }))); + + Ok(()) +} + +#[cfg(target_os = "linux")] +#[tokio::test] +async fn download_returns_cache_disk_full_when_target_is_dev_full() -> Result<()> { + let directory = TempDir::new()?; + let dest = directory.path().join("model.gguf"); + let partial_path = dest.with_extension("partial"); + std::os::unix::fs::symlink("/dev/full", &partial_path)?; + + let fixture = LocalHttpFixture::start(Scenario::always(FixtureResponse::ok( + b"this body will fail to write because /dev/full".to_vec(), + ))) + .await?; + let sink: Arc = Arc::new(RecordingSink::new()); + + let result = DownloadManager::new()? + .download(&fixture.url("/x"), &dest, sink) + .await; + + let Err(DownloadError::CacheDiskFull { source, .. }) = result else { + bail!("expected CacheDiskFull, got {result:?}"); + }; + assert_eq!(source.raw_os_error(), Some(28)); + + Ok(()) +} + +#[tokio::test] +async fn download_succeeds_after_cache_dir_was_deleted_between_calls() -> Result<()> { + let directory = TempDir::new()?; + let cache_subdir = directory.path().join("cache"); + let dest = cache_subdir.join("model.gguf"); + + let body = b"model bytes for the recreation test".to_vec(); + let fixture = + LocalHttpFixture::start(Scenario::always(FixtureResponse::ok(body.clone()))).await?; + let url = fixture.url("/x"); + + DownloadManager::new()? + .download( + &url, + &dest, + Arc::new(RecordingSink::new()) as Arc, + ) + .await?; + assert_eq!(tokio::fs::read(&dest).await?, body); + + tokio::fs::remove_dir_all(&cache_subdir).await?; + + DownloadManager::new()? + .download( + &url, + &dest, + Arc::new(RecordingSink::new()) as Arc, + ) + .await?; + assert_eq!(tokio::fs::read(&dest).await?, body); + + Ok(()) +} diff --git a/paddler_download_manager/tests/local_http_fixture/mod.rs b/paddler_download_manager/tests/local_http_fixture/mod.rs new file mode 100644 index 00000000..12ff938e --- /dev/null +++ b/paddler_download_manager/tests/local_http_fixture/mod.rs @@ -0,0 +1,271 @@ +use std::io; +use std::sync::Arc; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering; + +use anyhow::Result; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::oneshot; +use tokio::sync::watch; +use tokio::task::JoinHandle; + +const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2); + +#[derive(Clone)] +pub enum FixtureResponse { + Ok(Vec), + PartialContentWithRange { + body: Vec, + content_range: String, + }, + Status(u16), + OkDropAfter { + body: Vec, + bytes_before_drop: usize, + }, + StallBeforeHeaders, +} + +impl FixtureResponse { + pub const fn ok(body: Vec) -> Self { + Self::Ok(body) + } + + pub const fn partial_content_with_range(body: Vec, content_range: String) -> Self { + Self::PartialContentWithRange { + body, + content_range, + } + } + + pub const fn status(code: u16) -> Self { + Self::Status(code) + } + + pub const fn ok_drop_after(body: Vec, bytes_before_drop: usize) -> Self { + Self::OkDropAfter { + body, + bytes_before_drop, + } + } + + pub const fn stall_before_headers() -> Self { + Self::StallBeforeHeaders + } +} + +pub enum Scenario { + Always(FixtureResponse), +} + +impl Scenario { + pub const fn always(response: FixtureResponse) -> Self { + Self::Always(response) + } +} + +pub struct LocalHttpFixture { + accept_task: Option>, + last_range_rx: watch::Receiver>, + port: u16, + request_count: Arc, + shutdown_tx: Option>, +} + +impl LocalHttpFixture { + pub async fn start(scenario: Scenario) -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let port = listener.local_addr()?.port(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + let request_count = Arc::new(AtomicU32::new(0)); + let (last_range_tx, last_range_rx) = watch::channel(None::); + let last_range_tx = Arc::new(last_range_tx); + let scenario_state = Arc::new(ScenarioState::from(scenario)); + + let accept_request_count = request_count.clone(); + let accept_last_range_tx = last_range_tx; + let accept_scenario_state = scenario_state; + + let accept_task = tokio::spawn(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => break, + accept_result = listener.accept() => { + let Ok((socket, _addr)) = accept_result else { + break; + }; + let request_count_for_conn = accept_request_count.clone(); + let last_range_tx_for_conn = accept_last_range_tx.clone(); + let scenario_state_for_conn = accept_scenario_state.clone(); + + tokio::spawn(async move { + let _ = handle_connection( + socket, + request_count_for_conn, + last_range_tx_for_conn, + scenario_state_for_conn, + ) + .await; + }); + } + } + } + }); + + Ok(Self { + accept_task: Some(accept_task), + last_range_rx, + port, + request_count, + shutdown_tx: Some(shutdown_tx), + }) + } + + pub const fn port(&self) -> u16 { + self.port + } + + pub fn url(&self, path: &str) -> String { + format!("http://127.0.0.1:{}{path}", self.port) + } + + pub fn request_count(&self) -> u32 { + self.request_count.load(Ordering::Relaxed) + } + + pub fn last_recorded_range_header(&self) -> Option { + self.last_range_rx.borrow().clone() + } +} + +impl Drop for LocalHttpFixture { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + if let Some(task) = self.accept_task.take() { + task.abort(); + } + } +} + +enum ScenarioState { + Always(FixtureResponse), +} + +impl ScenarioState { + fn next(&self) -> FixtureResponse { + match self { + Self::Always(response) => response.clone(), + } + } +} + +impl From for ScenarioState { + fn from(scenario: Scenario) -> Self { + match scenario { + Scenario::Always(response) => Self::Always(response), + } + } +} + +async fn handle_connection( + mut socket: TcpStream, + request_count: Arc, + last_range_tx: Arc>>, + scenario_state: Arc, +) -> Result<()> { + let (reader_half, mut writer_half) = socket.split(); + let mut reader = BufReader::new(reader_half); + + let mut request_line = String::new(); + tokio::time::timeout(READ_TIMEOUT, reader.read_line(&mut request_line)).await??; + + let mut range_header_value: Option = None; + loop { + let mut header_line = String::new(); + let bytes_read = + tokio::time::timeout(READ_TIMEOUT, reader.read_line(&mut header_line)).await??; + if bytes_read == 0 || header_line == "\r\n" || header_line == "\n" { + break; + } + if let Some(rest) = header_line.strip_prefix("Range:") { + range_header_value = Some(rest.trim().to_owned()); + } else if let Some(rest) = header_line.strip_prefix("range:") { + range_header_value = Some(rest.trim().to_owned()); + } + } + + last_range_tx.send_replace(range_header_value); + request_count.fetch_add(1, Ordering::Relaxed); + + let response = scenario_state.next(); + write_response(&mut writer_half, response).await?; + Ok(()) +} + +async fn write_response(writer: &mut TWriter, response: FixtureResponse) -> io::Result<()> +where + TWriter: AsyncWriteExt + Unpin, +{ + match response { + FixtureResponse::Ok(body) => { + let header = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + writer.write_all(header.as_bytes()).await?; + writer.write_all(&body).await?; + writer.shutdown().await?; + } + FixtureResponse::PartialContentWithRange { + body, + content_range, + } => { + let header = format!( + "HTTP/1.1 206 Partial Content\r\nContent-Length: {}\r\nContent-Range: {}\r\nConnection: close\r\n\r\n", + body.len(), + content_range, + ); + writer.write_all(header.as_bytes()).await?; + writer.write_all(&body).await?; + writer.shutdown().await?; + } + FixtureResponse::Status(code) => { + let status_text = match code { + 401 => "Unauthorized", + 403 => "Forbidden", + 404 => "Not Found", + 416 => "Range Not Satisfiable", + 500 => "Internal Server Error", + 503 => "Service Unavailable", + _ => "Other", + }; + let header = format!( + "HTTP/1.1 {code} {status_text}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + ); + writer.write_all(header.as_bytes()).await?; + writer.shutdown().await?; + } + FixtureResponse::OkDropAfter { + body, + bytes_before_drop, + } => { + let header = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + writer.write_all(header.as_bytes()).await?; + let truncated_len = bytes_before_drop.min(body.len()); + writer.write_all(&body[..truncated_len]).await?; + } + FixtureResponse::StallBeforeHeaders => { + std::future::pending::<()>().await; + } + } + Ok(()) +} diff --git a/paddler_gui/src/agent_running_data.rs b/paddler_gui/src/agent_running_data.rs index e2c6d042..91fce71a 100644 --- a/paddler_gui/src/agent_running_data.rs +++ b/paddler_gui/src/agent_running_data.rs @@ -14,6 +14,7 @@ impl AgentRunningData { desired_slots_total: status.desired_slots_total, download_current: status.download_current, download_filename: status.download_filename, + download_indeterminate: status.download_indeterminate, download_total: status.download_total, id: String::new(), issues: status.issues, diff --git a/paddler_gui/src/running_balancer_snapshot.rs b/paddler_gui/src/running_balancer_snapshot.rs index dcf57484..c2ecab0b 100644 --- a/paddler_gui/src/running_balancer_snapshot.rs +++ b/paddler_gui/src/running_balancer_snapshot.rs @@ -50,7 +50,7 @@ mod tests { use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; - use std::sync::atomic::AtomicUsize; + use std::sync::atomic::AtomicU64; use anyhow::Result; use paddler::atomic_value::AtomicValue; @@ -81,9 +81,10 @@ mod tests { ), connection_close: CancellationToken::new(), desired_slots_total: AtomicValue::::new(0), - download_current: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), download_filename: RwLock::new(None), - download_total: AtomicValue::::new(0), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), id: id.to_owned(), diff --git a/paddler_gui/src/screen.rs b/paddler_gui/src/screen.rs index a89daf7d..a46fff74 100644 --- a/paddler_gui/src/screen.rs +++ b/paddler_gui/src/screen.rs @@ -75,6 +75,7 @@ impl Screen { desired_slots_total: 0, download_current: 0, download_filename: None, + download_indeterminate: true, download_total: 0, id: String::new(), issues: BTreeSet::new(), diff --git a/paddler_gui/src/ui/view_running_balancer.rs b/paddler_gui/src/ui/view_running_balancer.rs index a8c4b6f3..1b7bf2f3 100644 --- a/paddler_gui/src/ui/view_running_balancer.rs +++ b/paddler_gui/src/ui/view_running_balancer.rs @@ -32,6 +32,7 @@ fn format_desired_model(desired_model: &AgentDesiredModel) -> String { ) } AgentDesiredModel::LocalToAgent(path) => format!("Local: {path}"), + AgentDesiredModel::Url(reference) => format!("URL: {}", reference.url), AgentDesiredModel::None => "(not set)".to_owned(), } } diff --git a/paddler_tests/Cargo.toml b/paddler_tests/Cargo.toml index cd927f8f..f1006554 100644 --- a/paddler_tests/Cargo.toml +++ b/paddler_tests/Cargo.toml @@ -13,6 +13,7 @@ default = [] cuda = ["paddler/cuda"] metal = ["paddler/metal"] tests_that_use_compiled_paddler = [] +tests_that_use_in_process_cluster = [] tests_that_use_llms = [] web_admin_panel = ["paddler/web_admin_panel", "paddler_bootstrap/web_admin_panel"] diff --git a/paddler_tests/src/agents_stream_watcher.rs b/paddler_tests/src/agents_stream_watcher.rs index 0f93aba9..c77f5d08 100644 --- a/paddler_tests/src/agents_stream_watcher.rs +++ b/paddler_tests/src/agents_stream_watcher.rs @@ -194,6 +194,7 @@ mod tests { desired_slots_total: 1, download_current: 0, download_filename: None, + download_indeterminate: true, download_total: 0, id: agent_id.to_owned(), issues, diff --git a/paddler_tests/src/lib.rs b/paddler_tests/src/lib.rs index 0cffe675..1afef49d 100644 --- a/paddler_tests/src/lib.rs +++ b/paddler_tests/src/lib.rs @@ -17,6 +17,7 @@ pub mod in_process_cluster_params; pub mod inference_http_client; pub mod inference_message_stream; pub mod load_test_image_data_uri; +pub mod local_http_fixture; pub mod make_agent_controller_without_remote_agent; pub mod make_inference_parameters_deterministic; pub mod ministral_3_in_process_cluster_params; diff --git a/paddler_tests/src/local_http_fixture.rs b/paddler_tests/src/local_http_fixture.rs new file mode 100644 index 00000000..723072aa --- /dev/null +++ b/paddler_tests/src/local_http_fixture.rs @@ -0,0 +1,143 @@ +use std::sync::Arc; + +use anyhow::Context as _; +use anyhow::Result; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; + +pub struct LocalHttpFixture { + accept_task: Option>, + port: u16, + shutdown_tx: Option>, +} + +impl LocalHttpFixture { + pub async fn start(status_line: &'static str, body: Vec) -> Result { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .context("Failed to bind 127.0.0.1:0 for LocalHttpFixture")?; + let port = listener + .local_addr() + .context("LocalHttpFixture listener has no local addr")? + .port(); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + let body_arc = Arc::new(body); + let accept_task = tokio::spawn(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => break, + connection = listener.accept() => { + let Ok((mut socket, _addr)) = connection else { + break; + }; + let body_for_connection = body_arc.clone(); + tokio::spawn(async move { + let mut buffer = [0_u8; 1024]; + let _read = socket.read(&mut buffer).await; + + let response = format!( + "{status_line}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body_for_connection.len() + ); + let _written_headers = socket.write_all(response.as_bytes()).await; + let _written_body = socket.write_all(&body_for_connection).await; + let _flushed = socket.shutdown().await; + }); + } + } + } + }); + + Ok(Self { + accept_task: Some(accept_task), + port, + shutdown_tx: Some(shutdown_tx), + }) + } + + #[must_use] + pub const fn port(&self) -> u16 { + self.port + } + + #[must_use] + pub fn url(&self, path: &str) -> String { + format!("http://127.0.0.1:{}{path}", self.port) + } +} + +impl Drop for LocalHttpFixture { + fn drop(&mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + + if let Some(accept_task) = self.accept_task.take() { + accept_task.abort(); + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::local_http_fixture::LocalHttpFixture; + + #[tokio::test] + async fn serves_configured_status_and_body() -> Result<()> { + let fixture = LocalHttpFixture::start("HTTP/1.1 200 OK", b"hello bytes".to_vec()).await?; + let response = reqwest::get(fixture.url("/whatever")).await?; + + assert_eq!(response.status(), reqwest::StatusCode::OK); + assert_eq!(response.bytes().await?.as_ref(), b"hello bytes"); + + Ok(()) + } + + #[tokio::test] + async fn serves_404_when_configured() -> Result<()> { + let fixture = LocalHttpFixture::start("HTTP/1.1 404 Not Found", Vec::new()).await?; + let response = reqwest::get(fixture.url("/missing")).await?; + + assert_eq!(response.status(), reqwest::StatusCode::NOT_FOUND); + + Ok(()) + } + + #[tokio::test] + async fn each_fixture_gets_a_distinct_port() -> Result<()> { + let first = LocalHttpFixture::start("HTTP/1.1 200 OK", Vec::new()).await?; + let second = LocalHttpFixture::start("HTTP/1.1 200 OK", Vec::new()).await?; + + assert_ne!(first.port(), second.port()); + + Ok(()) + } + + #[tokio::test] + async fn drop_stops_accepting_connections() -> Result<()> { + let fixture = LocalHttpFixture::start("HTTP/1.1 200 OK", b"alive".to_vec()).await?; + let url = fixture.url("/alive"); + + let still_alive_response = reqwest::get(&url).await?; + assert_eq!(still_alive_response.status(), reqwest::StatusCode::OK); + + drop(fixture); + + let after_drop = reqwest::Client::new() + .get(&url) + .timeout(std::time::Duration::from_millis(500)) + .send() + .await; + assert!( + after_drop.is_err(), + "fixture should be unreachable after drop" + ); + + Ok(()) + } +} diff --git a/paddler_tests/src/make_agent_controller_without_remote_agent.rs b/paddler_tests/src/make_agent_controller_without_remote_agent.rs index ea360ff6..263f2ad6 100644 --- a/paddler_tests/src/make_agent_controller_without_remote_agent.rs +++ b/paddler_tests/src/make_agent_controller_without_remote_agent.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicI32; -use std::sync::atomic::AtomicUsize; +use std::sync::atomic::AtomicU64; use paddler::atomic_value::AtomicValue; use paddler::balancer::agent_controller::AgentController; @@ -26,9 +26,10 @@ pub fn make_agent_controller_without_remote_agent(id: &str) -> AgentController { ), connection_close: CancellationToken::new(), desired_slots_total: AtomicValue::::new(0), - download_current: AtomicValue::::new(0), + download_current: AtomicValue::::new(0), download_filename: RwLock::new(None), - download_total: AtomicValue::::new(0), + download_indeterminate: AtomicValue::::new(true), + download_total: AtomicValue::::new(0), embedding_sender_collection: Arc::new(EmbeddingSenderCollection::default()), generate_tokens_sender_collection: Arc::new(GenerateTokensSenderCollection::default()), id: id.to_owned(), diff --git a/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs b/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs new file mode 100644 index 00000000..e5f83175 --- /dev/null +++ b/paddler_tests/tests/balancer_persists_url_model_in_desired_state.rs @@ -0,0 +1,54 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::url_model_reference::UrlModelReference; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_persists_url_model_in_desired_state() -> Result<()> { + let configured_url = "https://example.invalid/persisted-model.gguf".to_owned(); + + let cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: Vec::new(), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::Url(UrlModelReference { + url: configured_url.clone(), + }), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let retrieved = cluster + .paddler_client + .management() + .get_balancer_desired_state() + .await + .map_err(anyhow::Error::new) + .context("failed to read balancer desired state")?; + + assert_eq!( + retrieved.model, + AgentDesiredModel::Url(UrlModelReference { + url: configured_url, + }) + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs new file mode 100644 index 00000000..f067a482 --- /dev/null +++ b/paddler_tests/tests/balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced.rs @@ -0,0 +1,104 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use std::time::Duration; + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::model_card::ModelCard; +use paddler_tests::model_card::qwen3_0_6b::qwen3_0_6b; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::chat_template::ChatTemplate; +use paddler_types::inference_parameters::InferenceParameters; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_reports_chat_template_does_not_compile_recovers_when_template_replaced( +) -> Result<()> { + let ModelCard { reference, .. } = qwen3_0_6b(); + + let invalid_template = ChatTemplate { + content: "{{invalid jinja template".to_owned(), + }; + let valid_template = ChatTemplate { + content: "{% for message in messages %}{{ message.content }}{% endfor %}".to_owned(), + }; + + let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: AgentConfig::uniform(1, 1), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: Some(invalid_template), + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::HuggingFace(reference.clone()), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: true, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let agent_id = cluster + .agent_ids + .first() + .context("cluster must have one registered agent")? + .clone(); + + let predicate_agent_id = agent_id.clone(); + cluster + .agents + .until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == predicate_agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::ChatTemplateDoesNotCompile(_))) + }) + }) + .await + .context("balancer should report ChatTemplateDoesNotCompile for invalid Jinja syntax")?; + + let recovered_state = BalancerDesiredState { + chat_template_override: Some(valid_template), + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::HuggingFace(reference), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: true, + }; + + cluster + .paddler_client + .management() + .put_balancer_desired_state(&recovered_state) + .await + .map_err(anyhow::Error::new) + .context("balancer should accept the recovered desired state")?; + + let predicate_agent_id_for_recovery = agent_id; + tokio::time::timeout( + Duration::from_secs(3), + cluster.agents.until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == predicate_agent_id_for_recovery + && agent + .issues + .iter() + .all(|issue| !matches!(issue, AgentIssue::ChatTemplateDoesNotCompile(_))) + }) + }), + ) + .await + .context("reconciliation should clear ChatTemplateDoesNotCompile within 3 seconds")??; + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_reports_download_server_denied_access.rs b/paddler_tests/tests/balancer_reports_download_server_denied_access.rs new file mode 100644 index 00000000..39e03e25 --- /dev/null +++ b/paddler_tests/tests/balancer_reports_download_server_denied_access.rs @@ -0,0 +1,75 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::local_http_fixture::LocalHttpFixture; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::url_model_reference::UrlModelReference; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_reports_download_server_denied_access() -> Result<()> { + let fixture = LocalHttpFixture::start("HTTP/1.1 403 Forbidden", Vec::new()).await?; + let model_url = fixture.url("/private.gguf"); + + let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: AgentConfig::uniform(1, 1), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::Url(UrlModelReference { + url: model_url.clone(), + }), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let agent_id = cluster + .agent_ids + .first() + .context("cluster must have one registered agent")? + .clone(); + + let snapshot = cluster + .agents + .until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::DownloadServerDeniedAccess(_))) + }) + }) + .await + .context("balancer should report DownloadServerDeniedAccess for a 403 URL")?; + + let saw_expected_url = snapshot.agents.iter().any(|agent| { + agent.issues.iter().any(|issue| { + matches!(issue, AgentIssue::DownloadServerDeniedAccess(model_path) + if model_path.model_path == model_url) + }) + }); + + assert!( + saw_expected_url, + "DownloadServerDeniedAccess should reference the configured URL" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_reports_download_server_errored.rs b/paddler_tests/tests/balancer_reports_download_server_errored.rs new file mode 100644 index 00000000..73fed463 --- /dev/null +++ b/paddler_tests/tests/balancer_reports_download_server_errored.rs @@ -0,0 +1,76 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::local_http_fixture::LocalHttpFixture; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::url_model_reference::UrlModelReference; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_reports_download_server_errored() -> Result<()> { + let fixture = + LocalHttpFixture::start("HTTP/1.1 500 Internal Server Error", Vec::new()).await?; + let model_url = fixture.url("/broken.gguf"); + + let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: AgentConfig::uniform(1, 1), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::Url(UrlModelReference { + url: model_url.clone(), + }), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let agent_id = cluster + .agent_ids + .first() + .context("cluster must have one registered agent")? + .clone(); + + let snapshot = cluster + .agents + .until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::DownloadServerErrored(_))) + }) + }) + .await + .context("balancer should report DownloadServerErrored when the server returns 500")?; + + let saw_expected_url = snapshot.agents.iter().any(|agent| { + agent.issues.iter().any(|issue| { + matches!(issue, AgentIssue::DownloadServerErrored(model_path) + if model_path.model_path == model_url) + }) + }); + + assert!( + saw_expected_url, + "DownloadServerErrored should reference the configured URL" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs b/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs new file mode 100644 index 00000000..8b542da7 --- /dev/null +++ b/paddler_tests/tests/balancer_reports_download_server_is_unreachable.rs @@ -0,0 +1,75 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::url_model_reference::UrlModelReference; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_reports_download_server_is_unreachable() -> Result<()> { + let model_url = "http://127.0.0.1:1/model.gguf".to_owned(); + + let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: AgentConfig::uniform(1, 1), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::Url(UrlModelReference { + url: model_url.clone(), + }), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let agent_id = cluster + .agent_ids + .first() + .context("cluster must have one registered agent")? + .clone(); + + let snapshot = cluster + .agents + .until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::DownloadServerIsUnreachable(_))) + }) + }) + .await + .context( + "balancer should report DownloadServerIsUnreachable when the URL points at a dead port", + )?; + + let saw_expected_url = snapshot.agents.iter().any(|agent| { + agent.issues.iter().any(|issue| { + matches!(issue, AgentIssue::DownloadServerIsUnreachable(model_path) + if model_path.model_path == model_url) + }) + }); + + assert!( + saw_expected_url, + "DownloadServerIsUnreachable should reference the configured URL" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs b/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs new file mode 100644 index 00000000..e711d191 --- /dev/null +++ b/paddler_tests/tests/balancer_reports_download_url_is_malformed.rs @@ -0,0 +1,73 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::url_model_reference::UrlModelReference; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_reports_download_url_is_malformed() -> Result<()> { + let malformed_url = "not a valid url".to_owned(); + + let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: AgentConfig::uniform(1, 1), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::Url(UrlModelReference { + url: malformed_url.clone(), + }), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let agent_id = cluster + .agent_ids + .first() + .context("cluster must have one registered agent")? + .clone(); + + let snapshot = cluster + .agents + .until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::DownloadUrlIsMalformed(_))) + }) + }) + .await + .context("balancer should report DownloadUrlIsMalformed for an invalid URL")?; + + let saw_expected_url = snapshot.agents.iter().any(|agent| { + agent.issues.iter().any(|issue| { + matches!(issue, AgentIssue::DownloadUrlIsMalformed(model_path) + if model_path.model_path == malformed_url) + }) + }); + + assert!( + saw_expected_url, + "DownloadUrlIsMalformed should reference the configured URL" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs b/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs new file mode 100644 index 00000000..502761fc --- /dev/null +++ b/paddler_tests/tests/balancer_reports_model_does_not_exist_at_url.rs @@ -0,0 +1,75 @@ +#![cfg(all( + feature = "tests_that_use_compiled_paddler", + feature = "tests_that_use_llms" +))] + +use anyhow::Context as _; +use anyhow::Result; +use paddler_tests::agent_config::AgentConfig; +use paddler_tests::local_http_fixture::LocalHttpFixture; +use paddler_tests::start_subprocess_cluster::start_subprocess_cluster; +use paddler_tests::subprocess_cluster_params::SubprocessClusterParams; +use paddler_types::agent_desired_model::AgentDesiredModel; +use paddler_types::agent_issue::AgentIssue; +use paddler_types::balancer_desired_state::BalancerDesiredState; +use paddler_types::inference_parameters::InferenceParameters; +use paddler_types::url_model_reference::UrlModelReference; + +#[serial_test::file_serial(model_load, path => "../target/model_load.lock")] +#[tokio::test(flavor = "multi_thread")] +async fn balancer_reports_model_does_not_exist_at_url() -> Result<()> { + let fixture = LocalHttpFixture::start("HTTP/1.1 404 Not Found", Vec::new()).await?; + let model_url = fixture.url("/missing.gguf"); + + let mut cluster = start_subprocess_cluster(SubprocessClusterParams { + agents: AgentConfig::uniform(1, 1), + wait_for_slots_ready: false, + desired_state: Some(BalancerDesiredState { + chat_template_override: None, + inference_parameters: InferenceParameters::default(), + model: AgentDesiredModel::Url(UrlModelReference { + url: model_url.clone(), + }), + multimodal_projection: AgentDesiredModel::None, + use_chat_template_override: false, + }), + ..SubprocessClusterParams::default() + }) + .await?; + + let agent_id = cluster + .agent_ids + .first() + .context("cluster must have one registered agent")? + .clone(); + + let snapshot = cluster + .agents + .until(move |snapshot| { + snapshot.agents.iter().any(|agent| { + agent.id == agent_id + && agent + .issues + .iter() + .any(|issue| matches!(issue, AgentIssue::ModelDoesNotExistAtUrl(_))) + }) + }) + .await + .context("balancer should report ModelDoesNotExistAtUrl for 404 URL")?; + + let saw_expected_url = snapshot.agents.iter().any(|agent| { + agent.issues.iter().any(|issue| { + matches!(issue, AgentIssue::ModelDoesNotExistAtUrl(model_path) + if model_path.model_path == model_url) + }) + }); + + assert!( + saw_expected_url, + "ModelDoesNotExistAtUrl should reference the configured URL" + ); + + cluster.shutdown().await?; + + Ok(()) +} diff --git a/paddler_tests/tests/harness_agents_watcher.rs b/paddler_tests/tests/harness_agents_watcher.rs index fac8d7d9..e5f17292 100644 --- a/paddler_tests/tests/harness_agents_watcher.rs +++ b/paddler_tests/tests/harness_agents_watcher.rs @@ -17,6 +17,7 @@ fn make_snapshot(agent_id: &str, slots_total: i32) -> AgentControllerPoolSnapsho desired_slots_total: slots_total, download_current: 0, download_filename: None, + download_indeterminate: true, download_total: 0, id: agent_id.to_owned(), issues: BTreeSet::new(), diff --git a/paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs b/paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs index c77ef076..6a9e5639 100644 --- a/paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs +++ b/paddler_tests/tests/in_process_cluster_lifecycle_under_concurrent_load_does_not_hang.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "tests_that_use_in_process_cluster")] + use std::time::Duration; use std::time::Instant; diff --git a/paddler_types/src/agent_controller_snapshot.rs b/paddler_types/src/agent_controller_snapshot.rs index d28b0fe0..559ffefa 100644 --- a/paddler_types/src/agent_controller_snapshot.rs +++ b/paddler_types/src/agent_controller_snapshot.rs @@ -10,9 +10,10 @@ use crate::agent_state_application_status::AgentStateApplicationStatus; #[serde(deny_unknown_fields)] pub struct AgentControllerSnapshot { pub desired_slots_total: i32, - pub download_current: usize, + pub download_current: u64, pub download_filename: Option, - pub download_total: usize, + pub download_indeterminate: bool, + pub download_total: u64, pub id: String, pub issues: BTreeSet, pub model_path: Option, diff --git a/paddler_types/src/agent_desired_model.rs b/paddler_types/src/agent_desired_model.rs index 3ce0d044..02e6ea73 100644 --- a/paddler_types/src/agent_desired_model.rs +++ b/paddler_types/src/agent_desired_model.rs @@ -2,12 +2,14 @@ use serde::Deserialize; use serde::Serialize; use crate::huggingface_model_reference::HuggingFaceModelReference; +use crate::url_model_reference::UrlModelReference; #[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)] #[serde(deny_unknown_fields)] pub enum AgentDesiredModel { HuggingFace(HuggingFaceModelReference), LocalToAgent(String), + Url(UrlModelReference), #[default] None, } diff --git a/paddler_types/src/agent_issue.rs b/paddler_types/src/agent_issue.rs index 4160bb5b..60c9c47e 100644 --- a/paddler_types/src/agent_issue.rs +++ b/paddler_types/src/agent_issue.rs @@ -9,11 +9,22 @@ use crate::agent_issue_params::SlotCannotStartParams; #[derive(Clone, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] #[serde(deny_unknown_fields)] pub enum AgentIssue { + CacheCannotAcquireLock(ModelPath), + CacheDirectoryIsNotWritable(ModelPath), + CacheStorageIsFull(ModelPath), ChatTemplateDoesNotCompile(ChatTemplateDoesNotCompileParams), + DownloadInterrupted(ModelPath), + DownloadServerDeniedAccess(ModelPath), + DownloadServerErrored(ModelPath), + DownloadServerIsUnreachable(ModelPath), + DownloadServerRejectedRequest(ModelPath), + DownloadUrlIsMalformed(ModelPath), HuggingFaceCannotAcquireLock(HuggingFaceDownloadLock), HuggingFaceModelDoesNotExist(ModelPath), HuggingFacePermissions(ModelPath), + ModelCacheIsCorrupted(ModelPath), ModelCannotBeLoaded(ModelPath), + ModelDoesNotExistAtUrl(ModelPath), ModelFileDoesNotExist(ModelPath), MultimodalProjectionCannotBeLoaded(ModelPath), SlotCannotStart(SlotCannotStartParams), diff --git a/paddler_types/src/lib.rs b/paddler_types/src/lib.rs index e2a333e0..b5d3bb8b 100644 --- a/paddler_types/src/lib.rs +++ b/paddler_types/src/lib.rs @@ -41,4 +41,5 @@ pub mod request_params; pub mod rpc_message; pub mod slot_aggregated_status_snapshot; pub mod streamable_result; +pub mod url_model_reference; pub mod validates; diff --git a/paddler_types/src/slot_aggregated_status_snapshot.rs b/paddler_types/src/slot_aggregated_status_snapshot.rs index ca201f8d..94dfdbc7 100644 --- a/paddler_types/src/slot_aggregated_status_snapshot.rs +++ b/paddler_types/src/slot_aggregated_status_snapshot.rs @@ -10,9 +10,10 @@ use crate::agent_state_application_status::AgentStateApplicationStatus; #[serde(deny_unknown_fields)] pub struct SlotAggregatedStatusSnapshot { pub desired_slots_total: i32, - pub download_current: usize, + pub download_current: u64, pub download_filename: Option, - pub download_total: usize, + pub download_indeterminate: bool, + pub download_total: u64, pub issues: BTreeSet, pub model_path: Option, pub slots_processing: i32, diff --git a/paddler_types/src/url_model_reference.rs b/paddler_types/src/url_model_reference.rs new file mode 100644 index 00000000..740decc4 --- /dev/null +++ b/paddler_types/src/url_model_reference.rs @@ -0,0 +1,8 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] +#[serde(deny_unknown_fields)] +pub struct UrlModelReference { + pub url: String, +} diff --git a/resources/ts/components/AgentIssues.tsx b/resources/ts/components/AgentIssues.tsx index 013cdf95..c9c5a162 100644 --- a/resources/ts/components/AgentIssues.tsx +++ b/resources/ts/components/AgentIssues.tsx @@ -224,6 +224,248 @@ export function AgentIssues({ issues }: { issues: Array }) { ); } + if ("DownloadUrlIsMalformed" in issue) { + return ( +
  • + + Download URL is malformed:{" "} + {issue.DownloadUrlIsMalformed.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking, but the same malformed URL + will keep failing the same way. +

    + What can you do?{" "} +

    + + Edit the model URL on the model configuration page + {" "} + to a valid http or https URL. +

    +
  • + ); + } + + if ("ModelDoesNotExistAtUrl" in issue) { + return ( +
  • + + Model does not exist at URL:{" "} + {issue.ModelDoesNotExistAtUrl.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking, but the same 404 will keep + firing until the remote server publishes the file at that URL. +

    + What can you do?{" "} +

    + Check the URL — the file may have moved or been removed.{" "} + Update the URL or replace it with one + that resolves. +

    +
  • + ); + } + + if ("DownloadServerDeniedAccess" in issue) { + return ( +
  • + + Download server denied access:{" "} + {issue.DownloadServerDeniedAccess.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking; if the server starts + accepting the request, the next attempt will succeed. +

    + What can you do?{" "} +

    + Confirm the URL is correct and reachable without auth. If it's + a private model, switch to a URL that doesn't require + credentials, or use the HuggingFace integration instead. +

    +
  • + ); + } + + if ("DownloadServerErrored" in issue) { + return ( +
  • + + Download server returned an error status:{" "} + {issue.DownloadServerErrored.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking. If the server starts + answering normally, the next attempt will succeed. +

    + What can you do?{" "} +

    + The remote server is reachable but returning a 5xx response. + Check the server's status page or logs if you control it; + otherwise wait — overload and maintenance windows usually clear + on the server's end. +

    +
  • + ); + } + + if ("DownloadInterrupted" in issue) { + return ( +
  • + + Download was interrupted:{" "} + {issue.DownloadInterrupted.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking. The next attempt resumes + from the bytes already on disk if the server supports Range + requests; otherwise it starts fresh. +

    + What can you do?{" "} +

    + Often transient — check network stability and whether the + remote server is being restarted or rate-limiting. No action + needed if it clears on its own. +

    +
  • + ); + } + + if ("DownloadServerIsUnreachable" in issue) { + return ( +
  • + + Download server is unreachable:{" "} + {issue.DownloadServerIsUnreachable.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking; if the network comes back, + the next attempt will succeed. +

    + What can you do?{" "} +

    + Check the agent's internet connection, firewall rules, and the + remote server's status. +

    +
  • + ); + } + + if ("DownloadServerRejectedRequest" in issue) { + return ( +
  • + + Download server rejected the request:{" "} + {issue.DownloadServerRejectedRequest.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking. If the server starts accepting + the request, the next attempt will succeed. +

    + What can you do?{" "} +

    + The server responded with a 4xx status, meaning the request was + rejected (for example bad URL, throttling, or unsupported + method). Verify the model URL is correct and that the host + isn't rate-limiting the agent. +

    +
  • + ); + } + + if ("CacheCannotAcquireLock" in issue) { + return ( +
  • + + Cannot acquire download lock:{" "} + {issue.CacheCannotAcquireLock.model_path} + + What will Paddler do?{" "} +

    + Paddler will reattempt to download the model every few seconds + until the download lock can be acquired. +

    + What can you do?{" "} +

    + This is likely a temporary issue. It happens when another + process is currently downloading this URL into the shared cache + directory. +

    +
  • + ); + } + + if ("CacheDirectoryIsNotWritable" in issue) { + return ( +
  • + + Cache directory is not writable:{" "} + {issue.CacheDirectoryIsNotWritable.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking; the moment write permission + is restored, the next attempt will succeed. +

    + What can you do?{" "} +

    + Grant write permission to the cache directory ( + $XDG_CACHE_HOME/paddler on Linux/macOS,{" "} + %LOCALAPPDATA%\paddler on Windows), or set{" "} + PADDLER_CACHE_DIR to a writable location. +

    +
  • + ); + } + + if ("CacheStorageIsFull" in issue) { + return ( +
  • + + Cache storage is full while downloading:{" "} + {issue.CacheStorageIsFull.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking; the moment space is + available, the next attempt will succeed. +

    + What can you do?{" "} +

    Free space on the disk that hosts the cache directory.

    +
  • + ); + } + + if ("ModelCacheIsCorrupted" in issue) { + return ( +
  • + + Model cache is corrupted:{" "} + {issue.ModelCacheIsCorrupted.model_path} + + What will Paddler do?{" "} +

    + Paddler will keep re-checking; the cache will be rebuilt + on the next attempt. +

    + What can you do?{" "} +

    + If the issue persists, manually clear the{" "} + downloaded-models subdirectory of the cache and + let Paddler rebuild it. +

    +
  • + ); + } + return (
  • Unknown issue: {JSON.stringify(issue)} diff --git a/resources/ts/components/AgentList.tsx b/resources/ts/components/AgentList.tsx index 7b082544..1bebcc78 100644 --- a/resources/ts/components/AgentList.tsx +++ b/resources/ts/components/AgentList.tsx @@ -50,6 +50,7 @@ export function AgentList({ const { download_current, download_filename, + download_indeterminate, download_total, id, issues, @@ -89,9 +90,13 @@ export function AgentList({ /> )} - {download_total > 0 ? ( + {download_filename !== null ? (
    - + {download_indeterminate ? ( + + ) : ( + + )} Download diff --git a/resources/ts/components/ChangeModelPage.tsx b/resources/ts/components/ChangeModelPage.tsx index bb4fa4fd..57ef4e13 100644 --- a/resources/ts/components/ChangeModelPage.tsx +++ b/resources/ts/components/ChangeModelPage.tsx @@ -24,6 +24,10 @@ function modelSchemaToUrl(model: AgentDesiredModel): string { return `agent://${model.LocalToAgent}`; } + if ("Url" in model) { + return model.Url.url; + } + throw new Error(`Unsupported model schema: ${JSON.stringify(model)}`); }