diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1ce380c5..1d4a3ec9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -75,7 +75,7 @@ jobs: - working-directory: workers run: | uv sync --dev - uv run pytest -v --cov=workers --cov-report=xml + uv run pytest -v --cov=workers --cov-report=xml -m "not slow" build-web: runs-on: ubuntu-latest diff --git a/CLAUDE.md b/CLAUDE.md index f9fdb02c..0cc70d30 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -58,6 +58,59 @@ Plan: `thoughts/shared/plans/active-2026-05-06-deliver-orchestrator-capacity-det Investigation: `thoughts/shared/investigations/2026-05-06-ollama-think-suppression-empirical.md` Runbook: [`docs/admin/llm-config.md`](docs/admin/llm-config.md#backend-parallelism-and-the-max_concurrent_calls-field) +**2026-05-06 worker LLM concurrency refactor** — 7 commits, `3274ade..6eed915`. +Universal per-`(provider, base_url)` gate registry in the Python worker: +every LLM and embedding call passes through a host or per-kind semaphore, +jitter-aware tenacity retry loop, optional RPM limiter, and tok/s ring +buffer. Eliminates the 5×3=15-attempt storm from stacked hand-rolled +retries. `GetProviderCapabilities.max_concurrent_calls` is now sourced +from the gate's effective cap for the resolved context, not bootstrap +config, so Go and Python agree on capacity by construction. Phase 7 +extends `/api/v1/admin/llm/activity` with a `gate_snapshot` field and +adds a live "LLM Gate Activity" section to the admin monitor page. + +Load-bearing constraints for future-Claude: + +- **Don't re-enable SDK retry** (`max_retries=0` on `AsyncOpenAI` and + `AsyncAnthropic`). The tenacity wrapper owns retry. Re-enabling SDK + retry produces 5×3=15-attempt storms per Decision 3. +- **Don't add a `[llm.concurrency]` TOML section.** Concurrency is + operator-tunable via env vars, not `config.toml`. Decision 7. +- **The kill switch is the rollback path**: `SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED=false` + reverts to pre-refactor behavior without redeploy. Use it before + assuming the gate is the problem. +- **Registry is constructed-once-passed-by-reference** — constructed in + `workers/__main__.py` and `workers/common/cli_main.py` only. No + module-level singletons. Every factory call (`create_llm_provider`, + `create_embedding_provider`, etc.) receives `gate_registry=` as a + required kwarg. +- **Don't delete the empty-content retry** at + `workers/common/llm/openai_compat.py` (around lines 249–313). It + handles ``-budget exhaustion (`stop_reason=length` + empty + visible content) — it is NOT a network retry and is explicitly distinct + from the tenacity wrapper retry. +- **Gate is authoritative for `GetProviderCapabilities`**: the worker's + `GetProviderCapabilities` handler reads the registry's effective cap + for the resolved-context `(provider, base_url)` via + `workers/reasoning/servicer.py`. Don't bypass back to + `WorkerConfig.llm_max_concurrent_calls` except in the legacy fallback + path (kill switch off). +- **Host gate vs. per-kind gate classification is per-provider, not + configurable per call.** Local providers (`ollama`, `vllm`, + `llama-cpp`, `sglang`, `lmstudio`) share one host gate across LLM and + embedding. Cloud providers (`openai`, `anthropic`, `gemini`, + `openrouter`) use per-kind gates. `openai-compatible` defaults host; + flip with `SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING=per_kind`. +- **Don't fork the cross-language plumbing.** `/api/v1/admin/llm/activity` + (REST) and `KnowledgeStreamProgress` (proto) are the sole channels for + gate snapshot and per-job tok/s. Don't add a new endpoint or proto + field; extend these. + +Plan: `thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.md` +Investigation: `thoughts/shared/investigations/2026-05-06-diagnose-llm-throughput-rotten.md` +Decisions log: `thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.decisions.md` +Runbook: [`docs/admin/llm-config.md`](docs/admin/llm-config.md#operator-concurrency-tuning) + **2026-05-05 web runtime API proxy fix** — 3 commits, `1fee78b..873bc53`. Replaces `next.config.ts rewrites()` with a Next.js middleware at `web/src/middleware.ts` that proxies `/api/*`, `/auth/*`, `/healthz`, diff --git a/docs/admin/llm-config.md b/docs/admin/llm-config.md index 8d6ce5c1..9f063e36 100644 --- a/docs/admin/llm-config.md +++ b/docs/admin/llm-config.md @@ -1297,6 +1297,176 @@ and a reinforced no-think system prompt. Total attempt budget is 3 --- +## Operator concurrency tuning + +The Python worker enforces per-provider concurrency through a gate +registry built from `workers/common/llm/concurrency.py`. Every LLM and +embedding call passes through a provider gate that holds the semaphore, +an optional RPM limiter, the retry loop, and the tok/s ring buffer. +This section documents the operator-visible knobs. + +The gate registry is the **runtime source of truth** for in-worker LLM +concurrency. The `GetProviderCapabilities.max_concurrent_calls` gRPC +field (consumed by the Go orchestrator at +`internal/qa/lazy_agent_synth.go:340`) is now sourced from the gate +registry's effective cap for the resolved-context `(provider, base_url)`, +not from the bootstrap config. Setting a per-provider env var changes +both the worker's semaphore and the value the orchestrator uses to clamp +its goroutine pool — the same knob applies on both sides. + +### Env var resolution order + +Resolution is first-match-wins, top to bottom. + +| Env var | Scope | Default | Notes | +|---|---|---|---| +| `SOURCEBRIDGE_LLM_PROVIDER__MAX_CONCURRENT` | per-provider LLM (or host-total for local providers) | see table below | `` per canonical table; e.g., `SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT=4` | +| `SOURCEBRIDGE_EMBEDDING_PROVIDER__MAX_CONCURRENT` | per-provider embedding (frontier only; ignored when host-gated) | same as LLM cap | unused for Ollama — host gate combines both kinds | +| `SOURCEBRIDGE_LLM_PROVIDER__RPM` | per-provider rate limit | unset (no limiter) | applies to all providers; see tier-1 cloud values below | +| `SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING` | `openai-compatible` gate mode | `host` | set to `per_kind` if pointing at a managed endpoint with separate chat/embedding quotas | +| `SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED` | kill switch | `true` | set to `false` to revert to pre-refactor behavior without redeploy | +| `SOURCEBRIDGE_LLM_RETRY_MAX_ATTEMPTS` | tenacity retry attempts | `5` | reduce to 2–3 on unreliable networks; increase risks storms | +| `SOURCEBRIDGE_LLM_METRICS_AGGREGATION_INTERVAL_SECONDS` | gate-metrics log interval | `30` | lower to 5 for debugging; not for production steady-state | +| `SOURCEBRIDGE_WORKER_LLM_MAX_CONCURRENT_CALLS` | **legacy seed** — seeds the active LLM provider's gate cap when no per-provider override is set | unset | deprecated for new deployments; preserved for backward compat | +| `SOURCEBRIDGE_LLM_PARALLEL_HINT` | alias for the legacy seed above | unset | deprecated alias; kept for backward compat | + +**Canonical `` tokens** (`key.upper().replace("-", "_")`): + +| Provider | Env-var token | +|---|---| +| `openai` | `OPENAI` | +| `anthropic` | `ANTHROPIC` | +| `ollama` | `OLLAMA` | +| `vllm` | `VLLM` | +| `llama-cpp` | `LLAMA_CPP` | +| `sglang` | `SGLANG` | +| `gemini` | `GEMINI` | +| `openrouter` | `OPENROUTER` | +| `lmstudio` | `LMSTUDIO` | +| `openai-compatible` | `OPENAI_COMPATIBLE` | + +The worker validates env-var tokens at startup and rejects unknown +spellings (e.g., `SOURCEBRIDGE_LLM_PROVIDER_OPENAICOMPAT_MAX_CONCURRENT`) +with an actionable error naming the canonical table. + +### Default values by provider + +| Provider | Gating | LLM cap | Embedding cap | RPM default | +|---|---|---|---|---| +| `ollama` | host | 1 | (shared via host gate) | none | +| `vllm` | host | 4 | (shared) | none | +| `llama-cpp` | host | 4 | (shared) | none | +| `sglang` | host | 4 | (shared) | none | +| `lmstudio` | host | 2 | (shared) | none | +| `openai-compatible` | host (operator-flippable to `per_kind`) | 4 | (shared) | none | +| `openai` | per-kind | 8 | 16 | none | +| `anthropic` | per-kind | 4 | n/a | none | +| `gemini` | per-kind | 8 | 16 | none | +| `openrouter` | per-kind | 8 | n/a | none | + +**Host vs. per-kind gating**: local providers (`ollama`, `vllm`, +`llama-cpp`, `sglang`, `lmstudio`) use one host gate that combines LLM +and embedding calls. Cloud providers (`openai`, `anthropic`, `gemini`, +`openrouter`) use separate per-kind gates. `openai-compatible` defaults +to host; flip with `SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING=per_kind` +if the endpoint has separate chat vs. embedding quotas. + +These are conservative real caps, not sentinels. Operators with +high-tier cloud accounts should raise the cap rather than disable the +gate: + +```bash +SOURCEBRIDGE_LLM_PROVIDER_OPENAI_MAX_CONCURRENT=64 +``` + +The hard ceiling is 256 concurrent calls (enforced at the gate, the +Go-side adapter clamp, and the SurrealDB `ASSERT` constraint). + +### Ollama: one knob covers everything + +For Ollama, `SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT` is the +**only** concurrency knob needed. The host gate combines LLM and +embedding calls against the same normalized origin +(`http://localhost:11434` regardless of whether the SDK uses +`/v1` or `/api`), so both kinds share one semaphore. There is no +separate `SOURCEBRIDGE_EMBEDDING_PROVIDER_OLLAMA_MAX_CONCURRENT` — it +is ignored for host-gated providers. + +Set it to match `OLLAMA_NUM_PARALLEL` on the Ollama daemon. + +### Capacity contract + +After this refactor, `GetProviderCapabilities.max_concurrent_calls` is +sourced from the gate registry's effective cap for the **resolved +context** `(provider, base_url)` — not from the bootstrap config. The +Go orchestrator at `internal/qa/lazy_agent_synth.go:340` clamps +`MaxConcurrency` to this value. Setting the per-provider env var changes +both the worker's semaphore and the orchestrator's clamp in one step. + +When the wrapper is disabled via the kill switch, the legacy +`SOURCEBRIDGE_WORKER_LLM_MAX_CONCURRENT_CALLS` / `SOURCEBRIDGE_LLM_PARALLEL_HINT` +path is used instead (same behavior as before this refactor). + +### RPM values for tier-1 cloud accounts + +The defaults ship with no RPM limiter (`None`). Operators on high-tier +accounts can layer on an RPM limit to prevent hitting provider-side +rate ceilings under burst load. + +| Provider | Tier | Recommended env var | +|---|---|---| +| OpenAI | Tier 4 | `SOURCEBRIDGE_LLM_PROVIDER_OPENAI_RPM=10000` (chat models; per-model limits vary for embeddings — check the OpenAI usage dashboard) | +| Anthropic | Tier 2 | `SOURCEBRIDGE_LLM_PROVIDER_ANTHROPIC_RPM=4000` (Claude 3.5 Sonnet; other models differ) | +| Gemini | Tier 2 / Pro | per-model — see [Google's rate limit documentation](https://ai.google.dev/gemini-api/docs/rate-limits); set `SOURCEBRIDGE_LLM_PROVIDER_GEMINI_RPM=` | +| OpenRouter | varies | leave unset; OpenRouter enforces its own rate limits and returns 429s; the retry wrapper handles them | + +### Server-side companion knobs (Ollama) + +These variables go on the **Ollama daemon**, not in SourceBridge. They +are the dominant bottleneck for Ollama throughput. Investigation +`thoughts/shared/investigations/2026-05-06-diagnose-llm-throughput-rotten.md` +confirmed that `OLLAMA_NUM_PARALLEL=1` (the stock default) is the +single largest throughput bottleneck — more impactful than any +SourceBridge-side tuning. + +| Ollama env var | Recommended value | Effect | +|---|---|---| +| `OLLAMA_NUM_PARALLEL` | `4`–`8` (sufficient RAM) / `2`–`4` (16–32 GB) / `1` (≤8 GB) | Max in-flight requests per daemon. **Set `SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT` to the same value.** | +| `OLLAMA_KEEP_ALIVE` | `-1` or `24h` | Prevents model unload between Living Wiki pages and between consecutive jobs. Default `5m` causes full reload latency (~30–90 s) at the start of each page after idle. | +| `OLLAMA_MAX_LOADED_MODELS` | number of distinct models the workload uses + 1 | Prevents model thrashing when `OLLAMA_NUM_PARALLEL > 1` and multiple models are configured. Default `1` is conservative; raise when you have LLM + embedding models that must coexist in VRAM. | + +Set these in your Ollama service file (e.g., +`/etc/systemd/system/ollama.service` `[Service] Environment=...` or +`/Library/LaunchDaemons/com.ollama.serve.plist`) and restart the +daemon. Then update `SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT` +to match the new `OLLAMA_NUM_PARALLEL`. + +### Where to look in the UI + +Go to **Admin → Monitor** (`/admin/monitor`) on your SourceBridge +instance. + +The **"LLM Gate Activity"** section shows live counters per active gate: + +| Column | What it means | +|---|---| +| Provider / endpoint | Gate key: provider name + normalized base URL | +| Kind | `llm` or `embedding` | +| In-flight / cap | Current in-flight calls vs. the gate's `max_concurrent` | +| Queued | Calls waiting for a slot (Decision 11 waiter counter) | +| tok/s | 60-second rolling tokens-per-second for this gate | +| 429s | Rate-limit errors since the gate was created | +| Retries | Total tenacity retry attempts since start | + +The per-job **tok/s pill** on each `ActiveJobCard` shows live +throughput for that specific Living Wiki generation run. + +The underlying data comes from `GET /api/v1/admin/llm/activity` +(`gate_snapshot` field), populated by the `GetLLMGateSnapshot` gRPC +method on the worker's `ReasoningService`. + +--- + ## Living Wiki page-count and ops behavior For how Living Wiki determines the number of pages to generate, how diff --git a/gen/go/common/v1/knowledge_progress.pb.go b/gen/go/common/v1/knowledge_progress.pb.go index a6f29054..82e629d9 100644 --- a/gen/go/common/v1/knowledge_progress.pb.go +++ b/gen/go/common/v1/knowledge_progress.pb.go @@ -139,8 +139,14 @@ type KnowledgeStreamProgress struct { FileCacheHits int32 `protobuf:"varint,11,opt,name=file_cache_hits,json=fileCacheHits,proto3" json:"file_cache_hits,omitempty"` PackageCacheHits int32 `protobuf:"varint,12,opt,name=package_cache_hits,json=packageCacheHits,proto3" json:"package_cache_hits,omitempty"` RootCacheHits int32 `protobuf:"varint,13,opt,name=root_cache_hits,json=rootCacheHits,proto3" json:"root_cache_hits,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Instantaneous throughput from the active LLM gate's 60-second ring + // buffer at the moment this progress event was emitted. Zero when the + // gate has not recorded any completions yet (cold start, non-streaming + // path, or kill-switch disabled). Consumers MUST treat zero as "unknown", + // not as "zero tokens per second". Added Phase 6 (worker LLM concurrency). + CurrentTokensPerSecond float32 `protobuf:"fixed32,14,opt,name=current_tokens_per_second,json=currentTokensPerSecond,proto3" json:"current_tokens_per_second,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *KnowledgeStreamProgress) Reset() { @@ -236,6 +242,13 @@ func (x *KnowledgeStreamProgress) GetRootCacheHits() int32 { return 0 } +func (x *KnowledgeStreamProgress) GetCurrentTokensPerSecond() float32 { + if x != nil { + return x.CurrentTokensPerSecond + } + return 0 +} + // KnowledgeStreamPhaseMarker fires once per phase transition so the // client can update the phase label deterministically. The detail // field is optional context (for example the strategy name on the @@ -296,7 +309,7 @@ var File_common_v1_knowledge_progress_proto protoreflect.FileDescriptor const file_common_v1_knowledge_progress_proto_rawDesc = "" + "\n" + - "\"common/v1/knowledge_progress.proto\x12\x16sourcebridge.common.v1\"\xfe\x02\n" + + "\"common/v1/knowledge_progress.proto\x12\x16sourcebridge.common.v1\"\xb9\x03\n" + "\x17KnowledgeStreamProgress\x12<\n" + "\x05phase\x18\x01 \x01(\x0e2&.sourcebridge.common.v1.KnowledgePhaseR\x05phase\x12'\n" + "\x0fcompleted_units\x18\x02 \x01(\x05R\x0ecompletedUnits\x12\x1f\n" + @@ -308,7 +321,8 @@ const file_common_v1_knowledge_progress_proto_rawDesc = "" + " \x01(\x05R\rleafCacheHits\x12&\n" + "\x0ffile_cache_hits\x18\v \x01(\x05R\rfileCacheHits\x12,\n" + "\x12package_cache_hits\x18\f \x01(\x05R\x10packageCacheHits\x12&\n" + - "\x0froot_cache_hits\x18\r \x01(\x05R\rrootCacheHits\"r\n" + + "\x0froot_cache_hits\x18\r \x01(\x05R\rrootCacheHits\x129\n" + + "\x19current_tokens_per_second\x18\x0e \x01(\x02R\x16currentTokensPerSecond\"r\n" + "\x1aKnowledgeStreamPhaseMarker\x12<\n" + "\x05phase\x18\x01 \x01(\x0e2&.sourcebridge.common.v1.KnowledgePhaseR\x05phase\x12\x16\n" + "\x06detail\x18\x02 \x01(\tR\x06detail*\x9e\x02\n" + diff --git a/gen/go/reasoning/v1/reasoning.pb.go b/gen/go/reasoning/v1/reasoning.pb.go index fb40204a..98b72582 100644 --- a/gen/go/reasoning/v1/reasoning.pb.go +++ b/gen/go/reasoning/v1/reasoning.pb.go @@ -2467,6 +2467,222 @@ func (x *GetProviderCapabilitiesResponse) GetMaxConcurrentCallsKnown() bool { return false } +// GetLLMGateSnapshotRequest is intentionally empty (not google.protobuf.Empty) +// for forward compatibility — future filter fields (provider, kind) can be +// added here without redeclaring the RPC. +type GetLLMGateSnapshotRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetLLMGateSnapshotRequest) Reset() { + *x = GetLLMGateSnapshotRequest{} + mi := &file_reasoning_v1_reasoning_proto_msgTypes[32] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetLLMGateSnapshotRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetLLMGateSnapshotRequest) ProtoMessage() {} + +func (x *GetLLMGateSnapshotRequest) ProtoReflect() protoreflect.Message { + mi := &file_reasoning_v1_reasoning_proto_msgTypes[32] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetLLMGateSnapshotRequest.ProtoReflect.Descriptor instead. +func (*GetLLMGateSnapshotRequest) Descriptor() ([]byte, []int) { + return file_reasoning_v1_reasoning_proto_rawDescGZIP(), []int{32} +} + +type GetLLMGateSnapshotResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Gates []*LLMGateEntry `protobuf:"bytes,1,rep,name=gates,proto3" json:"gates,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetLLMGateSnapshotResponse) Reset() { + *x = GetLLMGateSnapshotResponse{} + mi := &file_reasoning_v1_reasoning_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetLLMGateSnapshotResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetLLMGateSnapshotResponse) ProtoMessage() {} + +func (x *GetLLMGateSnapshotResponse) ProtoReflect() protoreflect.Message { + mi := &file_reasoning_v1_reasoning_proto_msgTypes[33] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetLLMGateSnapshotResponse.ProtoReflect.Descriptor instead. +func (*GetLLMGateSnapshotResponse) Descriptor() ([]byte, []int) { + return file_reasoning_v1_reasoning_proto_rawDescGZIP(), []int{33} +} + +func (x *GetLLMGateSnapshotResponse) GetGates() []*LLMGateEntry { + if x != nil { + return x.Gates + } + return nil +} + +// LLMGateEntry is one row in the gate snapshot — one entry per +// (provider, base_url_normalized, kind) triple that has been +// registered in the gate registry since the worker started. +type LLMGateEntry struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Canonical provider name (e.g. "ollama", "openai"). + Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` + // Normalized origin URL — "scheme://host:port" with path/query stripped. + // Host-gated providers collapse to the daemon origin; frontier providers + // use the raw base_url per Decision 1 in the concurrency plan. + BaseUrlNormalized string `protobuf:"bytes,2,opt,name=base_url_normalized,json=baseUrlNormalized,proto3" json:"base_url_normalized,omitempty"` + // "llm" or "embedding". Host-gated providers (Ollama, vLLM, …) emit one + // row per kind sharing a single max_concurrent and tokens_per_second. + Kind string `protobuf:"bytes,3,opt,name=kind,proto3" json:"kind,omitempty"` + // Current number of LLM calls executing inside the semaphore right now. + InFlight int32 `protobuf:"varint,4,opt,name=in_flight,json=inFlight,proto3" json:"in_flight,omitempty"` + // Number of callers waiting for a semaphore slot (Decision 11 waiter count). + Queued int32 `protobuf:"varint,5,opt,name=queued,proto3" json:"queued,omitempty"` + // Effective semaphore size — the cap enforced by this gate. + MaxConcurrent int32 `protobuf:"varint,6,opt,name=max_concurrent,json=maxConcurrent,proto3" json:"max_concurrent,omitempty"` + // Cumulative retry count since the worker started. + RetriesSinceStart int64 `protobuf:"varint,7,opt,name=retries_since_start,json=retriesSinceStart,proto3" json:"retries_since_start,omitempty"` + // Cumulative 429 / RateLimitError count since the worker started. + Recent_429Count int64 `protobuf:"varint,8,opt,name=recent_429_count,json=recent429Count,proto3" json:"recent_429_count,omitempty"` + // Output tokens per second averaged over the last 60 seconds (60-second + // ring buffer). Zero means no completions have been recorded yet. + TokensPerSecond float64 `protobuf:"fixed64,9,opt,name=tokens_per_second,json=tokensPerSecond,proto3" json:"tokens_per_second,omitempty"` + // Configured requests-per-minute limit for this gate; 0 = no RPM limiter. + Rpm int32 `protobuf:"varint,10,opt,name=rpm,proto3" json:"rpm,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LLMGateEntry) Reset() { + *x = LLMGateEntry{} + mi := &file_reasoning_v1_reasoning_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LLMGateEntry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LLMGateEntry) ProtoMessage() {} + +func (x *LLMGateEntry) ProtoReflect() protoreflect.Message { + mi := &file_reasoning_v1_reasoning_proto_msgTypes[34] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LLMGateEntry.ProtoReflect.Descriptor instead. +func (*LLMGateEntry) Descriptor() ([]byte, []int) { + return file_reasoning_v1_reasoning_proto_rawDescGZIP(), []int{34} +} + +func (x *LLMGateEntry) GetProvider() string { + if x != nil { + return x.Provider + } + return "" +} + +func (x *LLMGateEntry) GetBaseUrlNormalized() string { + if x != nil { + return x.BaseUrlNormalized + } + return "" +} + +func (x *LLMGateEntry) GetKind() string { + if x != nil { + return x.Kind + } + return "" +} + +func (x *LLMGateEntry) GetInFlight() int32 { + if x != nil { + return x.InFlight + } + return 0 +} + +func (x *LLMGateEntry) GetQueued() int32 { + if x != nil { + return x.Queued + } + return 0 +} + +func (x *LLMGateEntry) GetMaxConcurrent() int32 { + if x != nil { + return x.MaxConcurrent + } + return 0 +} + +func (x *LLMGateEntry) GetRetriesSinceStart() int64 { + if x != nil { + return x.RetriesSinceStart + } + return 0 +} + +func (x *LLMGateEntry) GetRecent_429Count() int64 { + if x != nil { + return x.Recent_429Count + } + return 0 +} + +func (x *LLMGateEntry) GetTokensPerSecond() float64 { + if x != nil { + return x.TokensPerSecond + } + return 0 +} + +func (x *LLMGateEntry) GetRpm() int32 { + if x != nil { + return x.Rpm + } + return 0 +} + var File_reasoning_v1_reasoning_proto protoreflect.FileDescriptor const file_reasoning_v1_reasoning_proto_rawDesc = "" + @@ -2668,7 +2884,22 @@ const file_reasoning_v1_reasoning_proto_rawDesc = "" + "\x12tool_use_supported\x18\x03 \x01(\bR\x10toolUseSupported\x128\n" + "\x18prompt_caching_supported\x18\x04 \x01(\bR\x16promptCachingSupported\x120\n" + "\x14max_concurrent_calls\x18\x05 \x01(\x05R\x12maxConcurrentCalls\x12;\n" + - "\x1amax_concurrent_calls_known\x18\x06 \x01(\bR\x17maxConcurrentCallsKnown2\xb1\f\n" + + "\x1amax_concurrent_calls_known\x18\x06 \x01(\bR\x17maxConcurrentCallsKnown\"\x1b\n" + + "\x19GetLLMGateSnapshotRequest\"[\n" + + "\x1aGetLLMGateSnapshotResponse\x12=\n" + + "\x05gates\x18\x01 \x03(\v2'.sourcebridge.reasoning.v1.LLMGateEntryR\x05gates\"\xe2\x02\n" + + "\fLLMGateEntry\x12\x1a\n" + + "\bprovider\x18\x01 \x01(\tR\bprovider\x12.\n" + + "\x13base_url_normalized\x18\x02 \x01(\tR\x11baseUrlNormalized\x12\x12\n" + + "\x04kind\x18\x03 \x01(\tR\x04kind\x12\x1b\n" + + "\tin_flight\x18\x04 \x01(\x05R\binFlight\x12\x16\n" + + "\x06queued\x18\x05 \x01(\x05R\x06queued\x12%\n" + + "\x0emax_concurrent\x18\x06 \x01(\x05R\rmaxConcurrent\x12.\n" + + "\x13retries_since_start\x18\a \x01(\x03R\x11retriesSinceStart\x12(\n" + + "\x10recent_429_count\x18\b \x01(\x03R\x0erecent429Count\x12*\n" + + "\x11tokens_per_second\x18\t \x01(\x01R\x0ftokensPerSecond\x12\x10\n" + + "\x03rpm\x18\n" + + " \x01(\x05R\x03rpm2\xb5\r\n" + "\x10ReasoningService\x12r\n" + "\rAnalyzeSymbol\x12/.sourcebridge.reasoning.v1.AnalyzeSymbolRequest\x1a0.sourcebridge.reasoning.v1.AnalyzeSymbolResponse\x12\x84\x01\n" + "\x13ExplainRelationship\x125.sourcebridge.reasoning.v1.ExplainRelationshipRequest\x1a6.sourcebridge.reasoning.v1.ExplainRelationshipResponse\x12u\n" + @@ -2679,7 +2910,8 @@ const file_reasoning_v1_reasoning_proto_rawDesc = "" + "\x11GenerateEmbedding\x123.sourcebridge.reasoning.v1.GenerateEmbeddingRequest\x1a4.sourcebridge.reasoning.v1.GenerateEmbeddingResponse\x12u\n" + "\x0eSimulateChange\x120.sourcebridge.reasoning.v1.SimulateChangeRequest\x1a1.sourcebridge.reasoning.v1.SimulateChangeResponse\x12\x90\x01\n" + "\x17AnswerQuestionWithTools\x129.sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest\x1a:.sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse\x12\x90\x01\n" + - "\x17GetProviderCapabilities\x129.sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest\x1a:.sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse\x12{\n" + + "\x17GetProviderCapabilities\x129.sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest\x1a:.sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse\x12\x81\x01\n" + + "\x12GetLLMGateSnapshot\x124.sourcebridge.reasoning.v1.GetLLMGateSnapshotRequest\x1a5.sourcebridge.reasoning.v1.GetLLMGateSnapshotResponse\x12{\n" + "\x10ClassifyQuestion\x122.sourcebridge.reasoning.v1.ClassifyQuestionRequest\x1a3.sourcebridge.reasoning.v1.ClassifyQuestionResponse\x12~\n" + "\x11DecomposeQuestion\x123.sourcebridge.reasoning.v1.DecomposeQuestionRequest\x1a4.sourcebridge.reasoning.v1.DecomposeQuestionResponse\x12\x99\x01\n" + "\x1aSynthesizeDecomposedAnswer\x12<.sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerRequest\x1a=.sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponseBFZDgithub.com/sourcebridge/sourcebridge/gen/go/reasoning/v1;reasoningv1b\x06proto3" @@ -2696,7 +2928,7 @@ func file_reasoning_v1_reasoning_proto_rawDescGZIP() []byte { return file_reasoning_v1_reasoning_proto_rawDescData } -var file_reasoning_v1_reasoning_proto_msgTypes = make([]protoimpl.MessageInfo, 32) +var file_reasoning_v1_reasoning_proto_msgTypes = make([]protoimpl.MessageInfo, 35) var file_reasoning_v1_reasoning_proto_goTypes = []any{ (*AnalyzeSymbolRequest)(nil), // 0: sourcebridge.reasoning.v1.AnalyzeSymbolRequest (*AnalyzeSymbolResponse)(nil), // 1: sourcebridge.reasoning.v1.AnalyzeSymbolResponse @@ -2730,75 +2962,81 @@ var file_reasoning_v1_reasoning_proto_goTypes = []any{ (*SynthesizeDecomposedAnswerResponse)(nil), // 29: sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponse (*GetProviderCapabilitiesRequest)(nil), // 30: sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest (*GetProviderCapabilitiesResponse)(nil), // 31: sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse - (*v1.CodeSymbol)(nil), // 32: sourcebridge.common.v1.CodeSymbol - (*v1.LLMUsage)(nil), // 33: sourcebridge.common.v1.LLMUsage - (v1.Confidence)(0), // 34: sourcebridge.common.v1.Confidence - (v1.Language)(0), // 35: sourcebridge.common.v1.Language - (*v1.Embedding)(nil), // 36: sourcebridge.common.v1.Embedding + (*GetLLMGateSnapshotRequest)(nil), // 32: sourcebridge.reasoning.v1.GetLLMGateSnapshotRequest + (*GetLLMGateSnapshotResponse)(nil), // 33: sourcebridge.reasoning.v1.GetLLMGateSnapshotResponse + (*LLMGateEntry)(nil), // 34: sourcebridge.reasoning.v1.LLMGateEntry + (*v1.CodeSymbol)(nil), // 35: sourcebridge.common.v1.CodeSymbol + (*v1.LLMUsage)(nil), // 36: sourcebridge.common.v1.LLMUsage + (v1.Confidence)(0), // 37: sourcebridge.common.v1.Confidence + (v1.Language)(0), // 38: sourcebridge.common.v1.Language + (*v1.Embedding)(nil), // 39: sourcebridge.common.v1.Embedding } var file_reasoning_v1_reasoning_proto_depIdxs = []int32{ - 32, // 0: sourcebridge.reasoning.v1.AnalyzeSymbolRequest.symbol:type_name -> sourcebridge.common.v1.CodeSymbol - 33, // 1: sourcebridge.reasoning.v1.AnalyzeSymbolResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 32, // 2: sourcebridge.reasoning.v1.ExplainRelationshipRequest.source:type_name -> sourcebridge.common.v1.CodeSymbol - 32, // 3: sourcebridge.reasoning.v1.ExplainRelationshipRequest.target:type_name -> sourcebridge.common.v1.CodeSymbol - 34, // 4: sourcebridge.reasoning.v1.ExplainRelationshipResponse.confidence:type_name -> sourcebridge.common.v1.Confidence - 33, // 5: sourcebridge.reasoning.v1.ExplainRelationshipResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 32, // 6: sourcebridge.reasoning.v1.AnswerQuestionRequest.context_symbols:type_name -> sourcebridge.common.v1.CodeSymbol - 35, // 7: sourcebridge.reasoning.v1.AnswerQuestionRequest.language:type_name -> sourcebridge.common.v1.Language - 32, // 8: sourcebridge.reasoning.v1.AnswerQuestionResponse.referenced_symbols:type_name -> sourcebridge.common.v1.CodeSymbol - 33, // 9: sourcebridge.reasoning.v1.AnswerQuestionResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 32, // 10: sourcebridge.reasoning.v1.AnswerQuestionStreamRequest.context_symbols:type_name -> sourcebridge.common.v1.CodeSymbol - 35, // 11: sourcebridge.reasoning.v1.AnswerQuestionStreamRequest.language:type_name -> sourcebridge.common.v1.Language - 32, // 12: sourcebridge.reasoning.v1.AnswerQuestionStreamResponse.referenced_symbols:type_name -> sourcebridge.common.v1.CodeSymbol - 33, // 13: sourcebridge.reasoning.v1.AnswerQuestionStreamResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 35, // 0: sourcebridge.reasoning.v1.AnalyzeSymbolRequest.symbol:type_name -> sourcebridge.common.v1.CodeSymbol + 36, // 1: sourcebridge.reasoning.v1.AnalyzeSymbolResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 35, // 2: sourcebridge.reasoning.v1.ExplainRelationshipRequest.source:type_name -> sourcebridge.common.v1.CodeSymbol + 35, // 3: sourcebridge.reasoning.v1.ExplainRelationshipRequest.target:type_name -> sourcebridge.common.v1.CodeSymbol + 37, // 4: sourcebridge.reasoning.v1.ExplainRelationshipResponse.confidence:type_name -> sourcebridge.common.v1.Confidence + 36, // 5: sourcebridge.reasoning.v1.ExplainRelationshipResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 35, // 6: sourcebridge.reasoning.v1.AnswerQuestionRequest.context_symbols:type_name -> sourcebridge.common.v1.CodeSymbol + 38, // 7: sourcebridge.reasoning.v1.AnswerQuestionRequest.language:type_name -> sourcebridge.common.v1.Language + 35, // 8: sourcebridge.reasoning.v1.AnswerQuestionResponse.referenced_symbols:type_name -> sourcebridge.common.v1.CodeSymbol + 36, // 9: sourcebridge.reasoning.v1.AnswerQuestionResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 35, // 10: sourcebridge.reasoning.v1.AnswerQuestionStreamRequest.context_symbols:type_name -> sourcebridge.common.v1.CodeSymbol + 38, // 11: sourcebridge.reasoning.v1.AnswerQuestionStreamRequest.language:type_name -> sourcebridge.common.v1.Language + 35, // 12: sourcebridge.reasoning.v1.AnswerQuestionStreamResponse.referenced_symbols:type_name -> sourcebridge.common.v1.CodeSymbol + 36, // 13: sourcebridge.reasoning.v1.AnswerQuestionStreamResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage 8, // 14: sourcebridge.reasoning.v1.AnswerQuestionStreamResponse.progress:type_name -> sourcebridge.reasoning.v1.ProgressEvent - 35, // 15: sourcebridge.reasoning.v1.ReviewFileRequest.language:type_name -> sourcebridge.common.v1.Language + 38, // 15: sourcebridge.reasoning.v1.ReviewFileRequest.language:type_name -> sourcebridge.common.v1.Language 11, // 16: sourcebridge.reasoning.v1.ReviewFileResponse.findings:type_name -> sourcebridge.reasoning.v1.ReviewFinding - 33, // 17: sourcebridge.reasoning.v1.ReviewFileResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 36, // 18: sourcebridge.reasoning.v1.GenerateEmbeddingResponse.embedding:type_name -> sourcebridge.common.v1.Embedding - 33, // 19: sourcebridge.reasoning.v1.GenerateEmbeddingResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 32, // 20: sourcebridge.reasoning.v1.SimulateChangeRequest.symbols:type_name -> sourcebridge.common.v1.CodeSymbol + 36, // 17: sourcebridge.reasoning.v1.ReviewFileResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 39, // 18: sourcebridge.reasoning.v1.GenerateEmbeddingResponse.embedding:type_name -> sourcebridge.common.v1.Embedding + 36, // 19: sourcebridge.reasoning.v1.GenerateEmbeddingResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 35, // 20: sourcebridge.reasoning.v1.SimulateChangeRequest.symbols:type_name -> sourcebridge.common.v1.CodeSymbol 15, // 21: sourcebridge.reasoning.v1.SimulateChangeResponse.resolved_symbols:type_name -> sourcebridge.reasoning.v1.SimulatedSymbolMatch - 33, // 22: sourcebridge.reasoning.v1.SimulateChangeResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 36, // 22: sourcebridge.reasoning.v1.SimulateChangeResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage 19, // 23: sourcebridge.reasoning.v1.AgentMessage.tool_calls:type_name -> sourcebridge.reasoning.v1.ToolCall 20, // 24: sourcebridge.reasoning.v1.AgentMessage.tool_results:type_name -> sourcebridge.reasoning.v1.ToolResult 18, // 25: sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest.messages:type_name -> sourcebridge.reasoning.v1.AgentMessage 17, // 26: sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest.tools:type_name -> sourcebridge.reasoning.v1.ToolSchema 18, // 27: sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse.turn:type_name -> sourcebridge.reasoning.v1.AgentMessage - 33, // 28: sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 33, // 29: sourcebridge.reasoning.v1.ClassifyQuestionResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 33, // 30: sourcebridge.reasoning.v1.DecomposeQuestionResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 36, // 28: sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 36, // 29: sourcebridge.reasoning.v1.ClassifyQuestionResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 36, // 30: sourcebridge.reasoning.v1.DecomposeQuestionResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage 28, // 31: sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerRequest.sub_answers:type_name -> sourcebridge.reasoning.v1.DecomposedSubAnswer - 33, // 32: sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage - 0, // 33: sourcebridge.reasoning.v1.ReasoningService.AnalyzeSymbol:input_type -> sourcebridge.reasoning.v1.AnalyzeSymbolRequest - 2, // 34: sourcebridge.reasoning.v1.ReasoningService.ExplainRelationship:input_type -> sourcebridge.reasoning.v1.ExplainRelationshipRequest - 4, // 35: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestion:input_type -> sourcebridge.reasoning.v1.AnswerQuestionRequest - 6, // 36: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionStream:input_type -> sourcebridge.reasoning.v1.AnswerQuestionStreamRequest - 9, // 37: sourcebridge.reasoning.v1.ReasoningService.ReviewFile:input_type -> sourcebridge.reasoning.v1.ReviewFileRequest - 12, // 38: sourcebridge.reasoning.v1.ReasoningService.GenerateEmbedding:input_type -> sourcebridge.reasoning.v1.GenerateEmbeddingRequest - 14, // 39: sourcebridge.reasoning.v1.ReasoningService.SimulateChange:input_type -> sourcebridge.reasoning.v1.SimulateChangeRequest - 21, // 40: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionWithTools:input_type -> sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest - 30, // 41: sourcebridge.reasoning.v1.ReasoningService.GetProviderCapabilities:input_type -> sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest - 23, // 42: sourcebridge.reasoning.v1.ReasoningService.ClassifyQuestion:input_type -> sourcebridge.reasoning.v1.ClassifyQuestionRequest - 25, // 43: sourcebridge.reasoning.v1.ReasoningService.DecomposeQuestion:input_type -> sourcebridge.reasoning.v1.DecomposeQuestionRequest - 27, // 44: sourcebridge.reasoning.v1.ReasoningService.SynthesizeDecomposedAnswer:input_type -> sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerRequest - 1, // 45: sourcebridge.reasoning.v1.ReasoningService.AnalyzeSymbol:output_type -> sourcebridge.reasoning.v1.AnalyzeSymbolResponse - 3, // 46: sourcebridge.reasoning.v1.ReasoningService.ExplainRelationship:output_type -> sourcebridge.reasoning.v1.ExplainRelationshipResponse - 5, // 47: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestion:output_type -> sourcebridge.reasoning.v1.AnswerQuestionResponse - 7, // 48: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionStream:output_type -> sourcebridge.reasoning.v1.AnswerQuestionStreamResponse - 10, // 49: sourcebridge.reasoning.v1.ReasoningService.ReviewFile:output_type -> sourcebridge.reasoning.v1.ReviewFileResponse - 13, // 50: sourcebridge.reasoning.v1.ReasoningService.GenerateEmbedding:output_type -> sourcebridge.reasoning.v1.GenerateEmbeddingResponse - 16, // 51: sourcebridge.reasoning.v1.ReasoningService.SimulateChange:output_type -> sourcebridge.reasoning.v1.SimulateChangeResponse - 22, // 52: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionWithTools:output_type -> sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse - 31, // 53: sourcebridge.reasoning.v1.ReasoningService.GetProviderCapabilities:output_type -> sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse - 24, // 54: sourcebridge.reasoning.v1.ReasoningService.ClassifyQuestion:output_type -> sourcebridge.reasoning.v1.ClassifyQuestionResponse - 26, // 55: sourcebridge.reasoning.v1.ReasoningService.DecomposeQuestion:output_type -> sourcebridge.reasoning.v1.DecomposeQuestionResponse - 29, // 56: sourcebridge.reasoning.v1.ReasoningService.SynthesizeDecomposedAnswer:output_type -> sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponse - 45, // [45:57] is the sub-list for method output_type - 33, // [33:45] is the sub-list for method input_type - 33, // [33:33] is the sub-list for extension type_name - 33, // [33:33] is the sub-list for extension extendee - 0, // [0:33] is the sub-list for field type_name + 36, // 32: sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponse.usage:type_name -> sourcebridge.common.v1.LLMUsage + 34, // 33: sourcebridge.reasoning.v1.GetLLMGateSnapshotResponse.gates:type_name -> sourcebridge.reasoning.v1.LLMGateEntry + 0, // 34: sourcebridge.reasoning.v1.ReasoningService.AnalyzeSymbol:input_type -> sourcebridge.reasoning.v1.AnalyzeSymbolRequest + 2, // 35: sourcebridge.reasoning.v1.ReasoningService.ExplainRelationship:input_type -> sourcebridge.reasoning.v1.ExplainRelationshipRequest + 4, // 36: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestion:input_type -> sourcebridge.reasoning.v1.AnswerQuestionRequest + 6, // 37: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionStream:input_type -> sourcebridge.reasoning.v1.AnswerQuestionStreamRequest + 9, // 38: sourcebridge.reasoning.v1.ReasoningService.ReviewFile:input_type -> sourcebridge.reasoning.v1.ReviewFileRequest + 12, // 39: sourcebridge.reasoning.v1.ReasoningService.GenerateEmbedding:input_type -> sourcebridge.reasoning.v1.GenerateEmbeddingRequest + 14, // 40: sourcebridge.reasoning.v1.ReasoningService.SimulateChange:input_type -> sourcebridge.reasoning.v1.SimulateChangeRequest + 21, // 41: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionWithTools:input_type -> sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest + 30, // 42: sourcebridge.reasoning.v1.ReasoningService.GetProviderCapabilities:input_type -> sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest + 32, // 43: sourcebridge.reasoning.v1.ReasoningService.GetLLMGateSnapshot:input_type -> sourcebridge.reasoning.v1.GetLLMGateSnapshotRequest + 23, // 44: sourcebridge.reasoning.v1.ReasoningService.ClassifyQuestion:input_type -> sourcebridge.reasoning.v1.ClassifyQuestionRequest + 25, // 45: sourcebridge.reasoning.v1.ReasoningService.DecomposeQuestion:input_type -> sourcebridge.reasoning.v1.DecomposeQuestionRequest + 27, // 46: sourcebridge.reasoning.v1.ReasoningService.SynthesizeDecomposedAnswer:input_type -> sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerRequest + 1, // 47: sourcebridge.reasoning.v1.ReasoningService.AnalyzeSymbol:output_type -> sourcebridge.reasoning.v1.AnalyzeSymbolResponse + 3, // 48: sourcebridge.reasoning.v1.ReasoningService.ExplainRelationship:output_type -> sourcebridge.reasoning.v1.ExplainRelationshipResponse + 5, // 49: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestion:output_type -> sourcebridge.reasoning.v1.AnswerQuestionResponse + 7, // 50: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionStream:output_type -> sourcebridge.reasoning.v1.AnswerQuestionStreamResponse + 10, // 51: sourcebridge.reasoning.v1.ReasoningService.ReviewFile:output_type -> sourcebridge.reasoning.v1.ReviewFileResponse + 13, // 52: sourcebridge.reasoning.v1.ReasoningService.GenerateEmbedding:output_type -> sourcebridge.reasoning.v1.GenerateEmbeddingResponse + 16, // 53: sourcebridge.reasoning.v1.ReasoningService.SimulateChange:output_type -> sourcebridge.reasoning.v1.SimulateChangeResponse + 22, // 54: sourcebridge.reasoning.v1.ReasoningService.AnswerQuestionWithTools:output_type -> sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse + 31, // 55: sourcebridge.reasoning.v1.ReasoningService.GetProviderCapabilities:output_type -> sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse + 33, // 56: sourcebridge.reasoning.v1.ReasoningService.GetLLMGateSnapshot:output_type -> sourcebridge.reasoning.v1.GetLLMGateSnapshotResponse + 24, // 57: sourcebridge.reasoning.v1.ReasoningService.ClassifyQuestion:output_type -> sourcebridge.reasoning.v1.ClassifyQuestionResponse + 26, // 58: sourcebridge.reasoning.v1.ReasoningService.DecomposeQuestion:output_type -> sourcebridge.reasoning.v1.DecomposeQuestionResponse + 29, // 59: sourcebridge.reasoning.v1.ReasoningService.SynthesizeDecomposedAnswer:output_type -> sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponse + 47, // [47:60] is the sub-list for method output_type + 34, // [34:47] is the sub-list for method input_type + 34, // [34:34] is the sub-list for extension type_name + 34, // [34:34] is the sub-list for extension extendee + 0, // [0:34] is the sub-list for field type_name } func init() { file_reasoning_v1_reasoning_proto_init() } @@ -2812,7 +3050,7 @@ func file_reasoning_v1_reasoning_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_reasoning_v1_reasoning_proto_rawDesc), len(file_reasoning_v1_reasoning_proto_rawDesc)), NumEnums: 0, - NumMessages: 32, + NumMessages: 35, NumExtensions: 0, NumServices: 1, }, diff --git a/gen/go/reasoning/v1/reasoning_grpc.pb.go b/gen/go/reasoning/v1/reasoning_grpc.pb.go index d308a3e5..498feb9b 100644 --- a/gen/go/reasoning/v1/reasoning_grpc.pb.go +++ b/gen/go/reasoning/v1/reasoning_grpc.pb.go @@ -28,6 +28,7 @@ const ( ReasoningService_SimulateChange_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/SimulateChange" ReasoningService_AnswerQuestionWithTools_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/AnswerQuestionWithTools" ReasoningService_GetProviderCapabilities_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/GetProviderCapabilities" + ReasoningService_GetLLMGateSnapshot_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/GetLLMGateSnapshot" ReasoningService_ClassifyQuestion_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/ClassifyQuestion" ReasoningService_DecomposeQuestion_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/DecomposeQuestion" ReasoningService_SynthesizeDecomposedAnswer_FullMethodName = "/sourcebridge.reasoning.v1.ReasoningService/SynthesizeDecomposedAnswer" @@ -78,6 +79,15 @@ type ReasoningServiceClient interface { // and caches the result so the agentic path can be gated without // a per-request round-trip. GetProviderCapabilities(ctx context.Context, in *GetProviderCapabilitiesRequest, opts ...grpc.CallOption) (*GetProviderCapabilitiesResponse, error) + // GetLLMGateSnapshot returns a point-in-time snapshot of all active + // per-provider concurrency gates in the worker. Used by the admin + // /api/v1/admin/llm/activity endpoint to surface real-time gate + // counters (in-flight, queued, tok/s) without a per-request round-trip. + // + // An explicit request struct (rather than google.protobuf.Empty) is + // used for forward compatibility: filter fields (provider, kind) can + // be added without redeclaring the RPC. + GetLLMGateSnapshot(ctx context.Context, in *GetLLMGateSnapshotRequest, opts ...grpc.CallOption) (*GetLLMGateSnapshotResponse, error) // ClassifyQuestion runs a cheap LLM classifier (Haiku) that // returns the question's likely class plus evidence-kind hints // (needs_call_graph, needs_tests, ...) and advisory symbol / @@ -210,6 +220,16 @@ func (c *reasoningServiceClient) GetProviderCapabilities(ctx context.Context, in return out, nil } +func (c *reasoningServiceClient) GetLLMGateSnapshot(ctx context.Context, in *GetLLMGateSnapshotRequest, opts ...grpc.CallOption) (*GetLLMGateSnapshotResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetLLMGateSnapshotResponse) + err := c.cc.Invoke(ctx, ReasoningService_GetLLMGateSnapshot_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *reasoningServiceClient) ClassifyQuestion(ctx context.Context, in *ClassifyQuestionRequest, opts ...grpc.CallOption) (*ClassifyQuestionResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ClassifyQuestionResponse) @@ -285,6 +305,15 @@ type ReasoningServiceServer interface { // and caches the result so the agentic path can be gated without // a per-request round-trip. GetProviderCapabilities(context.Context, *GetProviderCapabilitiesRequest) (*GetProviderCapabilitiesResponse, error) + // GetLLMGateSnapshot returns a point-in-time snapshot of all active + // per-provider concurrency gates in the worker. Used by the admin + // /api/v1/admin/llm/activity endpoint to surface real-time gate + // counters (in-flight, queued, tok/s) without a per-request round-trip. + // + // An explicit request struct (rather than google.protobuf.Empty) is + // used for forward compatibility: filter fields (provider, kind) can + // be added without redeclaring the RPC. + GetLLMGateSnapshot(context.Context, *GetLLMGateSnapshotRequest) (*GetLLMGateSnapshotResponse, error) // ClassifyQuestion runs a cheap LLM classifier (Haiku) that // returns the question's likely class plus evidence-kind hints // (needs_call_graph, needs_tests, ...) and advisory symbol / @@ -345,6 +374,9 @@ func (UnimplementedReasoningServiceServer) AnswerQuestionWithTools(context.Conte func (UnimplementedReasoningServiceServer) GetProviderCapabilities(context.Context, *GetProviderCapabilitiesRequest) (*GetProviderCapabilitiesResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetProviderCapabilities not implemented") } +func (UnimplementedReasoningServiceServer) GetLLMGateSnapshot(context.Context, *GetLLMGateSnapshotRequest) (*GetLLMGateSnapshotResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetLLMGateSnapshot not implemented") +} func (UnimplementedReasoningServiceServer) ClassifyQuestion(context.Context, *ClassifyQuestionRequest) (*ClassifyQuestionResponse, error) { return nil, status.Error(codes.Unimplemented, "method ClassifyQuestion not implemented") } @@ -530,6 +562,24 @@ func _ReasoningService_GetProviderCapabilities_Handler(srv interface{}, ctx cont return interceptor(ctx, in, info, handler) } +func _ReasoningService_GetLLMGateSnapshot_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetLLMGateSnapshotRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReasoningServiceServer).GetLLMGateSnapshot(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ReasoningService_GetLLMGateSnapshot_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReasoningServiceServer).GetLLMGateSnapshot(ctx, req.(*GetLLMGateSnapshotRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _ReasoningService_ClassifyQuestion_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ClassifyQuestionRequest) if err := dec(in); err != nil { @@ -623,6 +673,10 @@ var ReasoningService_ServiceDesc = grpc.ServiceDesc{ MethodName: "GetProviderCapabilities", Handler: _ReasoningService_GetProviderCapabilities_Handler, }, + { + MethodName: "GetLLMGateSnapshot", + Handler: _ReasoningService_GetLLMGateSnapshot_Handler, + }, { MethodName: "ClassifyQuestion", Handler: _ReasoningService_ClassifyQuestion_Handler, diff --git a/gen/python/common/v1/knowledge_progress_pb2.py b/gen/python/common/v1/knowledge_progress_pb2.py index 01da0629..f1e48c06 100644 --- a/gen/python/common/v1/knowledge_progress_pb2.py +++ b/gen/python/common/v1/knowledge_progress_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"common/v1/knowledge_progress.proto\x12\x16sourcebridge.common.v1\"\x89\x02\n\x17KnowledgeStreamProgress\x12\x35\n\x05phase\x18\x01 \x01(\x0e\x32&.sourcebridge.common.v1.KnowledgePhase\x12\x17\n\x0f\x63ompleted_units\x18\x02 \x01(\x05\x12\x13\n\x0btotal_units\x18\x03 \x01(\x05\x12\x11\n\tunit_kind\x18\x04 \x01(\t\x12\x0f\n\x07message\x18\x05 \x01(\t\x12\x17\n\x0fleaf_cache_hits\x18\n \x01(\x05\x12\x17\n\x0f\x66ile_cache_hits\x18\x0b \x01(\x05\x12\x1a\n\x12package_cache_hits\x18\x0c \x01(\x05\x12\x17\n\x0froot_cache_hits\x18\r \x01(\x05\"c\n\x1aKnowledgeStreamPhaseMarker\x12\x35\n\x05phase\x18\x01 \x01(\x0e\x32&.sourcebridge.common.v1.KnowledgePhase\x12\x0e\n\x06\x64\x65tail\x18\x02 \x01(\t*\x9e\x02\n\x0eKnowledgePhase\x12\x1f\n\x1bKNOWLEDGE_PHASE_UNSPECIFIED\x10\x00\x12\x1c\n\x18KNOWLEDGE_PHASE_SNAPSHOT\x10\x01\x12\"\n\x1eKNOWLEDGE_PHASE_LEAF_SUMMARIES\x10\x02\x12\"\n\x1eKNOWLEDGE_PHASE_FILE_SUMMARIES\x10\x03\x12%\n!KNOWLEDGE_PHASE_PACKAGE_SUMMARIES\x10\x04\x12\"\n\x1eKNOWLEDGE_PHASE_ROOT_SYNTHESIS\x10\x05\x12\x1a\n\x16KNOWLEDGE_PHASE_RENDER\x10\x06\x12\x1e\n\x1aKNOWLEDGE_PHASE_FINALIZING\x10\x07\x42@Z>github.com/sourcebridge/sourcebridge/gen/go/common/v1;commonv1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"common/v1/knowledge_progress.proto\x12\x16sourcebridge.common.v1\"\xac\x02\n\x17KnowledgeStreamProgress\x12\x35\n\x05phase\x18\x01 \x01(\x0e\x32&.sourcebridge.common.v1.KnowledgePhase\x12\x17\n\x0f\x63ompleted_units\x18\x02 \x01(\x05\x12\x13\n\x0btotal_units\x18\x03 \x01(\x05\x12\x11\n\tunit_kind\x18\x04 \x01(\t\x12\x0f\n\x07message\x18\x05 \x01(\t\x12\x17\n\x0fleaf_cache_hits\x18\n \x01(\x05\x12\x17\n\x0f\x66ile_cache_hits\x18\x0b \x01(\x05\x12\x1a\n\x12package_cache_hits\x18\x0c \x01(\x05\x12\x17\n\x0froot_cache_hits\x18\r \x01(\x05\x12!\n\x19\x63urrent_tokens_per_second\x18\x0e \x01(\x02\"c\n\x1aKnowledgeStreamPhaseMarker\x12\x35\n\x05phase\x18\x01 \x01(\x0e\x32&.sourcebridge.common.v1.KnowledgePhase\x12\x0e\n\x06\x64\x65tail\x18\x02 \x01(\t*\x9e\x02\n\x0eKnowledgePhase\x12\x1f\n\x1bKNOWLEDGE_PHASE_UNSPECIFIED\x10\x00\x12\x1c\n\x18KNOWLEDGE_PHASE_SNAPSHOT\x10\x01\x12\"\n\x1eKNOWLEDGE_PHASE_LEAF_SUMMARIES\x10\x02\x12\"\n\x1eKNOWLEDGE_PHASE_FILE_SUMMARIES\x10\x03\x12%\n!KNOWLEDGE_PHASE_PACKAGE_SUMMARIES\x10\x04\x12\"\n\x1eKNOWLEDGE_PHASE_ROOT_SYNTHESIS\x10\x05\x12\x1a\n\x16KNOWLEDGE_PHASE_RENDER\x10\x06\x12\x1e\n\x1aKNOWLEDGE_PHASE_FINALIZING\x10\x07\x42@Z>github.com/sourcebridge/sourcebridge/gen/go/common/v1;commonv1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,10 +32,10 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'Z>github.com/sourcebridge/sourcebridge/gen/go/common/v1;commonv1' - _globals['_KNOWLEDGEPHASE']._serialized_start=432 - _globals['_KNOWLEDGEPHASE']._serialized_end=718 + _globals['_KNOWLEDGEPHASE']._serialized_start=467 + _globals['_KNOWLEDGEPHASE']._serialized_end=753 _globals['_KNOWLEDGESTREAMPROGRESS']._serialized_start=63 - _globals['_KNOWLEDGESTREAMPROGRESS']._serialized_end=328 - _globals['_KNOWLEDGESTREAMPHASEMARKER']._serialized_start=330 - _globals['_KNOWLEDGESTREAMPHASEMARKER']._serialized_end=429 + _globals['_KNOWLEDGESTREAMPROGRESS']._serialized_end=363 + _globals['_KNOWLEDGESTREAMPHASEMARKER']._serialized_start=365 + _globals['_KNOWLEDGESTREAMPHASEMARKER']._serialized_end=464 # @@protoc_insertion_point(module_scope) diff --git a/gen/python/common/v1/knowledge_progress_pb2.pyi b/gen/python/common/v1/knowledge_progress_pb2.pyi index f99ef818..a7ace33c 100644 --- a/gen/python/common/v1/knowledge_progress_pb2.pyi +++ b/gen/python/common/v1/knowledge_progress_pb2.pyi @@ -25,7 +25,7 @@ KNOWLEDGE_PHASE_RENDER: KnowledgePhase KNOWLEDGE_PHASE_FINALIZING: KnowledgePhase class KnowledgeStreamProgress(_message.Message): - __slots__ = ("phase", "completed_units", "total_units", "unit_kind", "message", "leaf_cache_hits", "file_cache_hits", "package_cache_hits", "root_cache_hits") + __slots__ = ("phase", "completed_units", "total_units", "unit_kind", "message", "leaf_cache_hits", "file_cache_hits", "package_cache_hits", "root_cache_hits", "current_tokens_per_second") PHASE_FIELD_NUMBER: _ClassVar[int] COMPLETED_UNITS_FIELD_NUMBER: _ClassVar[int] TOTAL_UNITS_FIELD_NUMBER: _ClassVar[int] @@ -35,6 +35,7 @@ class KnowledgeStreamProgress(_message.Message): FILE_CACHE_HITS_FIELD_NUMBER: _ClassVar[int] PACKAGE_CACHE_HITS_FIELD_NUMBER: _ClassVar[int] ROOT_CACHE_HITS_FIELD_NUMBER: _ClassVar[int] + CURRENT_TOKENS_PER_SECOND_FIELD_NUMBER: _ClassVar[int] phase: KnowledgePhase completed_units: int total_units: int @@ -44,7 +45,8 @@ class KnowledgeStreamProgress(_message.Message): file_cache_hits: int package_cache_hits: int root_cache_hits: int - def __init__(self, phase: _Optional[_Union[KnowledgePhase, str]] = ..., completed_units: _Optional[int] = ..., total_units: _Optional[int] = ..., unit_kind: _Optional[str] = ..., message: _Optional[str] = ..., leaf_cache_hits: _Optional[int] = ..., file_cache_hits: _Optional[int] = ..., package_cache_hits: _Optional[int] = ..., root_cache_hits: _Optional[int] = ...) -> None: ... + current_tokens_per_second: float + def __init__(self, phase: _Optional[_Union[KnowledgePhase, str]] = ..., completed_units: _Optional[int] = ..., total_units: _Optional[int] = ..., unit_kind: _Optional[str] = ..., message: _Optional[str] = ..., leaf_cache_hits: _Optional[int] = ..., file_cache_hits: _Optional[int] = ..., package_cache_hits: _Optional[int] = ..., root_cache_hits: _Optional[int] = ..., current_tokens_per_second: _Optional[float] = ...) -> None: ... class KnowledgeStreamPhaseMarker(_message.Message): __slots__ = ("phase", "detail") diff --git a/gen/python/reasoning/v1/reasoning_pb2.py b/gen/python/reasoning/v1/reasoning_pb2.py index 2a0fce0c..e04a5c19 100644 --- a/gen/python/reasoning/v1/reasoning_pb2.py +++ b/gen/python/reasoning/v1/reasoning_pb2.py @@ -25,7 +25,7 @@ from common.v1 import types_pb2 as common_dot_v1_dot_types__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1creasoning/v1/reasoning.proto\x12\x19sourcebridge.reasoning.v1\x1a\x15\x63ommon/v1/types.proto\"~\n\x14\x41nalyzeSymbolRequest\x12\x32\n\x06symbol\x18\x01 \x01(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x1b\n\x13surrounding_context\x18\x02 \x01(\t\x12\x15\n\rrepository_id\x18\x03 \x01(\t\"\x91\x01\n\x15\x41nalyzeSymbolResponse\x12\x0f\n\x07summary\x18\x01 \x01(\t\x12\x0f\n\x07purpose\x18\x02 \x01(\t\x12\x10\n\x08\x63oncerns\x18\x03 \x03(\t\x12\x13\n\x0bsuggestions\x18\x04 \x03(\t\x12/\n\x05usage\x18\x05 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\x9f\x01\n\x1a\x45xplainRelationshipRequest\x12\x32\n\x06source\x18\x01 \x01(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x32\n\x06target\x18\x02 \x01(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x19\n\x11relationship_type\x18\x03 \x01(\t\"\x9b\x01\n\x1b\x45xplainRelationshipResponse\x12\x13\n\x0b\x65xplanation\x18\x01 \x01(\t\x12\x36\n\nconfidence\x18\x02 \x01(\x0e\x32\".sourcebridge.common.v1.Confidence\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xee\x01\n\x15\x41nswerQuestionRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x15\n\rrepository_id\x18\x02 \x01(\t\x12;\n\x0f\x63ontext_symbols\x18\x03 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x14\n\x0c\x63ontext_code\x18\x05 \x01(\t\x12\x11\n\tfile_path\x18\x06 \x01(\t\x12\x32\n\x08language\x18\x07 \x01(\x0e\x32 .sourcebridge.common.v1.Language\"\x99\x01\n\x16\x41nswerQuestionResponse\x12\x0e\n\x06\x61nswer\x18\x01 \x01(\t\x12>\n\x12referenced_symbols\x18\x02 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xf4\x01\n\x1b\x41nswerQuestionStreamRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x15\n\rrepository_id\x18\x02 \x01(\t\x12;\n\x0f\x63ontext_symbols\x18\x03 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x14\n\x0c\x63ontext_code\x18\x05 \x01(\t\x12\x11\n\tfile_path\x18\x06 \x01(\t\x12\x32\n\x08language\x18\x07 \x01(\x0e\x32 .sourcebridge.common.v1.Language\"\xf4\x01\n\x1c\x41nswerQuestionStreamResponse\x12\x15\n\rcontent_delta\x18\x01 \x01(\t\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\x12>\n\x12referenced_symbols\x18\x03 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12/\n\x05usage\x18\x04 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\x12:\n\x08progress\x18\x05 \x01(\x0b\x32(.sourcebridge.reasoning.v1.ProgressEvent\"U\n\rProgressEvent\x12\r\n\x05phase\x18\x01 \x01(\t\x12\x0e\n\x06\x64\x65tail\x18\x02 \x01(\t\x12\x11\n\ttool_name\x18\x03 \x01(\t\x12\x12\n\nelapsed_ms\x18\x04 \x01(\x03\"\x94\x01\n\x11ReviewFileRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x11\n\tfile_path\x18\x02 \x01(\t\x12\x32\n\x08language\x18\x03 \x01(\x0e\x32 .sourcebridge.common.v1.Language\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\x10\n\x08template\x18\x05 \x01(\t\"\xa2\x01\n\x12ReviewFileResponse\x12\x10\n\x08template\x18\x01 \x01(\t\x12:\n\x08\x66indings\x18\x02 \x03(\x0b\x32(.sourcebridge.reasoning.v1.ReviewFinding\x12\r\n\x05score\x18\x03 \x01(\x02\x12/\n\x05usage\x18\x04 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\x91\x01\n\rReviewFinding\x12\x10\n\x08\x63\x61tegory\x18\x01 \x01(\t\x12\x10\n\x08severity\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x11\n\tfile_path\x18\x04 \x01(\t\x12\x12\n\nstart_line\x18\x05 \x01(\x05\x12\x10\n\x08\x65nd_line\x18\x06 \x01(\x05\x12\x12\n\nsuggestion\x18\x07 \x01(\t\"7\n\x18GenerateEmbeddingRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\"\x82\x01\n\x19GenerateEmbeddingResponse\x12\x34\n\tembedding\x18\x01 \x01(\x0b\x32!.sourcebridge.common.v1.Embedding\x12/\n\x05usage\x18\x02 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xd1\x01\n\x15SimulateChangeRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x13\n\x0b\x61nchor_file\x18\x03 \x01(\t\x12\x15\n\ranchor_symbol\x18\x04 \x01(\t\x12\x33\n\x07symbols\x18\x05 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\r\n\x05top_n\x18\x06 \x01(\x05\x12\x1c\n\x14\x63onfidence_threshold\x18\x07 \x01(\x02\"\x97\x01\n\x14SimulatedSymbolMatch\x12\x11\n\tsymbol_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x16\n\x0equalified_name\x18\x03 \x01(\t\x12\x0c\n\x04kind\x18\x04 \x01(\t\x12\x11\n\tfile_path\x18\x05 \x01(\t\x12\x12\n\nsimilarity\x18\x06 \x01(\x02\x12\x11\n\tis_anchor\x18\x07 \x01(\x08\"\xd4\x01\n\x16SimulateChangeResponse\x12I\n\x10resolved_symbols\x18\x01 \x03(\x0b\x32/.sourcebridge.reasoning.v1.SimulatedSymbolMatch\x12#\n\x1b\x64\x65scription_embedding_model\x18\x02 \x01(\t\x12\x19\n\x11symbols_evaluated\x18\x03 \x01(\x05\x12/\n\x05usage\x18\x04 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"J\n\nToolSchema\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x19\n\x11input_schema_json\x18\x03 \x01(\t\"\xa0\x01\n\x0c\x41gentMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x37\n\ntool_calls\x18\x03 \x03(\x0b\x32#.sourcebridge.reasoning.v1.ToolCall\x12;\n\x0ctool_results\x18\x04 \x03(\x0b\x32%.sourcebridge.reasoning.v1.ToolResult\"<\n\x08ToolCall\x12\x0f\n\x07\x63\x61ll_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\targs_json\x18\x03 \x01(\t\"Y\n\nToolResult\x12\x0f\n\x07\x63\x61ll_id\x18\x01 \x01(\t\x12\n\n\x02ok\x18\x02 \x01(\x08\x12\x11\n\tdata_json\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x0c\n\x04hint\x18\x05 \x01(\t\"\xdb\x01\n\x1e\x41nswerQuestionWithToolsRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x39\n\x08messages\x18\x02 \x03(\x0b\x32\'.sourcebridge.reasoning.v1.AgentMessage\x12\x34\n\x05tools\x18\x03 \x03(\x0b\x32%.sourcebridge.reasoning.v1.ToolSchema\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x1d\n\x15\x65nable_prompt_caching\x18\x05 \x01(\x08\"\x87\x02\n\x1f\x41nswerQuestionWithToolsResponse\x12\x1c\n\x14\x63\x61pability_supported\x18\x01 \x01(\x08\x12\x35\n\x04turn\x18\x02 \x01(\x0b\x32\'.sourcebridge.reasoning.v1.AgentMessage\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\x12\x18\n\x10termination_hint\x18\x04 \x01(\t\x12#\n\x1b\x63\x61\x63he_creation_input_tokens\x18\x05 \x01(\x03\x12\x1f\n\x17\x63\x61\x63he_read_input_tokens\x18\x06 \x01(\x03\"j\n\x17\x43lassifyQuestionRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x10\n\x08question\x18\x02 \x01(\t\x12\x11\n\tfile_path\x18\x03 \x01(\t\x12\x13\n\x0bpinned_code\x18\x04 \x01(\t\"\xae\x02\n\x18\x43lassifyQuestionResponse\x12\x1c\n\x14\x63\x61pability_supported\x18\x01 \x01(\x08\x12\x16\n\x0equestion_class\x18\x02 \x01(\t\x12\x18\n\x10needs_call_graph\x18\x03 \x01(\x08\x12\x1a\n\x12needs_requirements\x18\x04 \x01(\x08\x12\x13\n\x0bneeds_tests\x18\x05 \x01(\x08\x12\x17\n\x0fneeds_summaries\x18\x06 \x01(\x08\x12\x19\n\x11symbol_candidates\x18\x07 \x03(\t\x12\x17\n\x0f\x66ile_candidates\x18\x08 \x03(\t\x12\x13\n\x0btopic_terms\x18\t \x03(\t\x12/\n\x05usage\x18\n \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"v\n\x18\x44\x65\x63omposeQuestionRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x10\n\x08question\x18\x02 \x01(\t\x12\x16\n\x0equestion_class\x18\x03 \x01(\t\x12\x19\n\x11max_sub_questions\x18\x04 \x01(\x05\"\x81\x01\n\x19\x44\x65\x63omposeQuestionResponse\x12\x1c\n\x14\x63\x61pability_supported\x18\x01 \x01(\x08\x12\x15\n\rsub_questions\x18\x02 \x03(\t\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xb9\x01\n!SynthesizeDecomposedAnswerRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x19\n\x11original_question\x18\x02 \x01(\t\x12\x43\n\x0bsub_answers\x18\x03 \x03(\x0b\x32..sourcebridge.reasoning.v1.DecomposedSubAnswer\x12\x1d\n\x15\x65nable_prompt_caching\x18\x04 \x01(\x08\"\x90\x01\n\x13\x44\x65\x63omposedSubAnswer\x12\x14\n\x0csub_question\x18\x01 \x01(\t\x12\x12\n\nsub_answer\x18\x02 \x01(\t\x12\x19\n\x11reference_handles\x18\x03 \x03(\t\x12\x1a\n\x12termination_reason\x18\x04 \x01(\t\x12\x18\n\x10tool_calls_count\x18\x05 \x01(\x05\"\xab\x01\n\"SynthesizeDecomposedAnswerResponse\x12\x0e\n\x06\x61nswer\x18\x01 \x01(\t\x12/\n\x05usage\x18\x02 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\x12#\n\x1b\x63\x61\x63he_creation_input_tokens\x18\x03 \x01(\x03\x12\x1f\n\x17\x63\x61\x63he_read_input_tokens\x18\x04 \x01(\x03\" \n\x1eGetProviderCapabilitiesRequest\"\xc2\x01\n\x1fGetProviderCapabilitiesResponse\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x1a\n\x12tool_use_supported\x18\x03 \x01(\x08\x12 \n\x18prompt_caching_supported\x18\x04 \x01(\x08\x12\x1c\n\x14max_concurrent_calls\x18\x05 \x01(\x05\x12\"\n\x1amax_concurrent_calls_known\x18\x06 \x01(\x08\x32\xb1\x0c\n\x10ReasoningService\x12r\n\rAnalyzeSymbol\x12/.sourcebridge.reasoning.v1.AnalyzeSymbolRequest\x1a\x30.sourcebridge.reasoning.v1.AnalyzeSymbolResponse\x12\x84\x01\n\x13\x45xplainRelationship\x12\x35.sourcebridge.reasoning.v1.ExplainRelationshipRequest\x1a\x36.sourcebridge.reasoning.v1.ExplainRelationshipResponse\x12u\n\x0e\x41nswerQuestion\x12\x30.sourcebridge.reasoning.v1.AnswerQuestionRequest\x1a\x31.sourcebridge.reasoning.v1.AnswerQuestionResponse\x12\x89\x01\n\x14\x41nswerQuestionStream\x12\x36.sourcebridge.reasoning.v1.AnswerQuestionStreamRequest\x1a\x37.sourcebridge.reasoning.v1.AnswerQuestionStreamResponse0\x01\x12i\n\nReviewFile\x12,.sourcebridge.reasoning.v1.ReviewFileRequest\x1a-.sourcebridge.reasoning.v1.ReviewFileResponse\x12~\n\x11GenerateEmbedding\x12\x33.sourcebridge.reasoning.v1.GenerateEmbeddingRequest\x1a\x34.sourcebridge.reasoning.v1.GenerateEmbeddingResponse\x12u\n\x0eSimulateChange\x12\x30.sourcebridge.reasoning.v1.SimulateChangeRequest\x1a\x31.sourcebridge.reasoning.v1.SimulateChangeResponse\x12\x90\x01\n\x17\x41nswerQuestionWithTools\x12\x39.sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest\x1a:.sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse\x12\x90\x01\n\x17GetProviderCapabilities\x12\x39.sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest\x1a:.sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse\x12{\n\x10\x43lassifyQuestion\x12\x32.sourcebridge.reasoning.v1.ClassifyQuestionRequest\x1a\x33.sourcebridge.reasoning.v1.ClassifyQuestionResponse\x12~\n\x11\x44\x65\x63omposeQuestion\x12\x33.sourcebridge.reasoning.v1.DecomposeQuestionRequest\x1a\x34.sourcebridge.reasoning.v1.DecomposeQuestionResponse\x12\x99\x01\n\x1aSynthesizeDecomposedAnswer\x12<.sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerRequest\x1a=.sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponseBFZDgithub.com/sourcebridge/sourcebridge/gen/go/reasoning/v1;reasoningv1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1creasoning/v1/reasoning.proto\x12\x19sourcebridge.reasoning.v1\x1a\x15\x63ommon/v1/types.proto\"~\n\x14\x41nalyzeSymbolRequest\x12\x32\n\x06symbol\x18\x01 \x01(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x1b\n\x13surrounding_context\x18\x02 \x01(\t\x12\x15\n\rrepository_id\x18\x03 \x01(\t\"\x91\x01\n\x15\x41nalyzeSymbolResponse\x12\x0f\n\x07summary\x18\x01 \x01(\t\x12\x0f\n\x07purpose\x18\x02 \x01(\t\x12\x10\n\x08\x63oncerns\x18\x03 \x03(\t\x12\x13\n\x0bsuggestions\x18\x04 \x03(\t\x12/\n\x05usage\x18\x05 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\x9f\x01\n\x1a\x45xplainRelationshipRequest\x12\x32\n\x06source\x18\x01 \x01(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x32\n\x06target\x18\x02 \x01(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x19\n\x11relationship_type\x18\x03 \x01(\t\"\x9b\x01\n\x1b\x45xplainRelationshipResponse\x12\x13\n\x0b\x65xplanation\x18\x01 \x01(\t\x12\x36\n\nconfidence\x18\x02 \x01(\x0e\x32\".sourcebridge.common.v1.Confidence\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xee\x01\n\x15\x41nswerQuestionRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x15\n\rrepository_id\x18\x02 \x01(\t\x12;\n\x0f\x63ontext_symbols\x18\x03 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x14\n\x0c\x63ontext_code\x18\x05 \x01(\t\x12\x11\n\tfile_path\x18\x06 \x01(\t\x12\x32\n\x08language\x18\x07 \x01(\x0e\x32 .sourcebridge.common.v1.Language\"\x99\x01\n\x16\x41nswerQuestionResponse\x12\x0e\n\x06\x61nswer\x18\x01 \x01(\t\x12>\n\x12referenced_symbols\x18\x02 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xf4\x01\n\x1b\x41nswerQuestionStreamRequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x15\n\rrepository_id\x18\x02 \x01(\t\x12;\n\x0f\x63ontext_symbols\x18\x03 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x14\n\x0c\x63ontext_code\x18\x05 \x01(\t\x12\x11\n\tfile_path\x18\x06 \x01(\t\x12\x32\n\x08language\x18\x07 \x01(\x0e\x32 .sourcebridge.common.v1.Language\"\xf4\x01\n\x1c\x41nswerQuestionStreamResponse\x12\x15\n\rcontent_delta\x18\x01 \x01(\t\x12\x10\n\x08\x66inished\x18\x02 \x01(\x08\x12>\n\x12referenced_symbols\x18\x03 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12/\n\x05usage\x18\x04 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\x12:\n\x08progress\x18\x05 \x01(\x0b\x32(.sourcebridge.reasoning.v1.ProgressEvent\"U\n\rProgressEvent\x12\r\n\x05phase\x18\x01 \x01(\t\x12\x0e\n\x06\x64\x65tail\x18\x02 \x01(\t\x12\x11\n\ttool_name\x18\x03 \x01(\t\x12\x12\n\nelapsed_ms\x18\x04 \x01(\x03\"\x94\x01\n\x11ReviewFileRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x11\n\tfile_path\x18\x02 \x01(\t\x12\x32\n\x08language\x18\x03 \x01(\x0e\x32 .sourcebridge.common.v1.Language\x12\x0f\n\x07\x63ontent\x18\x04 \x01(\t\x12\x10\n\x08template\x18\x05 \x01(\t\"\xa2\x01\n\x12ReviewFileResponse\x12\x10\n\x08template\x18\x01 \x01(\t\x12:\n\x08\x66indings\x18\x02 \x03(\x0b\x32(.sourcebridge.reasoning.v1.ReviewFinding\x12\r\n\x05score\x18\x03 \x01(\x02\x12/\n\x05usage\x18\x04 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\x91\x01\n\rReviewFinding\x12\x10\n\x08\x63\x61tegory\x18\x01 \x01(\t\x12\x10\n\x08severity\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x11\n\tfile_path\x18\x04 \x01(\t\x12\x12\n\nstart_line\x18\x05 \x01(\x05\x12\x10\n\x08\x65nd_line\x18\x06 \x01(\x05\x12\x12\n\nsuggestion\x18\x07 \x01(\t\"7\n\x18GenerateEmbeddingRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\"\x82\x01\n\x19GenerateEmbeddingResponse\x12\x34\n\tembedding\x18\x01 \x01(\x0b\x32!.sourcebridge.common.v1.Embedding\x12/\n\x05usage\x18\x02 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xd1\x01\n\x15SimulateChangeRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x13\n\x0b\x61nchor_file\x18\x03 \x01(\t\x12\x15\n\ranchor_symbol\x18\x04 \x01(\t\x12\x33\n\x07symbols\x18\x05 \x03(\x0b\x32\".sourcebridge.common.v1.CodeSymbol\x12\r\n\x05top_n\x18\x06 \x01(\x05\x12\x1c\n\x14\x63onfidence_threshold\x18\x07 \x01(\x02\"\x97\x01\n\x14SimulatedSymbolMatch\x12\x11\n\tsymbol_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x16\n\x0equalified_name\x18\x03 \x01(\t\x12\x0c\n\x04kind\x18\x04 \x01(\t\x12\x11\n\tfile_path\x18\x05 \x01(\t\x12\x12\n\nsimilarity\x18\x06 \x01(\x02\x12\x11\n\tis_anchor\x18\x07 \x01(\x08\"\xd4\x01\n\x16SimulateChangeResponse\x12I\n\x10resolved_symbols\x18\x01 \x03(\x0b\x32/.sourcebridge.reasoning.v1.SimulatedSymbolMatch\x12#\n\x1b\x64\x65scription_embedding_model\x18\x02 \x01(\t\x12\x19\n\x11symbols_evaluated\x18\x03 \x01(\x05\x12/\n\x05usage\x18\x04 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"J\n\nToolSchema\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x19\n\x11input_schema_json\x18\x03 \x01(\t\"\xa0\x01\n\x0c\x41gentMessage\x12\x0c\n\x04role\x18\x01 \x01(\t\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x37\n\ntool_calls\x18\x03 \x03(\x0b\x32#.sourcebridge.reasoning.v1.ToolCall\x12;\n\x0ctool_results\x18\x04 \x03(\x0b\x32%.sourcebridge.reasoning.v1.ToolResult\"<\n\x08ToolCall\x12\x0f\n\x07\x63\x61ll_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\targs_json\x18\x03 \x01(\t\"Y\n\nToolResult\x12\x0f\n\x07\x63\x61ll_id\x18\x01 \x01(\t\x12\n\n\x02ok\x18\x02 \x01(\x08\x12\x11\n\tdata_json\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x0c\n\x04hint\x18\x05 \x01(\t\"\xdb\x01\n\x1e\x41nswerQuestionWithToolsRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x39\n\x08messages\x18\x02 \x03(\x0b\x32\'.sourcebridge.reasoning.v1.AgentMessage\x12\x34\n\x05tools\x18\x03 \x03(\x0b\x32%.sourcebridge.reasoning.v1.ToolSchema\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x1d\n\x15\x65nable_prompt_caching\x18\x05 \x01(\x08\"\x87\x02\n\x1f\x41nswerQuestionWithToolsResponse\x12\x1c\n\x14\x63\x61pability_supported\x18\x01 \x01(\x08\x12\x35\n\x04turn\x18\x02 \x01(\x0b\x32\'.sourcebridge.reasoning.v1.AgentMessage\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\x12\x18\n\x10termination_hint\x18\x04 \x01(\t\x12#\n\x1b\x63\x61\x63he_creation_input_tokens\x18\x05 \x01(\x03\x12\x1f\n\x17\x63\x61\x63he_read_input_tokens\x18\x06 \x01(\x03\"j\n\x17\x43lassifyQuestionRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x10\n\x08question\x18\x02 \x01(\t\x12\x11\n\tfile_path\x18\x03 \x01(\t\x12\x13\n\x0bpinned_code\x18\x04 \x01(\t\"\xae\x02\n\x18\x43lassifyQuestionResponse\x12\x1c\n\x14\x63\x61pability_supported\x18\x01 \x01(\x08\x12\x16\n\x0equestion_class\x18\x02 \x01(\t\x12\x18\n\x10needs_call_graph\x18\x03 \x01(\x08\x12\x1a\n\x12needs_requirements\x18\x04 \x01(\x08\x12\x13\n\x0bneeds_tests\x18\x05 \x01(\x08\x12\x17\n\x0fneeds_summaries\x18\x06 \x01(\x08\x12\x19\n\x11symbol_candidates\x18\x07 \x03(\t\x12\x17\n\x0f\x66ile_candidates\x18\x08 \x03(\t\x12\x13\n\x0btopic_terms\x18\t \x03(\t\x12/\n\x05usage\x18\n \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"v\n\x18\x44\x65\x63omposeQuestionRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x10\n\x08question\x18\x02 \x01(\t\x12\x16\n\x0equestion_class\x18\x03 \x01(\t\x12\x19\n\x11max_sub_questions\x18\x04 \x01(\x05\"\x81\x01\n\x19\x44\x65\x63omposeQuestionResponse\x12\x1c\n\x14\x63\x61pability_supported\x18\x01 \x01(\x08\x12\x15\n\rsub_questions\x18\x02 \x03(\t\x12/\n\x05usage\x18\x03 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\"\xb9\x01\n!SynthesizeDecomposedAnswerRequest\x12\x15\n\rrepository_id\x18\x01 \x01(\t\x12\x19\n\x11original_question\x18\x02 \x01(\t\x12\x43\n\x0bsub_answers\x18\x03 \x03(\x0b\x32..sourcebridge.reasoning.v1.DecomposedSubAnswer\x12\x1d\n\x15\x65nable_prompt_caching\x18\x04 \x01(\x08\"\x90\x01\n\x13\x44\x65\x63omposedSubAnswer\x12\x14\n\x0csub_question\x18\x01 \x01(\t\x12\x12\n\nsub_answer\x18\x02 \x01(\t\x12\x19\n\x11reference_handles\x18\x03 \x03(\t\x12\x1a\n\x12termination_reason\x18\x04 \x01(\t\x12\x18\n\x10tool_calls_count\x18\x05 \x01(\x05\"\xab\x01\n\"SynthesizeDecomposedAnswerResponse\x12\x0e\n\x06\x61nswer\x18\x01 \x01(\t\x12/\n\x05usage\x18\x02 \x01(\x0b\x32 .sourcebridge.common.v1.LLMUsage\x12#\n\x1b\x63\x61\x63he_creation_input_tokens\x18\x03 \x01(\x03\x12\x1f\n\x17\x63\x61\x63he_read_input_tokens\x18\x04 \x01(\x03\" \n\x1eGetProviderCapabilitiesRequest\"\xc2\x01\n\x1fGetProviderCapabilitiesResponse\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x1a\n\x12tool_use_supported\x18\x03 \x01(\x08\x12 \n\x18prompt_caching_supported\x18\x04 \x01(\x08\x12\x1c\n\x14max_concurrent_calls\x18\x05 \x01(\x05\x12\"\n\x1amax_concurrent_calls_known\x18\x06 \x01(\x08\"\x1b\n\x19GetLLMGateSnapshotRequest\"T\n\x1aGetLLMGateSnapshotResponse\x12\x36\n\x05gates\x18\x01 \x03(\x0b\x32\'.sourcebridge.reasoning.v1.LLMGateEntry\"\xe5\x01\n\x0cLLMGateEntry\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\x1b\n\x13\x62\x61se_url_normalized\x18\x02 \x01(\t\x12\x0c\n\x04kind\x18\x03 \x01(\t\x12\x11\n\tin_flight\x18\x04 \x01(\x05\x12\x0e\n\x06queued\x18\x05 \x01(\x05\x12\x16\n\x0emax_concurrent\x18\x06 \x01(\x05\x12\x1b\n\x13retries_since_start\x18\x07 \x01(\x03\x12\x18\n\x10recent_429_count\x18\x08 \x01(\x03\x12\x19\n\x11tokens_per_second\x18\t \x01(\x01\x12\x0b\n\x03rpm\x18\n \x01(\x05\x32\xb5\r\n\x10ReasoningService\x12r\n\rAnalyzeSymbol\x12/.sourcebridge.reasoning.v1.AnalyzeSymbolRequest\x1a\x30.sourcebridge.reasoning.v1.AnalyzeSymbolResponse\x12\x84\x01\n\x13\x45xplainRelationship\x12\x35.sourcebridge.reasoning.v1.ExplainRelationshipRequest\x1a\x36.sourcebridge.reasoning.v1.ExplainRelationshipResponse\x12u\n\x0e\x41nswerQuestion\x12\x30.sourcebridge.reasoning.v1.AnswerQuestionRequest\x1a\x31.sourcebridge.reasoning.v1.AnswerQuestionResponse\x12\x89\x01\n\x14\x41nswerQuestionStream\x12\x36.sourcebridge.reasoning.v1.AnswerQuestionStreamRequest\x1a\x37.sourcebridge.reasoning.v1.AnswerQuestionStreamResponse0\x01\x12i\n\nReviewFile\x12,.sourcebridge.reasoning.v1.ReviewFileRequest\x1a-.sourcebridge.reasoning.v1.ReviewFileResponse\x12~\n\x11GenerateEmbedding\x12\x33.sourcebridge.reasoning.v1.GenerateEmbeddingRequest\x1a\x34.sourcebridge.reasoning.v1.GenerateEmbeddingResponse\x12u\n\x0eSimulateChange\x12\x30.sourcebridge.reasoning.v1.SimulateChangeRequest\x1a\x31.sourcebridge.reasoning.v1.SimulateChangeResponse\x12\x90\x01\n\x17\x41nswerQuestionWithTools\x12\x39.sourcebridge.reasoning.v1.AnswerQuestionWithToolsRequest\x1a:.sourcebridge.reasoning.v1.AnswerQuestionWithToolsResponse\x12\x90\x01\n\x17GetProviderCapabilities\x12\x39.sourcebridge.reasoning.v1.GetProviderCapabilitiesRequest\x1a:.sourcebridge.reasoning.v1.GetProviderCapabilitiesResponse\x12\x81\x01\n\x12GetLLMGateSnapshot\x12\x34.sourcebridge.reasoning.v1.GetLLMGateSnapshotRequest\x1a\x35.sourcebridge.reasoning.v1.GetLLMGateSnapshotResponse\x12{\n\x10\x43lassifyQuestion\x12\x32.sourcebridge.reasoning.v1.ClassifyQuestionRequest\x1a\x33.sourcebridge.reasoning.v1.ClassifyQuestionResponse\x12~\n\x11\x44\x65\x63omposeQuestion\x12\x33.sourcebridge.reasoning.v1.DecomposeQuestionRequest\x1a\x34.sourcebridge.reasoning.v1.DecomposeQuestionResponse\x12\x99\x01\n\x1aSynthesizeDecomposedAnswer\x12<.sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerRequest\x1a=.sourcebridge.reasoning.v1.SynthesizeDecomposedAnswerResponseBFZDgithub.com/sourcebridge/sourcebridge/gen/go/reasoning/v1;reasoningv1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -97,6 +97,12 @@ _globals['_GETPROVIDERCAPABILITIESREQUEST']._serialized_end=4977 _globals['_GETPROVIDERCAPABILITIESRESPONSE']._serialized_start=4980 _globals['_GETPROVIDERCAPABILITIESRESPONSE']._serialized_end=5174 - _globals['_REASONINGSERVICE']._serialized_start=5177 - _globals['_REASONINGSERVICE']._serialized_end=6762 + _globals['_GETLLMGATESNAPSHOTREQUEST']._serialized_start=5176 + _globals['_GETLLMGATESNAPSHOTREQUEST']._serialized_end=5203 + _globals['_GETLLMGATESNAPSHOTRESPONSE']._serialized_start=5205 + _globals['_GETLLMGATESNAPSHOTRESPONSE']._serialized_end=5289 + _globals['_LLMGATEENTRY']._serialized_start=5292 + _globals['_LLMGATEENTRY']._serialized_end=5521 + _globals['_REASONINGSERVICE']._serialized_start=5524 + _globals['_REASONINGSERVICE']._serialized_end=7241 # @@protoc_insertion_point(module_scope) diff --git a/gen/python/reasoning/v1/reasoning_pb2.pyi b/gen/python/reasoning/v1/reasoning_pb2.pyi index b6a574ba..6dd7ff6d 100644 --- a/gen/python/reasoning/v1/reasoning_pb2.pyi +++ b/gen/python/reasoning/v1/reasoning_pb2.pyi @@ -422,3 +422,37 @@ class GetProviderCapabilitiesResponse(_message.Message): max_concurrent_calls: int max_concurrent_calls_known: bool def __init__(self, provider: _Optional[str] = ..., model: _Optional[str] = ..., tool_use_supported: bool = ..., prompt_caching_supported: bool = ..., max_concurrent_calls: _Optional[int] = ..., max_concurrent_calls_known: bool = ...) -> None: ... + +class GetLLMGateSnapshotRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class GetLLMGateSnapshotResponse(_message.Message): + __slots__ = ("gates",) + GATES_FIELD_NUMBER: _ClassVar[int] + gates: _containers.RepeatedCompositeFieldContainer[LLMGateEntry] + def __init__(self, gates: _Optional[_Iterable[_Union[LLMGateEntry, _Mapping]]] = ...) -> None: ... + +class LLMGateEntry(_message.Message): + __slots__ = ("provider", "base_url_normalized", "kind", "in_flight", "queued", "max_concurrent", "retries_since_start", "recent_429_count", "tokens_per_second", "rpm") + PROVIDER_FIELD_NUMBER: _ClassVar[int] + BASE_URL_NORMALIZED_FIELD_NUMBER: _ClassVar[int] + KIND_FIELD_NUMBER: _ClassVar[int] + IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] + QUEUED_FIELD_NUMBER: _ClassVar[int] + MAX_CONCURRENT_FIELD_NUMBER: _ClassVar[int] + RETRIES_SINCE_START_FIELD_NUMBER: _ClassVar[int] + RECENT_429_COUNT_FIELD_NUMBER: _ClassVar[int] + TOKENS_PER_SECOND_FIELD_NUMBER: _ClassVar[int] + RPM_FIELD_NUMBER: _ClassVar[int] + provider: str + base_url_normalized: str + kind: str + in_flight: int + queued: int + max_concurrent: int + retries_since_start: int + recent_429_count: int + tokens_per_second: float + rpm: int + def __init__(self, provider: _Optional[str] = ..., base_url_normalized: _Optional[str] = ..., kind: _Optional[str] = ..., in_flight: _Optional[int] = ..., queued: _Optional[int] = ..., max_concurrent: _Optional[int] = ..., retries_since_start: _Optional[int] = ..., recent_429_count: _Optional[int] = ..., tokens_per_second: _Optional[float] = ..., rpm: _Optional[int] = ...) -> None: ... diff --git a/gen/python/reasoning/v1/reasoning_pb2_grpc.py b/gen/python/reasoning/v1/reasoning_pb2_grpc.py index 32d6887d..ca784519 100644 --- a/gen/python/reasoning/v1/reasoning_pb2_grpc.py +++ b/gen/python/reasoning/v1/reasoning_pb2_grpc.py @@ -80,6 +80,11 @@ def __init__(self, channel): request_serializer=reasoning_dot_v1_dot_reasoning__pb2.GetProviderCapabilitiesRequest.SerializeToString, response_deserializer=reasoning_dot_v1_dot_reasoning__pb2.GetProviderCapabilitiesResponse.FromString, _registered_method=True) + self.GetLLMGateSnapshot = channel.unary_unary( + '/sourcebridge.reasoning.v1.ReasoningService/GetLLMGateSnapshot', + request_serializer=reasoning_dot_v1_dot_reasoning__pb2.GetLLMGateSnapshotRequest.SerializeToString, + response_deserializer=reasoning_dot_v1_dot_reasoning__pb2.GetLLMGateSnapshotResponse.FromString, + _registered_method=True) self.ClassifyQuestion = channel.unary_unary( '/sourcebridge.reasoning.v1.ReasoningService/ClassifyQuestion', request_serializer=reasoning_dot_v1_dot_reasoning__pb2.ClassifyQuestionRequest.SerializeToString, @@ -185,6 +190,20 @@ def GetProviderCapabilities(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetLLMGateSnapshot(self, request, context): + """GetLLMGateSnapshot returns a point-in-time snapshot of all active + per-provider concurrency gates in the worker. Used by the admin + /api/v1/admin/llm/activity endpoint to surface real-time gate + counters (in-flight, queued, tok/s) without a per-request round-trip. + + An explicit request struct (rather than google.protobuf.Empty) is + used for forward compatibility: filter fields (provider, kind) can + be added without redeclaring the RPC. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def ClassifyQuestion(self, request, context): """ClassifyQuestion runs a cheap LLM classifier (Haiku) that returns the question's likely class plus evidence-kind hints @@ -271,6 +290,11 @@ def add_ReasoningServiceServicer_to_server(servicer, server): request_deserializer=reasoning_dot_v1_dot_reasoning__pb2.GetProviderCapabilitiesRequest.FromString, response_serializer=reasoning_dot_v1_dot_reasoning__pb2.GetProviderCapabilitiesResponse.SerializeToString, ), + 'GetLLMGateSnapshot': grpc.unary_unary_rpc_method_handler( + servicer.GetLLMGateSnapshot, + request_deserializer=reasoning_dot_v1_dot_reasoning__pb2.GetLLMGateSnapshotRequest.FromString, + response_serializer=reasoning_dot_v1_dot_reasoning__pb2.GetLLMGateSnapshotResponse.SerializeToString, + ), 'ClassifyQuestion': grpc.unary_unary_rpc_method_handler( servicer.ClassifyQuestion, request_deserializer=reasoning_dot_v1_dot_reasoning__pb2.ClassifyQuestionRequest.FromString, @@ -541,6 +565,33 @@ def GetProviderCapabilities(request, metadata, _registered_method=True) + @staticmethod + def GetLLMGateSnapshot(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sourcebridge.reasoning.v1.ReasoningService/GetLLMGateSnapshot', + reasoning_dot_v1_dot_reasoning__pb2.GetLLMGateSnapshotRequest.SerializeToString, + reasoning_dot_v1_dot_reasoning__pb2.GetLLMGateSnapshotResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def ClassifyQuestion(request, target, diff --git a/internal/api/graphql/knowledge_generation_architecture_diagram.go b/internal/api/graphql/knowledge_generation_architecture_diagram.go index 4a495fe7..85cf29bd 100644 --- a/internal/api/graphql/knowledge_generation_architecture_diagram.go +++ b/internal/api/graphql/knowledge_generation_architecture_diagram.go @@ -136,7 +136,7 @@ func (s architectureDiagramGenerationService) runGenerationPipeline( audience := p.audience depth := p.depth - rt.ReportProgress(0.1, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Snapshot assembled", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.1, "snapshot", "Snapshot assembled") var architectureBundle architectureDiagramPromptBundle @@ -147,7 +147,7 @@ func (s architectureDiagramGenerationService) runGenerationPipeline( } else { understandingForDiagram = understanding if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } } @@ -188,7 +188,7 @@ func (s architectureDiagramGenerationService) runGenerationPipeline( rt.ReportTokens(int(resp.Usage.InputTokens), int(resp.Usage.OutputTokens)) } - rt.ReportProgress(0.96, "llm", "LLM completed, persisting diagram") + rt.ReportProgress(0.96, "llm", "LLM completed, persisting diagram", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.8, "llm", "LLM completed, persisting diagram") sections := []knowledgepkg.Section{{ @@ -230,7 +230,7 @@ func (s architectureDiagramGenerationService) runGenerationPipeline( if err := r.KnowledgeStore.UpdateKnowledgeArtifactStatus(artifact.ID, knowledgepkg.StatusReady); err != nil { return err } - rt.ReportProgress(1.0, "ready", "AI architecture diagram ready") + rt.ReportProgress(1.0, "ready", "AI architecture diagram ready", 0) return nil } diff --git a/internal/api/graphql/knowledge_generation_cliff_notes.go b/internal/api/graphql/knowledge_generation_cliff_notes.go index 329378c0..0e923465 100644 --- a/internal/api/graphql/knowledge_generation_cliff_notes.go +++ b/internal/api/graphql/knowledge_generation_cliff_notes.go @@ -250,10 +250,10 @@ func (s cliffNotesGenerationService) runGenerationPipeline( } }() genStart := time.Now() - rt.ReportProgress(0.1, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Snapshot assembled", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.1, "snapshot", "Snapshot assembled") if reusedUnderstanding { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } appendJobLog(r.Orchestrator, rt, llm.LogLevelInfo, "snapshot", "snapshot_assembled", "Snapshot assembled", map[string]any{ @@ -367,7 +367,7 @@ func (s cliffNotesGenerationService) runGenerationPipeline( if reusedSummaries > 0 { llmMessage = fmt.Sprintf("LLM completed, reused %d summaries, persisting sections", reusedSummaries) } - rt.ReportProgress(0.96, "llm", llmMessage) + rt.ReportProgress(0.96, "llm", llmMessage, 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.8, "llm", llmMessage) if resp.Usage != nil { @@ -455,7 +455,7 @@ func (s cliffNotesGenerationService) runGenerationPipeline( if reusedSummaries > 0 { readyMessage = fmt.Sprintf("Cliff notes ready · reused %d summaries", reusedSummaries) } - rt.ReportProgress(1.0, "ready", readyMessage) + rt.ReportProgress(1.0, "ready", readyMessage, 0) slog.Info("cliff_notes_generation_completed", "artifact_id", artifact.ID, "scope_type", string(scope.ScopeType), diff --git a/internal/api/graphql/knowledge_generation_code_tour.go b/internal/api/graphql/knowledge_generation_code_tour.go index 2e37bdf2..0a738970 100644 --- a/internal/api/graphql/knowledge_generation_code_tour.go +++ b/internal/api/graphql/knowledge_generation_code_tour.go @@ -148,14 +148,14 @@ func (s codeTourGenerationService) runGenerationPipeline( theme := p.theme enrichedSnapJSON := snapJSON - rt.ReportProgress(0.1, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Snapshot assembled", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.1, "snapshot", "Snapshot assembled") if artifactUsesUnderstanding(generationMode) { if understanding, reused, err := r.ensureFreshRepositoryUnderstanding(runCtx, rt, repo, artifact, snap.SourceRevision, snapJSON); err != nil { return err } else { if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } if understanding != nil { @@ -190,7 +190,7 @@ func (s codeTourGenerationService) runGenerationPipeline( return err } - rt.ReportProgress(0.96, "llm", "LLM completed, persisting stops") + rt.ReportProgress(0.96, "llm", "LLM completed, persisting stops", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.8, "llm", "LLM completed, persisting") if resp.Usage != nil { @@ -243,7 +243,7 @@ func (s codeTourGenerationService) runGenerationPipeline( if err := r.KnowledgeStore.UpdateKnowledgeArtifactStatus(artifact.ID, knowledgepkg.StatusReady); err != nil { slog.Error("failed to mark code tour ready", "artifact_id", artifact.ID, "error", err) } - rt.ReportProgress(1.0, "ready", "Code tour ready") + rt.ReportProgress(1.0, "ready", "Code tour ready", 0) slog.Info("code tour generation complete", "artifact_id", artifact.ID) return nil } diff --git a/internal/api/graphql/knowledge_generation_learning_path.go b/internal/api/graphql/knowledge_generation_learning_path.go index 569d71ef..b28c7aa3 100644 --- a/internal/api/graphql/knowledge_generation_learning_path.go +++ b/internal/api/graphql/knowledge_generation_learning_path.go @@ -148,14 +148,14 @@ func (s learningPathGenerationService) runGenerationPipeline( focusArea := p.focusArea enrichedSnapJSON := snapJSON - rt.ReportProgress(0.1, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Snapshot assembled", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.1, "snapshot", "Snapshot assembled") if artifactUsesUnderstanding(generationMode) { if understanding, reused, err := r.ensureFreshRepositoryUnderstanding(runCtx, rt, repo, artifact, snap.SourceRevision, snapJSON); err != nil { return err } else { if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } if understanding != nil { @@ -190,7 +190,7 @@ func (s learningPathGenerationService) runGenerationPipeline( return err } - rt.ReportProgress(0.96, "llm", "LLM completed, persisting steps") + rt.ReportProgress(0.96, "llm", "LLM completed, persisting steps", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.8, "llm", "LLM completed, persisting") if resp.Usage != nil { @@ -248,7 +248,7 @@ func (s learningPathGenerationService) runGenerationPipeline( if err := r.KnowledgeStore.UpdateKnowledgeArtifactStatus(artifact.ID, knowledgepkg.StatusReady); err != nil { slog.Error("failed to mark learning path ready", "artifact_id", artifact.ID, "error", err) } - rt.ReportProgress(1.0, "ready", "Learning path ready") + rt.ReportProgress(1.0, "ready", "Learning path ready", 0) slog.Info("learning path generation complete", "artifact_id", artifact.ID) return nil } diff --git a/internal/api/graphql/knowledge_generation_workflow_story.go b/internal/api/graphql/knowledge_generation_workflow_story.go index 26dde0b6..7e800d3a 100644 --- a/internal/api/graphql/knowledge_generation_workflow_story.go +++ b/internal/api/graphql/knowledge_generation_workflow_story.go @@ -159,14 +159,14 @@ func (s workflowStoryGenerationService) runGenerationPipeline( executionPathJSON := p.executionPathJSON enrichedSnapJSON := snapJSON - rt.ReportProgress(0.1, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Snapshot assembled", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.1, "snapshot", "Snapshot assembled") if artifactUsesUnderstanding(generationMode) { if understanding, reused, err := r.ensureFreshRepositoryUnderstanding(runCtx, rt, repo, artifact, snap.SourceRevision, snapJSON); err != nil { return err } else { if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } if understanding != nil { @@ -204,7 +204,7 @@ func (s workflowStoryGenerationService) runGenerationPipeline( return err } - rt.ReportProgress(0.96, "llm", "LLM completed, persisting sections") + rt.ReportProgress(0.96, "llm", "LLM completed, persisting sections", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.8, "llm", "LLM completed, persisting") if resp.Usage != nil { @@ -240,7 +240,7 @@ func (s workflowStoryGenerationService) runGenerationPipeline( if err := r.KnowledgeStore.UpdateKnowledgeArtifactStatus(artifact.ID, knowledgepkg.StatusReady); err != nil { slog.Error("failed to mark workflow story ready", "artifact_id", artifact.ID, "error", err) } - rt.ReportProgress(1.0, "ready", "Workflow story ready") + rt.ReportProgress(1.0, "ready", "Workflow story ready", 0) slog.Info("workflow story generation complete", "artifact_id", artifact.ID) return nil } diff --git a/internal/api/graphql/knowledge_job.go b/internal/api/graphql/knowledge_job.go index ecf7c346..6813b945 100644 --- a/internal/api/graphql/knowledge_job.go +++ b/internal/api/graphql/knowledge_job.go @@ -43,7 +43,7 @@ func setKnowledgeQueueHeartbeatInterval(interval time.Duration) { // owned this type was deleted in CA-122 once every knowledge RPC // became server-streaming and could surface real per-phase progress // to the orchestrator's UpdatedAt heartbeat. -type progressWriter func(progress float64, phase, message string) error +type progressWriter func(progress float64, phase, message string, throughputTPS float64) error // knowledgeJobTargetKey returns the canonical dedupe key the orchestrator // uses for a knowledge artifact generation job. Matching keys collapse to @@ -150,7 +150,7 @@ func startKnowledgeQueueHeartbeat(ctx context.Context, rt llm.Runtime, artifactI case <-hbCtx.Done(): return case <-tick.C: - rt.ReportProgress(0.02, "queued", "Waiting for knowledge generation slot") + rt.ReportProgress(0.02, "queued", "Waiting for knowledge generation slot", 0) if store != nil && artifactID != "" { if err := store.UpdateKnowledgeArtifactProgressWithPhase(artifactID, 0.02, "queued", "Waiting for knowledge generation slot"); err != nil { knowledgeProgressWriteErrorsTotal.Add(1) @@ -230,7 +230,7 @@ func (r *Resolver) enqueueKnowledgeJob( GenerationMode: string(artifact.GenerationMode), MaxAttempts: knowledgeJobMaxAttempts(artifact, scope), RunWithContext: func(runCtx context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.02, "queued", "Waiting for knowledge generation slot") + rt.ReportProgress(0.02, "queued", "Waiting for knowledge generation slot", 0) if err := r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.02, "queued", "Waiting for knowledge generation slot"); err != nil { knowledgeProgressWriteErrorsTotal.Add(1) slog.Warn("knowledge_progress_write_failed", @@ -340,7 +340,7 @@ func enqueueRepositoryUnderstandingJob( // rather than redoing all the work. MaxAttempts: 3, RunWithContext: func(runCtx context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.02, "queued", "Waiting for knowledge generation slot") + rt.ReportProgress(0.02, "queued", "Waiting for knowledge generation slot", 0) appendJobLog(r.Orchestrator, rt, llm.LogLevelInfo, "queued", "knowledge_slot_wait_started", "Waiting for knowledge generation slot", map[string]any{ "job_type": "build_repository_understanding", }) diff --git a/internal/api/graphql/knowledge_job_test.go b/internal/api/graphql/knowledge_job_test.go index ef1a5df8..4d416123 100644 --- a/internal/api/graphql/knowledge_job_test.go +++ b/internal/api/graphql/knowledge_job_test.go @@ -42,7 +42,7 @@ func TestEnqueueKnowledgeJobCreatesQueuedKnowledgeJob(t *testing.T) { block := make(chan struct{}) err = r.enqueueKnowledgeJob(context.Background(), artifact, "seed:cliff_notes", 1234, func(_ context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.2, "snapshot", "queued") + rt.ReportProgress(0.2, "snapshot", "queued", 0) <-block return nil }) @@ -139,7 +139,7 @@ func TestKnowledgeJobsShareGlobalConcurrencyGate(t *testing.T) { first := makeArtifact("repo-1", knowledge.ArtifactCliffNotes) if err := r.enqueueKnowledgeJob(context.Background(), first, "cliff_notes", 100, func(_ context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.25, "generating", "first") + rt.ReportProgress(0.25, "generating", "first", 0) entered <- "cliff_notes" <-releaseRunning return nil @@ -149,7 +149,7 @@ func TestKnowledgeJobsShareGlobalConcurrencyGate(t *testing.T) { second := makeArtifact("repo-1", knowledge.ArtifactCodeTour) if err := r.enqueueKnowledgeJob(context.Background(), second, "code_tour", 100, func(_ context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.25, "generating", "second") + rt.ReportProgress(0.25, "generating", "second", 0) entered <- "code_tour" <-releaseRunning return nil @@ -276,7 +276,7 @@ func TestQueuedKnowledgeJobsHeartbeatWhileWaitingForGate(t *testing.T) { stop := startKnowledgeQueueHeartbeat(ctx, testRuntime{ jobID: "job-1", setProgress: func(progress float64, phase, message string) { - _ = jobStore.SetProgress("job-1", progress, phase, message) + _ = jobStore.SetProgress("job-1", progress, phase, message, 0) }, }, artifact.ID, knowledgeStore) defer stop() @@ -309,7 +309,7 @@ type testRuntime struct { func (t testRuntime) JobID() string { return t.jobID } -func (t testRuntime) ReportProgress(progress float64, phase, message string) { +func (t testRuntime) ReportProgress(progress float64, phase, message string, _ float64) { if t.setProgress != nil { t.setProgress(progress, phase, message) } diff --git a/internal/api/graphql/knowledge_seed.go b/internal/api/graphql/knowledge_seed.go index 78df8556..4e54823b 100644 --- a/internal/api/graphql/knowledge_seed.go +++ b/internal/api/graphql/knowledge_seed.go @@ -114,7 +114,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, run := func(runCtx context.Context, rt llm.Runtime) error { snapshotBytes := []byte(snapshotJSON) enrichedSnapshotJSON := snapshotJSON - rt.ReportProgress(0.1, "snapshot", "Seed snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Seed snapshot assembled", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.1, "snapshot", "Seed snapshot assembled") // CA-122 Phase 6/7: stream-driven progress for the seed pipeline. streamDriver := r.runStreamProgressDriver(runCtx, rt, artifact.ID, rpcBucketForArtifact(artifact)) @@ -140,7 +140,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, if _, err := updateUnderstandingForCliffNotes(r.KnowledgeStore, artifact, key.Scope, sourceRevision, resp, knowledgepkg.UnderstandingFirstPassReady); err != nil { slog.Warn("failed to update repository understanding from seed cliff notes", "artifact_id", artifact.ID, "error", err) } - rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting") + rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting", 0) sections := make([]knowledgepkg.Section, len(resp.Sections)) for i, sec := range resp.Sections { sections[i] = knowledgepkg.Section{ @@ -160,7 +160,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, return err } else { if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } if understanding != nil { @@ -181,7 +181,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, if err != nil { return err } - rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting") + rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting", 0) sections := make([]knowledgepkg.Section, len(resp.Steps)) for i, step := range resp.Steps { sections[i] = knowledgepkg.Section{ @@ -199,7 +199,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, return err } else { if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } if understanding != nil { @@ -220,7 +220,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, if err != nil { return err } - rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting") + rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting", 0) sections := make([]knowledgepkg.Section, len(resp.Stops)) for i, stop := range resp.Stops { sections[i] = knowledgepkg.Section{ @@ -244,7 +244,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, return err } else { if reused { - rt.ReportProgress(0.12, "understanding", "Using cached repository understanding") + rt.ReportProgress(0.12, "understanding", "Using cached repository understanding", 0) _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Using cached repository understanding") } if understanding != nil { @@ -267,7 +267,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, if err != nil { return err } - rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting") + rt.ReportProgress(0.96, "llm", "Seed LLM completed, persisting", 0) sections := make([]knowledgepkg.Section, len(resp.Sections)) for i, sec := range resp.Sections { sections[i] = knowledgepkg.Section{ @@ -285,7 +285,7 @@ func (r *mutationResolver) ensureKnowledgeArtifact(repo *graphstore.Repository, default: return nil } - rt.ReportProgress(1.0, "ready", "Seed artifact ready") + rt.ReportProgress(1.0, "ready", "Seed artifact ready", 0) return nil } diff --git a/internal/api/graphql/knowledge_stream_driver.go b/internal/api/graphql/knowledge_stream_driver.go index 3e136c7c..56b18b63 100644 --- a/internal/api/graphql/knowledge_stream_driver.go +++ b/internal/api/graphql/knowledge_stream_driver.go @@ -241,8 +241,8 @@ func (d *streamProgressDriver) handlePhase(pm *commonv1.KnowledgeStreamPhaseMark d.currentMu.Unlock() msg := phaseLabel(pm.GetPhase()) - d.rt.ReportProgress(pct, "generating", msg) - if err := d.write(pct, "generating", msg); err != nil { + d.rt.ReportProgress(pct, "generating", msg, 0) + if err := d.write(pct, "generating", msg, 0); err != nil { d.logWriteErr(err, "phase") } } @@ -281,8 +281,9 @@ func (d *streamProgressDriver) handleProgress(p *commonv1.KnowledgeStreamProgres if msg == "" { msg = phaseLabel(d.curPhase) } - d.rt.ReportProgress(pct, "generating", msg) - if err := d.write(pct, "generating", msg); err != nil { + throughputTPS := float64(p.GetCurrentTokensPerSecond()) + d.rt.ReportProgress(pct, "generating", msg, throughputTPS) + if err := d.write(pct, "generating", msg, throughputTPS); err != nil { d.logWriteErr(err, "progress") } } @@ -324,7 +325,7 @@ func (r *Resolver) runStreamProgressDriver( _ = ctx // reserved for future per-driver cancellation; the // streaming RPC's own ctx already governs its lifetime. return newStreamProgressDriver(rt, - func(p float64, phase, msg string) error { + func(p float64, phase, msg string, _ float64) error { return r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifactID, p, phase, msg) }, kind, @@ -343,7 +344,7 @@ func (r *Resolver) runUnderstandingStreamDriver( ) *streamProgressDriver { _ = ctx return newStreamProgressDriver(rt, - func(p float64, phase, msg string) error { + func(p float64, phase, msg string, _ float64) error { return r.KnowledgeStore.UpdateRepositoryUnderstandingProgress(understandingID, p, phase, msg) }, kind, diff --git a/internal/api/graphql/knowledge_stream_driver_test.go b/internal/api/graphql/knowledge_stream_driver_test.go index 4896244e..8c279b05 100644 --- a/internal/api/graphql/knowledge_stream_driver_test.go +++ b/internal/api/graphql/knowledge_stream_driver_test.go @@ -23,17 +23,18 @@ type streamFakeRuntime struct { } type reportedProgress struct { - pct float64 - phase string - message string + pct float64 + phase string + message string + throughputTPS float64 } func (f *streamFakeRuntime) JobID() string { return "fake-job" } -func (f *streamFakeRuntime) ReportProgress(p float64, phase, message string) { +func (f *streamFakeRuntime) ReportProgress(p float64, phase, message string, throughputTPS float64) { f.mu.Lock() defer f.mu.Unlock() - f.progress = append(f.progress, reportedProgress{pct: p, phase: phase, message: message}) + f.progress = append(f.progress, reportedProgress{pct: p, phase: phase, message: message, throughputTPS: throughputTPS}) } func (f *streamFakeRuntime) ReportTokens(int, int) {} @@ -58,7 +59,7 @@ func TestStreamProgressDriverWritesPhaseAndProgress(t *testing.T) { writes []reportedProgress mu sync.Mutex ) - write := func(p float64, phase, message string) error { + write := func(p float64, phase, message string, _ float64) error { mu.Lock() defer mu.Unlock() writes = append(writes, reportedProgress{pct: p, phase: phase, message: message}) @@ -120,7 +121,7 @@ func TestStreamProgressDriverDropsOldestUnderBackpressure(t *testing.T) { var lastSeen atomic.Int32 var writeMu sync.Mutex var writes []float64 - write := func(p float64, phase, message string) error { + write := func(p float64, phase, message string, _ float64) error { <-gate writeMu.Lock() writes = append(writes, p) @@ -168,13 +169,62 @@ func TestStreamProgressDriverDropsOldestUnderBackpressure(t *testing.T) { } } +// TestStreamProgressDriverPropagatesThroughputTPS verifies that a +// KnowledgeStreamProgress message carrying a non-zero +// current_tokens_per_second value flows from the proto field through +// handleProgress into both rt.ReportProgress and the row writer. +func TestStreamProgressDriverPropagatesThroughputTPS(t *testing.T) { + rt := &streamFakeRuntime{} + var ( + writtenTPS []float64 + mu sync.Mutex + ) + write := func(p float64, phase, message string, throughputTPS float64) error { + mu.Lock() + defer mu.Unlock() + writtenTPS = append(writtenTPS, throughputTPS) + return nil + } + + d := newStreamProgressDriver(rt, write, rpcBucketCollapsed, "artifact_id", "tps-test") + defer d.Close() + + handler := d.OnProgress() + handler(worker.KnowledgeStreamEvent{Progress: &commonv1.KnowledgeStreamProgress{ + Phase: commonv1.KnowledgePhase_KNOWLEDGE_PHASE_RENDER, + CompletedUnits: 1, + TotalUnits: 2, + UnitKind: "section", + Message: "rendering", + CurrentTokensPerSecond: 12.3, + }}) + d.Close() + + rtSnap := rt.snapshot() + mu.Lock() + defer mu.Unlock() + + if len(rtSnap) == 0 { + t.Fatal("expected at least one ReportProgress call") + } + if got := rtSnap[len(rtSnap)-1].throughputTPS; got < 12.2 || got > 12.4 { + t.Fatalf("expected throughputTPS≈12.3 in rt.ReportProgress, got %v", got) + } + if len(writtenTPS) == 0 { + t.Fatal("expected at least one row write") + } + if got := writtenTPS[len(writtenTPS)-1]; got < 12.2 || got > 12.4 { + t.Fatalf("expected throughputTPS≈12.3 in write callback, got %v", got) + } +} + // TestStreamProgressDriverCloseDrains exercises the Close-before- // terminal-state-write contract: events queued before Close are all // written before Close returns. func TestStreamProgressDriverCloseDrains(t *testing.T) { rt := &streamFakeRuntime{} var writes int32 - write := func(p float64, phase, message string) error { + write := func(p float64, phase, message string, _ float64) error { atomic.AddInt32(&writes, 1) // Small artificial delay to verify Close waits. time.Sleep(2 * time.Millisecond) diff --git a/internal/api/graphql/knowledge_support.go b/internal/api/graphql/knowledge_support.go index db224cc4..a66f0500 100644 --- a/internal/api/graphql/knowledge_support.go +++ b/internal/api/graphql/knowledge_support.go @@ -1219,7 +1219,7 @@ func (r *Resolver) ensureFreshRepositoryUnderstanding( "reason", "cliff_notes_already_generating", "cliff_notes_id", existingCN.ID) if rt != nil { - rt.ReportProgress(0.12, "understanding", "Cliff notes in progress — proceeding without understanding") + rt.ReportProgress(0.12, "understanding", "Cliff notes in progress — proceeding without understanding", 0) } return nil, false, nil } @@ -1229,7 +1229,7 @@ func (r *Resolver) ensureFreshRepositoryUnderstanding( slog.Warn("failed to seed repository understanding", "artifact_id", artifact.ID, "error", err) } if rt != nil { - rt.ReportProgress(0.12, "understanding", "Building repository understanding") + rt.ReportProgress(0.12, "understanding", "Building repository understanding", 0) } _ = r.KnowledgeStore.UpdateKnowledgeArtifactProgressWithPhase(artifact.ID, 0.12, "understanding", "Building repository understanding") streamDriver := r.runStreamProgressDriver(ctx, rt, artifact.ID, rpcBucketHierarchical) @@ -1846,7 +1846,7 @@ func (r *Resolver) enqueueSingleCliffNotesDeepening( GenerationMode: string(artifact.GenerationMode), MaxAttempts: 2, RunWithContext: func(runCtx context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.05, "deepening", "Deepening critical cliff note sections") + rt.ReportProgress(0.05, "deepening", "Deepening critical cliff note sections", 0) markCliffNotesDeepRefinementStatus(r.KnowledgeStore, artifact, r.KnowledgeStore.GetKnowledgeSections(artifact.ID), selectedTitles, knowledgepkg.RefinementRunning, "") bgCtx := withCliffNotesRenderMetadata(runCtx, true, selectedTitles, string(knowledgepkg.DepthMedium), "product_core") streamDriver := r.runStreamProgressDriver(bgCtx, rt, artifact.ID, rpcBucketForArtifact(artifact)) @@ -1899,7 +1899,7 @@ func (r *Resolver) enqueueSingleCliffNotesDeepening( } outcome, outcomeError := cliffNotesDeepeningOutcome(merged, selectedTitles) markCliffNotesDeepRefinementStatus(r.KnowledgeStore, artifact, merged, selectedTitles, outcome, outcomeError) - rt.ReportProgress(1.0, "ready", "Section deepening complete") + rt.ReportProgress(1.0, "ready", "Section deepening complete", 0) return nil }, } diff --git a/internal/api/graphql/living_wiki_coldstart_test.go b/internal/api/graphql/living_wiki_coldstart_test.go index 3d1d75c3..a6aeeba0 100644 --- a/internal/api/graphql/living_wiki_coldstart_test.go +++ b/internal/api/graphql/living_wiki_coldstart_test.go @@ -172,11 +172,11 @@ func csRunnerFromPages( total := len(pages) if total == 0 { - rt.ReportProgress(1.0, "ok", "no pages") + rt.ReportProgress(1.0, "ok", "no pages", 0) return nil } - rt.ReportProgress(0.05, "generating", fmt.Sprintf("starting %d pages", total)) + rt.ReportProgress(0.05, "generating", fmt.Sprintf("starting %d pages", total), 0) var generated, excludedCount int32 var excludedIDsAcc atomicStringSlice @@ -195,7 +195,7 @@ func csRunnerFromPages( } done := int(atomic.LoadInt32(&generated)) + int(atomic.LoadInt32(&excludedCount)) rt.ReportProgress(0.05+0.90*float64(done)/float64(total), - "generating", fmt.Sprintf("%d/%d", done, total)) + "generating", fmt.Sprintf("%d/%d", done, total), 0) }, } @@ -219,7 +219,7 @@ func csRunnerFromPages( finalGen := int(atomic.LoadInt32(&generated)) finalExcl := int(atomic.LoadInt32(&excludedCount)) - rt.ReportProgress(1.0, status, fmt.Sprintf("%d gen, %d excl", finalGen, finalExcl)) + rt.ReportProgress(1.0, status, fmt.Sprintf("%d gen, %d excl", finalGen, finalExcl), 0) if jrs != nil { now := time.Now() @@ -264,7 +264,7 @@ type fakeRuntime struct { } func (f *fakeRuntime) JobID() string { return f.jobID } -func (f *fakeRuntime) ReportProgress(p float64, phase, _ string) { +func (f *fakeRuntime) ReportProgress(p float64, phase, _ string, _ float64) { f.mu.Lock() defer f.mu.Unlock() f.progress = p @@ -598,7 +598,7 @@ func TestColdStartJobAppearsInSharedActivityFeed(t *testing.T) { RepoID: "feed-test", Priority: llm.PriorityInteractive, RunWithContext: func(runCtx context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.1, "generating", "testing") + rt.ReportProgress(0.1, "generating", "testing", 0) select { case <-block: case <-runCtx.Done(): @@ -1318,11 +1318,11 @@ func csRunnerFromPagesWithSinks( total := len(pages) if total == 0 { - rt.ReportProgress(1.0, "ok", "no pages") + rt.ReportProgress(1.0, "ok", "no pages", 0) return nil } - rt.ReportProgress(0.05, "generating", fmt.Sprintf("starting %d pages", total)) + rt.ReportProgress(0.05, "generating", fmt.Sprintf("starting %d pages", total), 0) var generated, excludedCount int32 var excludedIDsAcc atomicStringSlice @@ -1341,7 +1341,7 @@ func csRunnerFromPagesWithSinks( } done := int(atomic.LoadInt32(&generated)) + int(atomic.LoadInt32(&excludedCount)) rt.ReportProgress(0.05+0.90*float64(done)/float64(total), - "generating", fmt.Sprintf("%d/%d", done, total)) + "generating", fmt.Sprintf("%d/%d", done, total), 0) }, } @@ -1377,7 +1377,7 @@ func csRunnerFromPagesWithSinks( ) } - rt.ReportProgress(1.0, status, fmt.Sprintf("%d gen, %d excl", finalGen, finalExcl)) + rt.ReportProgress(1.0, status, fmt.Sprintf("%d gen, %d excl", finalGen, finalExcl), 0) if jrs != nil { now := time.Now() diff --git a/internal/api/graphql/llm_sync.go b/internal/api/graphql/llm_sync.go index 649f5153..93f6c6da 100644 --- a/internal/api/graphql/llm_sync.go +++ b/internal/api/graphql/llm_sync.go @@ -81,7 +81,7 @@ func (r *Resolver) runSyncLLMJob( type noopRuntime struct{} func (noopRuntime) JobID() string { return "" } -func (noopRuntime) ReportProgress(progress float64, phase, message string) {} +func (noopRuntime) ReportProgress(progress float64, phase, message string, throughputTPS float64) {} func (noopRuntime) ReportTokens(input, output int) {} func (noopRuntime) ReportSnapshotBytes(bytes int) {} func (noopRuntime) Heartbeat() error { return nil } diff --git a/internal/api/graphql/schema.resolvers.go b/internal/api/graphql/schema.resolvers.go index a41088f8..45b35aba 100644 --- a/internal/api/graphql/schema.resolvers.go +++ b/internal/api/graphql/schema.resolvers.go @@ -389,7 +389,7 @@ func (r *mutationResolver) BuildRepositoryUnderstanding(ctx context.Context, inp } err = enqueueRepositoryUnderstandingJob(ctx, r.Resolver, repo, understanding, scope, snapshotJSON, func(runCtx context.Context, rt llm.Runtime) error { - rt.ReportProgress(0.1, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.1, "snapshot", "Snapshot assembled", 0) appendJobLog(r.Orchestrator, rt, llm.LogLevelInfo, "snapshot", "snapshot_assembled", "Snapshot assembled", map[string]any{ "snapshot_bytes": len(snapshotJSON), "scope_type": string(scope.ScopeType), @@ -423,7 +423,7 @@ func (r *mutationResolver) BuildRepositoryUnderstanding(ctx context.Context, inp if _, err := updateUnderstandingForCliffNotes(r.KnowledgeStore, &knowledgepkg.Artifact{RepositoryID: repo.ID}, scope, snap.SourceRevision, resp, knowledgepkg.UnderstandingFirstPassReady); err != nil { return err } - rt.ReportProgress(1.0, "ready", "Repository understanding ready") + rt.ReportProgress(1.0, "ready", "Repository understanding ready", 0) appendJobLog(r.Orchestrator, rt, llm.LogLevelInfo, "ready", "repository_understanding_ready", "Repository understanding ready", map[string]any{ "cached_nodes": func() int32 { if resp.Diagnostics == nil { diff --git a/internal/api/rest/admin_llm_monitor.go b/internal/api/rest/admin_llm_monitor.go index 0181dd6d..e0a07cc9 100644 --- a/internal/api/rest/admin_llm_monitor.go +++ b/internal/api/rest/admin_llm_monitor.go @@ -4,11 +4,13 @@ package rest import ( + "context" "encoding/json" "fmt" "net/http" "sort" "strconv" + "sync" "time" "github.com/go-chi/chi/v5" @@ -33,6 +35,46 @@ type monitorActivityResponse struct { // Stats is derived queue state that tests and the Monitor header // rely on (max concurrency, current in-flight, pending queue depth). Stats monitorStats `json:"stats"` + // GateSnapshot is a point-in-time view of per-provider LLM concurrency + // gates sourced from the worker via GetLLMGateSnapshot. Omitted entirely + // when the worker is unreachable, the call errors, or the wrapper is + // kill-switched, so old/disabled-wrapper deployments don't surface a + // misleading empty array. + GateSnapshot []monitorGateEntry `json:"gate_snapshot,omitempty"` +} + +// monitorGateEntry mirrors LLMGateEntry from the proto — one row per +// (provider, base_url_normalized, kind) gate in the worker registry. +type monitorGateEntry struct { + Provider string `json:"provider"` + BaseURLNormalized string `json:"base_url_normalized"` + Kind string `json:"kind"` + InFlight int `json:"in_flight"` + Queued int `json:"queued"` + MaxConcurrent int `json:"max_concurrent"` + RetriesSinceStart int64 `json:"retries_since_start"` + Recent429Count int64 `json:"recent_429_count"` + TokensPerSecond float64 `json:"tokens_per_second"` + RPM int `json:"rpm,omitempty"` +} + +// gateSnapshotFetcher is a narrow interface over the gRPC call so the cache +// can be unit-tested without a live worker connection. The production wiring +// uses *worker.Client directly; tests inject a stub. +type gateSnapshotFetcher interface { + GetLLMGateSnapshot(ctx context.Context) ([]monitorGateEntry, error) +} + +// gateSnapshotCache holds the last successful gate snapshot from the worker +// and the time it was fetched. The 1-second TTL prevents hammering the worker +// when the admin monitor is polled every 2 seconds. +// +// The optional fetcher field overrides the default s.worker path for testing. +type gateSnapshotCache struct { + mu sync.Mutex + entries []monitorGateEntry + fetchedAt time.Time + fetcher gateSnapshotFetcher // non-nil in tests only; nil = use s.worker } // monitorHealth is the traffic-light summary at the top of the Monitor @@ -127,16 +169,20 @@ type monitorJobView struct { SkippedFileUnits int `json:"skipped_file_units"` SkippedPackageUnits int `json:"skipped_package_units"` SkippedRootUnits int `json:"skipped_root_units"` - ArtifactID string `json:"artifact_id,omitempty"` - RepoID string `json:"repo_id,omitempty"` - ElapsedMs int64 `json:"elapsed_ms"` - QueuePosition int `json:"queue_position,omitempty"` - QueueDepth int `json:"queue_depth,omitempty"` - EstimatedWaitMs int64 `json:"estimated_wait_ms,omitempty"` - CreatedAt time.Time `json:"created_at"` - StartedAt *time.Time `json:"started_at,omitempty"` - UpdatedAt time.Time `json:"updated_at"` - CompletedAt *time.Time `json:"completed_at,omitempty"` + ArtifactID string `json:"artifact_id,omitempty"` + RepoID string `json:"repo_id,omitempty"` + // CurrentTokensPerSecond is the instantaneous LLM throughput sampled + // from the gate's 60-second ring buffer at the last progress update. + // Zero means unknown; consumers MUST treat zero as "unknown". + CurrentTokensPerSecond float64 `json:"current_tokens_per_second,omitempty"` + ElapsedMs int64 `json:"elapsed_ms"` + QueuePosition int `json:"queue_position,omitempty"` + QueueDepth int `json:"queue_depth,omitempty"` + EstimatedWaitMs int64 `json:"estimated_wait_ms,omitempty"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + UpdatedAt time.Time `json:"updated_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` } type monitorJobLogView struct { @@ -239,13 +285,14 @@ func toMonitorJobView(j *llm.Job) monitorJobView { SkippedFileUnits: j.SkippedFileUnits, SkippedPackageUnits: j.SkippedPackageUnits, SkippedRootUnits: j.SkippedRootUnits, - ArtifactID: j.ArtifactID, - RepoID: j.RepoID, - ElapsedMs: j.Elapsed().Milliseconds(), - CreatedAt: j.CreatedAt, - StartedAt: j.StartedAt, - UpdatedAt: j.UpdatedAt, - CompletedAt: j.CompletedAt, + ArtifactID: j.ArtifactID, + RepoID: j.RepoID, + CurrentTokensPerSecond: j.CurrentTokensPerSecond, + ElapsedMs: j.Elapsed().Milliseconds(), + CreatedAt: j.CreatedAt, + StartedAt: j.StartedAt, + UpdatedAt: j.UpdatedAt, + CompletedAt: j.CompletedAt, } } @@ -377,10 +424,84 @@ func (s *Server) handleLLMActivity(w http.ResponseWriter, r *http.Request) { PendingMaintenance: countPendingPriority(pending, llm.PriorityMaintenance), PendingPrewarm: countPendingPriority(pending, llm.PriorityPrewarm), }, + GateSnapshot: s.fetchGateSnapshot(r.Context()), } writeJSON(w, http.StatusOK, resp) } +// fetchGateSnapshot returns the worker's gate snapshot, using a 1-second +// in-process cache to avoid hammering the worker on every 2-second poll. +// +// Returns nil (not an empty slice) when: +// - s.worker is nil (no worker configured) and no test fetcher is set +// - the worker is unreachable or the RPC errors +// - the snapshot is empty (kill-switch off or no gates registered yet) +// +// nil causes the GateSnapshot field to be omitted (omitempty) so old +// or kill-switched deployments don't surface a misleading empty array. +func (s *Server) fetchGateSnapshot(ctx context.Context) []monitorGateEntry { + s.gateSnapshotCache.mu.Lock() + defer s.gateSnapshotCache.mu.Unlock() + + // Determine the fetch function to use (test hook or production path). + var fetch func(context.Context) ([]monitorGateEntry, error) + if s.gateSnapshotCache.fetcher != nil { + fetch = s.gateSnapshotCache.fetcher.GetLLMGateSnapshot + } else if s.worker != nil { + fetch = func(ctx context.Context) ([]monitorGateEntry, error) { + resp, err := s.worker.GetLLMGateSnapshot(ctx) + if err != nil || resp == nil { + return nil, err + } + entries := make([]monitorGateEntry, 0, len(resp.Gates)) + for _, g := range resp.Gates { + entries = append(entries, monitorGateEntry{ + Provider: g.Provider, + BaseURLNormalized: g.BaseUrlNormalized, + Kind: g.Kind, + InFlight: int(g.InFlight), + Queued: int(g.Queued), + MaxConcurrent: int(g.MaxConcurrent), + RetriesSinceStart: g.RetriesSinceStart, + Recent429Count: g.Recent_429Count, + TokensPerSecond: g.TokensPerSecond, + RPM: int(g.Rpm), + }) + } + return entries, nil + } + } else { + return nil + } + + if time.Since(s.gateSnapshotCache.fetchedAt) < time.Second { + // Cache hit: return a copy to avoid races if the caller holds a reference. + if len(s.gateSnapshotCache.entries) == 0 { + return nil + } + out := make([]monitorGateEntry, len(s.gateSnapshotCache.entries)) + copy(out, s.gateSnapshotCache.entries) + return out + } + + // Cache miss: fetch from the worker (lock held so concurrent callers don't + // all fire simultaneously; the 1-second TTL means at most one caller + // waits for the gRPC call at a time). + entries, err := fetch(ctx) + if err != nil || len(entries) == 0 { + // Don't update the cache timestamp on error so the next caller retries + // immediately rather than waiting the full TTL on a transient failure. + return nil + } + + s.gateSnapshotCache.entries = entries + s.gateSnapshotCache.fetchedAt = time.Now() + + out := make([]monitorGateEntry, len(entries)) + copy(out, entries) + return out +} + func modeRollups(jobs []monitorJobView) map[string]monitorModeRollup { if len(jobs) == 0 { return nil diff --git a/internal/api/rest/admin_llm_monitor_test.go b/internal/api/rest/admin_llm_monitor_test.go index eec04b55..9d617771 100644 --- a/internal/api/rest/admin_llm_monitor_test.go +++ b/internal/api/rest/admin_llm_monitor_test.go @@ -4,7 +4,9 @@ package rest import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" @@ -17,6 +19,34 @@ import ( "github.com/sourcebridge/sourcebridge/internal/llm/orchestrator" ) +// ────────────────────────────────────────────────────────────────────────────── +// Gate snapshot test helpers + +// fakeGateFetcher is a test stub for gateSnapshotFetcher. It records how many +// times GetLLMGateSnapshot was called and returns a fixed response. +type fakeGateFetcher struct { + entries []monitorGateEntry + err error + calls int +} + +func (f *fakeGateFetcher) GetLLMGateSnapshot(_ context.Context) ([]monitorGateEntry, error) { + f.calls++ + if f.err != nil { + return nil, f.err + } + return f.entries, nil +} + +// newMonitorTestServerWithGateFetcher builds a Server with a gate fetcher stub +// wired in for testing fetchGateSnapshot without a live worker. +func newMonitorTestServerWithGateFetcher(t *testing.T, fetcher gateSnapshotFetcher) *Server { + t.Helper() + s := newMonitorTestServer(t) + s.gateSnapshotCache.fetcher = fetcher + return s +} + // newMonitorTestServer builds a Server instance wired to an isolated // orchestrator + in-memory JobStore, sufficient for testing the // monitor HTTP handlers without pulling in the full server stack. @@ -239,7 +269,7 @@ func TestHandleLLMActivityShowsCompletedJob(t *testing.T) { JobType: "cliff_notes", TargetKey: "repo-1:activity", Run: func(rt llm.Runtime) error { - rt.ReportProgress(0.5, "midway", "halfway") + rt.ReportProgress(0.5, "midway", "halfway", 0) rt.ReportTokens(200, 150) close(done) return nil @@ -314,7 +344,7 @@ func TestHandleLLMJobLogs(t *testing.T) { JobType: "cliff_notes", TargetKey: "repo-1:logs", Run: func(rt llm.Runtime) error { - rt.ReportProgress(0.25, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.25, "snapshot", "Snapshot assembled", 0) return nil }, }) @@ -457,6 +487,53 @@ func TestParseListFilterBasicFields(t *testing.T) { } } +func TestMonitorJobViewCurrentTokensPerSecondSerializesWhenNonZero(t *testing.T) { + view := monitorJobView{ + ID: "job-123", + Status: string(llm.StatusGenerating), + CurrentTokensPerSecond: 12.3, + ElapsedMs: 500, + UpdatedAt: time.Now(), + } + b, err := json.Marshal(view) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var out map[string]any + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("unmarshal: %v", err) + } + got, ok := out["current_tokens_per_second"] + if !ok { + t.Fatal("expected current_tokens_per_second key in JSON output") + } + // JSON numbers unmarshal to float64 from map[string]any. + if v, _ := got.(float64); v < 12.2 || v > 12.4 { + t.Fatalf("expected current_tokens_per_second≈12.3, got %v", got) + } +} + +func TestMonitorJobViewCurrentTokensPerSecondOmittedWhenZero(t *testing.T) { + view := monitorJobView{ + ID: "job-456", + Status: string(llm.StatusGenerating), + ElapsedMs: 100, + UpdatedAt: time.Now(), + // CurrentTokensPerSecond intentionally zero (default) + } + b, err := json.Marshal(view) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var out map[string]any + if err := json.Unmarshal(b, &out); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if _, ok := out["current_tokens_per_second"]; ok { + t.Fatal("expected current_tokens_per_second to be omitted when zero") + } +} + func TestEventMatchesFilter(t *testing.T) { job := &llm.Job{ Subsystem: llm.SubsystemKnowledge, @@ -488,3 +565,128 @@ func TestEventMatchesFilter(t *testing.T) { t.Fatal("mismatched repo should not pass") } } + +// ────────────────────────────────────────────────────────────────────────────── +// Phase 7: fetchGateSnapshot + handleLLMActivity gate_snapshot field tests + +// TestHandleLLMActivityGateSnapshotPopulated verifies that handleLLMActivity +// includes gate_snapshot when the worker stub returns a non-empty snapshot. +func TestHandleLLMActivityGateSnapshotPopulated(t *testing.T) { + want := []monitorGateEntry{ + { + Provider: "ollama", + BaseURLNormalized: "http://localhost:11434", + Kind: "llm", + InFlight: 1, + Queued: 2, + MaxConcurrent: 4, + RetriesSinceStart: 3, + Recent429Count: 1, + TokensPerSecond: 12.5, + }, + { + Provider: "ollama", + BaseURLNormalized: "http://localhost:11434", + Kind: "embedding", + InFlight: 0, + Queued: 0, + MaxConcurrent: 4, + TokensPerSecond: 0, + }, + } + fetcher := &fakeGateFetcher{entries: want} + s := newMonitorTestServerWithGateFetcher(t, fetcher) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/llm/activity", nil) + w := httptest.NewRecorder() + s.handleLLMActivity(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + var resp monitorActivityResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(resp.GateSnapshot) != 2 { + t.Fatalf("expected 2 gate entries, got %d", len(resp.GateSnapshot)) + } + got := resp.GateSnapshot[0] + if got.Provider != "ollama" { + t.Errorf("provider: want %q, got %q", "ollama", got.Provider) + } + if got.Kind != "llm" { + t.Errorf("kind: want %q, got %q", "llm", got.Kind) + } + if got.InFlight != 1 { + t.Errorf("in_flight: want 1, got %d", got.InFlight) + } + if got.MaxConcurrent != 4 { + t.Errorf("max_concurrent: want 4, got %d", got.MaxConcurrent) + } + if got.TokensPerSecond != 12.5 { + t.Errorf("tokens_per_second: want 12.5, got %f", got.TokensPerSecond) + } + if fetcher.calls != 1 { + t.Errorf("expected exactly 1 gRPC call, got %d", fetcher.calls) + } +} + +// TestHandleLLMActivityGateSnapshotOmittedOnError verifies that gate_snapshot +// is absent from the response (not an empty array) when the fetcher errors. +// The rest of the response must remain intact and return 200. +func TestHandleLLMActivityGateSnapshotOmittedOnError(t *testing.T) { + fetcher := &fakeGateFetcher{err: errors.New("worker unreachable")} + s := newMonitorTestServerWithGateFetcher(t, fetcher) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/llm/activity", nil) + w := httptest.NewRecorder() + s.handleLLMActivity(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 even when worker errors, got %d", w.Code) + } + // Verify gate_snapshot key is absent from the raw JSON (omitempty). + body := w.Body.String() + if strings.Contains(body, `"gate_snapshot"`) { + t.Fatalf("gate_snapshot should be absent when worker errors, but found it in: %s", body) + } + // Verify the rest of the response is intact. + var resp monitorActivityResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode body: %v", err) + } + if resp.Stats.MaxConcurrency != 2 { + t.Errorf("stats should still be present; max_concurrency: want 2, got %d", resp.Stats.MaxConcurrency) + } +} + +// TestFetchGateSnapshotCacheDeduplicatesGRPCCalls verifies that two calls +// within the 1-second TTL window trigger only one underlying gRPC fetch. +func TestFetchGateSnapshotCacheDeduplicatesGRPCCalls(t *testing.T) { + fetcher := &fakeGateFetcher{entries: []monitorGateEntry{ + {Provider: "openai", Kind: "llm", MaxConcurrent: 8}, + }} + s := newMonitorTestServerWithGateFetcher(t, fetcher) + + ctx := context.Background() + first := s.fetchGateSnapshot(ctx) + second := s.fetchGateSnapshot(ctx) + + if len(first) != 1 || len(second) != 1 { + t.Fatalf("expected 1 entry on both calls, got %d and %d", len(first), len(second)) + } + if fetcher.calls != 1 { + t.Errorf("cache should deduplicate: expected 1 gRPC call, got %d", fetcher.calls) + } +} + +// TestFetchGateSnapshotNilWorkerReturnsNil verifies that fetchGateSnapshot +// returns nil (omitempty) when the server has no worker and no test fetcher. +func TestFetchGateSnapshotNilWorkerReturnsNil(t *testing.T) { + s := newMonitorTestServer(t) // worker is nil in the base fixture + result := s.fetchGateSnapshot(context.Background()) + if result != nil { + t.Fatalf("expected nil from fetchGateSnapshot when worker is nil, got %v", result) + } +} diff --git a/internal/api/rest/clustering_handler.go b/internal/api/rest/clustering_handler.go index e19c8e42..6e2cf077 100644 --- a/internal/api/rest/clustering_handler.go +++ b/internal/api/rest/clustering_handler.go @@ -374,7 +374,7 @@ func runRelabelClusters(ctx context.Context, rt llm.Runtime, cs clustering.Clust return ctx.Err() } progress := float64(i) / total - rt.ReportProgress(progress, "relabeling", fmt.Sprintf("Relabeling cluster %d of %d", i+1, int(total))) + rt.ReportProgress(progress, "relabeling", fmt.Sprintf("Relabeling cluster %d of %d", i+1, int(total)), 0) c, err := cs.GetClusterByID(ctx, clusterID) if err != nil || c == nil { @@ -409,7 +409,7 @@ func runRelabelClusters(ctx context.Context, rt llm.Runtime, cs clustering.Clust } } } - rt.ReportProgress(1.0, "ready", fmt.Sprintf("Relabeled %d clusters", len(clusterIDs))) + rt.ReportProgress(1.0, "ready", fmt.Sprintf("Relabeled %d clusters", len(clusterIDs)), 0) return nil } diff --git a/internal/api/rest/repo_llm_monitor_test.go b/internal/api/rest/repo_llm_monitor_test.go index 749046ff..51bfe4ae 100644 --- a/internal/api/rest/repo_llm_monitor_test.go +++ b/internal/api/rest/repo_llm_monitor_test.go @@ -195,7 +195,7 @@ func TestRepoLLMJobLogsMatchingRepo(t *testing.T) { TargetKey: "repo-1:logs-match", RepoID: "repo-abc", Run: func(rt llm.Runtime) error { - rt.ReportProgress(0.25, "snapshot", "Snapshot assembled") + rt.ReportProgress(0.25, "snapshot", "Snapshot assembled", 0) return nil }, }) diff --git a/internal/api/rest/router.go b/internal/api/rest/router.go index 379ec5ec..37ac719c 100644 --- a/internal/api/rest/router.go +++ b/internal/api/rest/router.go @@ -314,6 +314,7 @@ type Server struct { clusterRunner *clustering.Runner // subsystem clustering job dispatcher; nil = feature disabled healthChecker *HealthChecker // shared DB+worker probe; nil = embedded/test mode, handlers fall back to local checks workerVersionLookup *versionLookup // best-effort cache for worker GetVersion (CA-136); nil = workerVersion always "" in /api/v1/version + gateSnapshotCache gateSnapshotCache // 1-second TTL cache for worker gate snapshot (Phase 7) // encryptionKeySet is true when the API booted with a resolved encryption // key (from SOURCEBRIDGE_SECURITY_ENCRYPTION_KEY_FILE or the literal env diff --git a/internal/clustering/job.go b/internal/clustering/job.go index 44bed5e2..e2003e56 100644 --- a/internal/clustering/job.go +++ b/internal/clustering/job.go @@ -100,7 +100,7 @@ func NewOrchestratorDispatcher(o *orchestrator.Orchestrator) Dispatcher { // run is the job body. It executes the full clustering pipeline for one repo. func (r *Runner) run(ctx context.Context, rt llm.Runtime, repoID, commitSHA string) error { - rt.ReportProgress(0.05, "loading", "Loading call graph") + rt.ReportProgress(0.05, "loading", "Loading call graph", 0) // 1. Load call edges. rawEdges := r.store.GetCallEdges(repoID) @@ -111,11 +111,11 @@ func (r *Runner) run(ctx context.Context, rt llm.Runtime, repoID, commitSHA stri if storedHash != "" && storedHash == currentHash { slog.Info("clustering: call graph unchanged, skipping", "repo_id", repoID, "edge_hash", currentHash[:8]) - rt.ReportProgress(1.0, "unchanged", "Call graph unchanged — skipping re-cluster") + rt.ReportProgress(1.0, "unchanged", "Call graph unchanged — skipping re-cluster", 0) return nil } - rt.ReportProgress(0.15, "running_lpa", "Running label propagation") + rt.ReportProgress(0.15, "running_lpa", "Running label propagation", 0) // 3. Collect all node IDs. nodeSet := make(map[string]struct{}, len(rawEdges)*2) @@ -133,7 +133,7 @@ func (r *Runner) run(ctx context.Context, rt llm.Runtime, repoID, commitSHA stri seed := BuildSeed(repoID, commitSHA) lpaResult := RunLPA(rawEdges, nodeIDs, seed) - rt.ReportProgress(0.60, "building_clusters", "Building cluster records") + rt.ReportProgress(0.60, "building_clusters", "Building cluster records", 0) // 5. Group nodes by their final label. groups := make(map[string][]string, len(nodeIDs)/4+1) @@ -201,7 +201,7 @@ func (r *Runner) run(ctx context.Context, rt llm.Runtime, repoID, commitSHA stri clusters = append(clusters, cls) } - rt.ReportProgress(0.80, "persisting", "Persisting clusters") + rt.ReportProgress(0.80, "persisting", "Persisting clusters", 0) // 9. Atomic replace: delete old clusters and insert new ones in a single // transaction so GetClusters never returns empty mid-swap. @@ -228,7 +228,7 @@ func (r *Runner) run(ctx context.Context, rt llm.Runtime, repoID, commitSHA stri "size_p95", sp95, ) - rt.ReportProgress(1.0, "ready", fmt.Sprintf("Clustered %d symbols into %d subsystems (Q=%.2f)", len(nodeIDs), len(clusters), q)) + rt.ReportProgress(1.0, "ready", fmt.Sprintf("Clustered %d symbols into %d subsystems (Q=%.2f)", len(nodeIDs), len(clusters), q), 0) return nil } diff --git a/internal/db/llm_job_store.go b/internal/db/llm_job_store.go index 0a6a23ff..dadab42a 100644 --- a/internal/db/llm_job_store.go +++ b/internal/db/llm_job_store.go @@ -57,12 +57,17 @@ type surrealLLMJob struct { SkippedFileUnits int `json:"skipped_file_units"` SkippedPackageUnits int `json:"skipped_package_units"` SkippedRootUnits int `json:"skipped_root_units"` - ArtifactID string `json:"artifact_id"` - RepoID string `json:"repo_id"` - CreatedAt surrealTime `json:"created_at"` - StartedAt surrealTime `json:"started_at"` - UpdatedAt surrealTime `json:"updated_at"` - CompletedAt surrealTime `json:"completed_at"` + ArtifactID string `json:"artifact_id"` + RepoID string `json:"repo_id"` + // CurrentTokensPerSecond is stored as a pointer so that zero and + // absent are distinct on round-trip: a missing field decodes to nil + // (unknown) rather than 0 (zero rate), while an explicit 0 is + // preserved but the caller treats it as "unknown" per the Job docs. + CurrentTokensPerSecond *float64 `json:"current_tokens_per_second,omitempty"` + CreatedAt surrealTime `json:"created_at"` + StartedAt surrealTime `json:"started_at"` + UpdatedAt surrealTime `json:"updated_at"` + CompletedAt surrealTime `json:"completed_at"` } type surrealLLMJobLog struct { @@ -123,6 +128,9 @@ func (r *surrealLLMJob) toJob() *llm.Job { CreatedAt: r.CreatedAt.Time, UpdatedAt: r.UpdatedAt.Time, } + if r.CurrentTokensPerSecond != nil { + job.CurrentTokensPerSecond = *r.CurrentTokensPerSecond + } if !r.StartedAt.Time.IsZero() { t := r.StartedAt.Time job.StartedAt = &t @@ -212,6 +220,7 @@ func (s *SurrealStore) Create(job *llm.Job) (*llm.Job, error) { skipped_root_units = $skipped_root_units, artifact_id = $artifact_id, repo_id = $repo_id, + current_tokens_per_second = $current_tokens_per_second, created_at = time::now(), updated_at = time::now()` @@ -252,6 +261,7 @@ func (s *SurrealStore) Create(job *llm.Job) (*llm.Job, error) { "skipped_root_units": job.SkippedRootUnits, "artifact_id": job.ArtifactID, "repo_id": job.RepoID, + "current_tokens_per_second": job.CurrentTokensPerSecond, } if _, err := surrealdb.Query[interface{}](ctx(), db, sql, vars); err != nil { @@ -307,6 +317,7 @@ func (s *SurrealStore) Update(job *llm.Job) error { skipped_root_units = $skipped_root_units, artifact_id = $artifact_id, repo_id = $repo_id, + current_tokens_per_second = $current_tokens_per_second, updated_at = time::now()` vars := map[string]any{ "id": job.ID, @@ -345,6 +356,7 @@ func (s *SurrealStore) Update(job *llm.Job) error { "skipped_root_units": job.SkippedRootUnits, "artifact_id": job.ArtifactID, "repo_id": job.RepoID, + "current_tokens_per_second": job.CurrentTokensPerSecond, } _, err := queryOne[interface{}](ctx(), db, sql, vars) return err @@ -532,7 +544,7 @@ func (s *SurrealStore) SetStatus(id string, status llm.JobStatus) error { } // SetProgress updates the progress fields. -func (s *SurrealStore) SetProgress(id string, progress float64, phase, message string) error { +func (s *SurrealStore) SetProgress(id string, progress float64, phase, message string, throughputTPS float64) error { db := s.client.DB() if db == nil { return fmt.Errorf("database not connected") @@ -548,6 +560,7 @@ func (s *SurrealStore) SetProgress(id string, progress float64, phase, message s progress = $progress, progress_phase = $phase, progress_message = $message, + current_tokens_per_second = $tps, updated_at = time::now() WHERE status = 'pending' OR status = 'generating'`, map[string]any{ @@ -555,6 +568,7 @@ func (s *SurrealStore) SetProgress(id string, progress float64, phase, message s "progress": progress, "phase": phase, "message": message, + "tps": throughputTPS, }) return err } diff --git a/internal/livingwiki/coldstart/runner.go b/internal/livingwiki/coldstart/runner.go index 32209462..bb499b22 100644 --- a/internal/livingwiki/coldstart/runner.go +++ b/internal/livingwiki/coldstart/runner.go @@ -189,7 +189,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { } if cfg.LWOrch == nil { return func(_ context.Context, rt llm.Runtime) error { - rt.ReportProgress(1.0, "unavailable", "Living-wiki orchestrator not configured") + rt.ReportProgress(1.0, "unavailable", "Living-wiki orchestrator not configured", 0) return nil } } @@ -198,7 +198,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { jobID := rt.JobID() start := time.Now() - rt.ReportProgress(0.0, "planning", "Resolving page taxonomy") + rt.ReportProgress(0.0, "planning", "Resolving page taxonomy", 0) // ── Step 1.65 (NEW at r3): Resolve LLM identity ONCE before taxonomy ───── // @@ -290,7 +290,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { } } if len(pages) == 0 { - rt.ReportProgress(1.0, "ok", "No previously-excluded pages found; nothing to retry") + rt.ReportProgress(1.0, "ok", "No previously-excluded pages found; nothing to retry", 0) return nil } } else { @@ -303,7 +303,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { } if len(pages) == 0 { - rt.ReportProgress(1.0, "ok", "No pages to generate for this repository") + rt.ReportProgress(1.0, "ok", "No pages to generate for this repository", 0) return nil } @@ -322,7 +322,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { pages = ApplyPageSelection(pages, cfg.SelectedPageIDs) if len(pages) == 0 { - rt.ReportProgress(1.0, "ok", "No pages remain after selection filter") + rt.ReportProgress(1.0, "ok", "No pages remain after selection filter", 0) return nil } @@ -375,7 +375,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { "cap_value", capValue, "excluded_only_retry", excludedOnlyRetry, ) - rt.ReportProgress(0.01, "planning", planningMsg) + rt.ReportProgress(0.01, "planning", planningMsg, 0) // ── Step 1.5: Attach knowledge artifacts to architecture pages ───────── // @@ -557,7 +557,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { if toGenerate == 0 && len(skipNeedsFixup) == 0 { rt.ReportProgress(1.0, "ok", fmt.Sprintf( - "All %d pages already up to date — nothing to regenerate", total)) + "All %d pages already up to date — nothing to regenerate", total), 0) if cfg.JobResultStore != nil { now := time.Now() _ = cfg.JobResultStore.Save(runCtx, cfg.TenantID, &livingwiki.LivingWikiJobResult{ @@ -576,7 +576,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { rt.ReportProgress(0.05, "generating", fmt.Sprintf( "Generating %d pages (%s; %d already up to date, %d need fixup)", - toGenerate, planningMsg, len(skipFully), len(skipNeedsFixup))) + toGenerate, planningMsg, len(skipFully), len(skipNeedsFixup)), 0) // ── Step 1.9: Wire async dispatch worker (CR2) ─────────────────────────── // @@ -656,7 +656,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { progress = 0.05 + 0.85*float64(done)/float64(toGenerate) } rt.ReportProgress(progress, "generating", - fmt.Sprintf("%d/%d pages complete", done, toGenerate)) + fmt.Sprintf("%d/%d pages complete", done, toGenerate), 0) // Phase 2: trigger an index update every indexUpdateEvery page completions. // OnPageDone now fires from the orchestrator's single-goroutine persistence @@ -719,7 +719,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { p = 0.05 + 0.85*float64(done)/float64(toGenerate) } rt.ReportProgress(p, "generating", - fmt.Sprintf("%d/%d pages complete", done, toGenerate)) + fmt.Sprintf("%d/%d pages complete", done, toGenerate), 0) // Phase 2: refresh the index page on every 30s heartbeat tick. // Run in a goroutine so a slow sink write does not delay the next // heartbeat. runCtx (not hbCtx) so the write can complete even if @@ -866,7 +866,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { // errors leave sinks untouched. (codex r1b [Medium]) if (err == nil || isPartial) && (len(result.Generated) > 0 || len(skippedPageIDs) > 0) { rt.ReportProgress(0.92, "pushing", fmt.Sprintf( - "Pushing %d pages to sinks", len(result.Generated))) + "Pushing %d pages to sinks", len(result.Generated)), 0) // Resolve the repository name for the Confluence root page title. // Best-effort: fall back to an empty string if the store is nil or @@ -898,7 +898,7 @@ func BuildRunner(cfg Config) func(ctx context.Context, rt llm.Runtime) error { rt.ReportProgress(1.0, status, fmt.Sprintf( "Generation complete: %d generated, %d excluded", finalGen, finalExcl, - )) + ), 0) // ── Step 5: Persist LivingWikiJobResult ─────────────────────────────── if cfg.JobResultStore != nil { diff --git a/internal/livingwiki/coldstart/runner_test.go b/internal/livingwiki/coldstart/runner_test.go index 698e3697..b33fe1bc 100644 --- a/internal/livingwiki/coldstart/runner_test.go +++ b/internal/livingwiki/coldstart/runner_test.go @@ -124,7 +124,7 @@ func (s *stubRuntime) JobID() string { return s.jobID } -func (s *stubRuntime) ReportProgress(progress float64, status, message string) { +func (s *stubRuntime) ReportProgress(progress float64, status, message string, _ float64) { if s.onReportProgress != nil { s.onReportProgress(progress, status, message) } diff --git a/internal/llm/job.go b/internal/llm/job.go index 323b79ac..f8d27bbb 100644 --- a/internal/llm/job.go +++ b/internal/llm/job.go @@ -146,6 +146,11 @@ type Job struct { Progress float64 `json:"progress"` ProgressPhase string `json:"progress_phase,omitempty"` ProgressMessage string `json:"progress_message,omitempty"` + // CurrentTokensPerSecond is the instantaneous LLM throughput sampled + // from the gate's 60-second ring buffer at the most recent progress + // update. Zero means unknown (non-streaming path, cold gate, or + // kill-switch disabled). Consumers MUST treat zero as "unknown". + CurrentTokensPerSecond float64 `json:"current_tokens_per_second,omitempty"` ErrorCode string `json:"error_code,omitempty"` ErrorMessage string `json:"error_message,omitempty"` @@ -277,9 +282,11 @@ type Runtime interface { // JobID returns the persisted id of the job currently running. JobID() string // ReportProgress updates the job's progress (0.0-1.0), phase label, - // and human-readable message. Updates are debounced by the orchestrator - // to avoid write amplification. - ReportProgress(progress float64, phase, message string) + // and human-readable message. throughputTPS carries the instantaneous + // LLM token/s from the gate's 60-second ring buffer at the moment the + // progress event was emitted; pass 0 when not available. Updates are + // debounced by the orchestrator to avoid write amplification. + ReportProgress(progress float64, phase, message string, throughputTPS float64) // ReportTokens records the input/output token counts for billing and // metrics. Typically called once at the end of the job. ReportTokens(input, output int) diff --git a/internal/llm/memstore.go b/internal/llm/memstore.go index 3c6a58e7..490a99d4 100644 --- a/internal/llm/memstore.go +++ b/internal/llm/memstore.go @@ -221,7 +221,7 @@ func (s *MemStore) SetStatus(id string, status JobStatus) error { } // SetProgress updates the progress fields without changing status. -func (s *MemStore) SetProgress(id string, progress float64, phase, message string) error { +func (s *MemStore) SetProgress(id string, progress float64, phase, message string, throughputTPS float64) error { s.mu.Lock() defer s.mu.Unlock() j, ok := s.jobs[id] @@ -240,6 +240,9 @@ func (s *MemStore) SetProgress(id string, progress float64, phase, message strin j.Progress = progress j.ProgressPhase = phase j.ProgressMessage = message + if throughputTPS > 0 { + j.CurrentTokensPerSecond = throughputTPS + } j.UpdatedAt = time.Now() return nil } diff --git a/internal/llm/memstore_test.go b/internal/llm/memstore_test.go index 73afc426..4c0d2c21 100644 --- a/internal/llm/memstore_test.go +++ b/internal/llm/memstore_test.go @@ -165,7 +165,7 @@ func TestMemStoreIgnoresWritesForTerminalJobs(t *testing.T) { store := NewMemStore() _, _ = store.Create(newTestJob("done", "tk", StatusCancelled)) before := store.GetByID("done") - if err := store.SetProgress("done", 0.75, "render", "ignored"); err != nil { + if err := store.SetProgress("done", 0.75, "render", "ignored", 0); err != nil { t.Fatalf("SetProgress failed: %v", err) } if err := store.SetTokens("done", 10, 20); err != nil { diff --git a/internal/llm/orchestrator/orchestrator.go b/internal/llm/orchestrator/orchestrator.go index 6601101c..e23e24bc 100644 --- a/internal/llm/orchestrator/orchestrator.go +++ b/internal/llm/orchestrator/orchestrator.go @@ -1198,7 +1198,7 @@ func (o *Orchestrator) runJob(item *workItem) { } for attempt := 1; attempt <= maxAttempts; attempt++ { if cooldown := o.breaker.waitDuration(req.Subsystem); cooldown > 0 { - _ = o.store.SetProgress(jobID, 0.02, "backoff", "Waiting for model backend to recover") + _ = o.store.SetProgress(jobID, 0.02, "backoff", "Waiting for model backend to recover", 0) if job := o.store.GetByID(jobID); job != nil { o.publish(llm.JobEvent{Kind: llm.EventProgress, Job: job}) } diff --git a/internal/llm/orchestrator/orchestrator_test.go b/internal/llm/orchestrator/orchestrator_test.go index cff9a759..4c6637bb 100644 --- a/internal/llm/orchestrator/orchestrator_test.go +++ b/internal/llm/orchestrator/orchestrator_test.go @@ -55,7 +55,7 @@ func TestOrchestratorEnqueueRunsJobToCompletion(t *testing.T) { JobType: "cliff_notes", TargetKey: "repo-1:cliff_notes:dev:medium", Run: func(rt llm.Runtime) error { - rt.ReportProgress(0.5, "mid", "halfway") + rt.ReportProgress(0.5, "mid", "halfway", 0) rt.ReportTokens(1000, 500) rt.ReportSnapshotBytes(12345) ran.Store(true) @@ -451,9 +451,9 @@ func TestOrchestratorPublishesEvents(t *testing.T) { JobType: "cliff_notes", TargetKey: "repo-1:events", Run: func(rt llm.Runtime) error { - rt.ReportProgress(0.25, "building", "building") + rt.ReportProgress(0.25, "building", "building", 0) time.Sleep(15 * time.Millisecond) // cross the debounce window - rt.ReportProgress(0.75, "finishing", "nearly done") + rt.ReportProgress(0.75, "finishing", "nearly done", 0) return nil }, }) @@ -894,7 +894,7 @@ func TestRuntimeHeartbeatIsNoopOnTerminalJob(t *testing.T) { JobType: "heartbeat_terminal_test", TargetKey: "heartbeat-terminal-1", Run: func(rt llm.Runtime) error { - rt.ReportProgress(1.0, "ok", "done") + rt.ReportProgress(1.0, "ok", "done", 0) return nil }, }) @@ -945,12 +945,12 @@ func TestRuntimeHeartbeatBypassesProgressDebounce(t *testing.T) { JobType: "heartbeat_debounce_test", TargetKey: "heartbeat-debounce-1", Run: func(rt llm.Runtime) error { - rt.ReportProgress(0.5, "mid", "first") + rt.ReportProgress(0.5, "mid", "first", 0) afterProgress1 := orch.GetJob(rt.JobID()).UpdatedAt time.Sleep(20 * time.Millisecond) // Second ReportProgress is debounced (5s window) so should NOT // advance UpdatedAt. - rt.ReportProgress(0.5, "mid", "first") + rt.ReportProgress(0.5, "mid", "first", 0) afterProgress2 := orch.GetJob(rt.JobID()).UpdatedAt time.Sleep(20 * time.Millisecond) // Heartbeat MUST advance UpdatedAt despite the debounce window. @@ -1151,7 +1151,7 @@ func TestReaper_HeartbeatStale_LivingWiki_QueuedPhase_NotReapedEarly(t *testing. t.Fatalf("enqueue: %v", err) } // Job is StatusPending; set progress phase to "queued". - if err := orch.store.SetProgress(job.ID, 0, "queued", ""); err != nil { + if err := orch.store.SetProgress(job.ID, 0, "queued", "", 0); err != nil { t.Fatalf("set progress: %v", err) } // 6 min stale — above heartbeatStaleThreshold but within queued override. diff --git a/internal/llm/orchestrator/runtime.go b/internal/llm/orchestrator/runtime.go index 081a7cd1..8f832d19 100644 --- a/internal/llm/orchestrator/runtime.go +++ b/internal/llm/orchestrator/runtime.go @@ -29,6 +29,7 @@ type runtime struct { lastProgress float64 lastPhase string lastMessage string + lastThroughputTPS float64 lastWrite time.Time pendingProgress bool lastLoggedPhase string @@ -52,8 +53,10 @@ func (r *runtime) JobID() string { return r.jobID } // ReportProgress records a progress update. The update is written to // the store immediately if the debounce window has elapsed, or buffered -// and written on the next tick / flush otherwise. -func (r *runtime) ReportProgress(progress float64, phase, message string) { +// and written on the next tick / flush otherwise. throughputTPS carries +// the instantaneous LLM token/s from the gate's 60-second ring buffer; +// pass 0 when not available (non-streaming callers, clustering, etc.). +func (r *runtime) ReportProgress(progress float64, phase, message string, throughputTPS float64) { r.mu.Lock() defer r.mu.Unlock() if progress < 0 { @@ -65,6 +68,7 @@ func (r *runtime) ReportProgress(progress float64, phase, message string) { r.lastProgress = progress r.lastPhase = phase r.lastMessage = message + r.lastThroughputTPS = throughputTPS r.pendingProgress = true now := time.Now() @@ -132,7 +136,7 @@ func (r *runtime) flush() { // writeProgressLocked persists the current buffered progress values. The // caller must hold r.mu. func (r *runtime) writeProgressLocked(now time.Time) { - if err := r.orch.store.SetProgress(r.jobID, r.lastProgress, r.lastPhase, r.lastMessage); err != nil { + if err := r.orch.store.SetProgress(r.jobID, r.lastProgress, r.lastPhase, r.lastMessage, r.lastThroughputTPS); err != nil { return } r.lastWrite = now diff --git a/internal/llm/store.go b/internal/llm/store.go index fee0c398..31b2c194 100644 --- a/internal/llm/store.go +++ b/internal/llm/store.go @@ -48,7 +48,10 @@ type JobStore interface { // SetProgress writes a progress update. Callers are expected to // debounce upstream; the store writes every call it receives. - SetProgress(id string, progress float64, phase, message string) error + // throughputTPS carries the instantaneous LLM token/s from the gate's + // 60-second ring buffer; 0 means not available (non-streaming paths, + // or before the first streaming completion). + SetProgress(id string, progress float64, phase, message string, throughputTPS float64) error // Heartbeat bumps updated_at to time::now() without changing any other // field. Used by long-running jobs to assert liveness when no progress diff --git a/internal/worker/client.go b/internal/worker/client.go index 495fe6c2..1771dab0 100644 --- a/internal/worker/client.go +++ b/internal/worker/client.go @@ -1101,6 +1101,23 @@ func (c *Client) GetProviderCapabilities(ctx context.Context) (*reasoningv1.GetP return b.reasoning.GetProviderCapabilities(ctx, &reasoningv1.GetProviderCapabilitiesRequest{}) } +// GetLLMGateSnapshot returns a point-in-time snapshot of all active +// per-provider concurrency gates from the worker. Used by the admin +// monitor endpoint to surface real-time gate counters. +// +// Short timeout (5s): this is a read-only, lock-free snapshot; any +// call that takes longer than 5s indicates a stuck worker. +func (c *Client) GetLLMGateSnapshot(ctx context.Context) (*reasoningv1.GetLLMGateSnapshotResponse, error) { + b := c.acquire() + if b == nil { + return nil, errClientClosed + } + defer c.release(b) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return b.reasoning.GetLLMGateSnapshot(ctx, &reasoningv1.GetLLMGateSnapshotRequest{}) +} + // ClassifyQuestion runs the LLM-backed question classifier. Quick // timeout (2s) because callers fall back to the keyword classifier // when this fails. diff --git a/proto/common/v1/knowledge_progress.proto b/proto/common/v1/knowledge_progress.proto index 7b2ed726..a204aefb 100644 --- a/proto/common/v1/knowledge_progress.proto +++ b/proto/common/v1/knowledge_progress.proto @@ -68,6 +68,13 @@ message KnowledgeStreamProgress { int32 file_cache_hits = 11; int32 package_cache_hits = 12; int32 root_cache_hits = 13; + + // Instantaneous throughput from the active LLM gate's 60-second ring + // buffer at the moment this progress event was emitted. Zero when the + // gate has not recorded any completions yet (cold start, non-streaming + // path, or kill-switch disabled). Consumers MUST treat zero as "unknown", + // not as "zero tokens per second". Added Phase 6 (worker LLM concurrency). + float current_tokens_per_second = 14; } // KnowledgeStreamPhaseMarker fires once per phase transition so the diff --git a/proto/reasoning/v1/reasoning.proto b/proto/reasoning/v1/reasoning.proto index 735f3713..399198af 100644 --- a/proto/reasoning/v1/reasoning.proto +++ b/proto/reasoning/v1/reasoning.proto @@ -56,6 +56,16 @@ service ReasoningService { // a per-request round-trip. rpc GetProviderCapabilities(GetProviderCapabilitiesRequest) returns (GetProviderCapabilitiesResponse); + // GetLLMGateSnapshot returns a point-in-time snapshot of all active + // per-provider concurrency gates in the worker. Used by the admin + // /api/v1/admin/llm/activity endpoint to surface real-time gate + // counters (in-flight, queued, tok/s) without a per-request round-trip. + // + // An explicit request struct (rather than google.protobuf.Empty) is + // used for forward compatibility: filter fields (provider, kind) can + // be added without redeclaring the RPC. + rpc GetLLMGateSnapshot(GetLLMGateSnapshotRequest) returns (GetLLMGateSnapshotResponse); + // ClassifyQuestion runs a cheap LLM classifier (Haiku) that // returns the question's likely class plus evidence-kind hints // (needs_call_graph, needs_tests, ...) and advisory symbol / @@ -466,3 +476,51 @@ message GetProviderCapabilitiesResponse { // report a value. When false, callers MUST NOT clamp on this field. bool max_concurrent_calls_known = 6; } + +// GetLLMGateSnapshotRequest is intentionally empty (not google.protobuf.Empty) +// for forward compatibility — future filter fields (provider, kind) can be +// added here without redeclaring the RPC. +message GetLLMGateSnapshotRequest {} + +message GetLLMGateSnapshotResponse { + repeated LLMGateEntry gates = 1; +} + +// LLMGateEntry is one row in the gate snapshot — one entry per +// (provider, base_url_normalized, kind) triple that has been +// registered in the gate registry since the worker started. +message LLMGateEntry { + // Canonical provider name (e.g. "ollama", "openai"). + string provider = 1; + + // Normalized origin URL — "scheme://host:port" with path/query stripped. + // Host-gated providers collapse to the daemon origin; frontier providers + // use the raw base_url per Decision 1 in the concurrency plan. + string base_url_normalized = 2; + + // "llm" or "embedding". Host-gated providers (Ollama, vLLM, …) emit one + // row per kind sharing a single max_concurrent and tokens_per_second. + string kind = 3; + + // Current number of LLM calls executing inside the semaphore right now. + int32 in_flight = 4; + + // Number of callers waiting for a semaphore slot (Decision 11 waiter count). + int32 queued = 5; + + // Effective semaphore size — the cap enforced by this gate. + int32 max_concurrent = 6; + + // Cumulative retry count since the worker started. + int64 retries_since_start = 7; + + // Cumulative 429 / RateLimitError count since the worker started. + int64 recent_429_count = 8; + + // Output tokens per second averaged over the last 60 seconds (60-second + // ring buffer). Zero means no completions have been recorded yet. + double tokens_per_second = 9; + + // Configured requests-per-minute limit for this gate; 0 = no RPM limiter. + int32 rpm = 10; +} diff --git a/web/src/app/(app)/admin/monitor/page.tsx b/web/src/app/(app)/admin/monitor/page.tsx index fb9cf800..66e6ebf7 100644 --- a/web/src/app/(app)/admin/monitor/page.tsx +++ b/web/src/app/(app)/admin/monitor/page.tsx @@ -10,6 +10,7 @@ import { Panel } from "@/components/ui/panel"; import { StatCard } from "@/components/ui/stat-card"; import { authFetch } from "@/lib/auth-fetch"; import { normalizeActivityResponse } from "@/lib/llm/activity"; +import type { LLMGateEntry } from "@/lib/llm/activity"; import { JobProgress, formatElapsedMs, @@ -140,6 +141,10 @@ interface ActivityResponse { pending_maintenance?: number; pending_prewarm?: number; }; + // Present when the worker is reachable and has at least one active gate. + // Omitted entirely when the worker is unreachable or the wrapper is + // kill-switched (never an empty array — see Go handler omitempty). + gate_snapshot?: LLMGateEntry[]; } const POLL_INTERVAL_MS = 2000; @@ -486,6 +491,11 @@ export default function MonitorPage() { )} + {/* LLM Gate Activity — only rendered when gate_snapshot is present */} + {data?.gate_snapshot && data.gate_snapshot.length > 0 ? ( + + ) : null} + {/* Zone 2 — Now running */}
@@ -650,7 +660,14 @@ function ActiveJobCard({
- {formatElapsedMs(job.elapsed_ms)} + + {formatElapsedMs(job.elapsed_ms)} + {job.status === "generating" && (job.current_tokens_per_second ?? 0) > 0 ? ( + + {job.current_tokens_per_second!.toFixed(1)} tok/s + + ) : null} + {formatGenerationMode(job.generation_mode) ? {formatGenerationMode(job.generation_mode)} : null} {formatPriority(job.priority) ? {formatPriority(job.priority)} : null} @@ -912,6 +929,83 @@ function JobDetailDrawer({ job, onClose }: { job: JobView; onClose: () => void } ); } +// ────────────────────────────────────────────────────────────────────────────── +// LLM Gate Activity + +function LLMGateActivitySection({ gates }: { gates: LLMGateEntry[] }) { + return ( + +
+

LLM Gate Activity

+

+ Per-provider concurrency gate counters — live snapshot from the worker. One row per + provider + endpoint + kind combination. +

+
+
+ + + + + + + + + + + + + + + {gates.map((gate, idx) => ( + + + + + + + + + + + ))} + +
ProviderEndpointKindIn-flight / CapQueuedTok/s429sRetries
+ {gate.provider} + + + {gate.base_url_normalized || "—"} + + {gate.kind} + {gate.in_flight} /  + {gate.max_concurrent > 0 ? gate.max_concurrent : "∞"} + + {gate.queued > 0 ? ( + {gate.queued} + ) : ( + gate.queued + )} + + {gate.tokens_per_second > 0 ? gate.tokens_per_second.toFixed(1) : "—"} + + {gate.recent_429_count > 0 ? ( + + {gate.recent_429_count} + + ) : ( + gate.recent_429_count + )} + + {gate.retries_since_start} +
+
+
+ ); +} + function JobLogsPanel({ job }: { job: JobView }) { const [logs, setLogs] = useState([]); const [error, setError] = useState(null); diff --git a/web/src/lib/llm/activity.ts b/web/src/lib/llm/activity.ts index a5efd105..ce9fbdfd 100644 --- a/web/src/lib/llm/activity.ts +++ b/web/src/lib/llm/activity.ts @@ -1,5 +1,25 @@ "use client"; +/** + * LLMGateEntry mirrors the monitorGateEntry JSON shape from the Go REST handler. + * One entry per (provider, base_url_normalized, kind) gate in the worker registry. + * + * Consumers must treat zero values as "unknown/uncapped" for max_concurrent + * (matches the (known=true, calls=0) unbounded encoding from GetProviderCapabilities). + */ +export interface LLMGateEntry { + provider: string; + base_url_normalized: string; + kind: "llm" | "embedding"; + in_flight: number; + queued: number; + max_concurrent: number; + retries_since_start: number; + recent_429_count: number; + tokens_per_second: number; + rpm?: number; +} + interface ActivityEnvelope { active?: TJob[]; recent?: TJob[]; diff --git a/web/src/lib/llm/job-types.ts b/web/src/lib/llm/job-types.ts index 1eb0e043..f6612d30 100644 --- a/web/src/lib/llm/job-types.ts +++ b/web/src/lib/llm/job-types.ts @@ -44,6 +44,12 @@ export interface LLMJobView { estimated_wait_ms?: number; generation_mode?: "classic" | "understanding_first"; priority?: "interactive" | "maintenance" | "prewarm"; + /** + * Instantaneous LLM throughput sampled from the gate's 60-second ring + * buffer at the last progress update. Zero means unknown (non-streaming + * path, cold gate, or kill-switch disabled). Treat zero as "unknown". + */ + current_tokens_per_second?: number; elapsed_ms: number; updated_at: string; created_at?: string; diff --git a/workers/__main__.py b/workers/__main__.py index 4207ebfe..5fa97000 100644 --- a/workers/__main__.py +++ b/workers/__main__.py @@ -30,6 +30,7 @@ from workers import __version__ as _worker_version # noqa: E402 from workers.common.config import WorkerConfig # noqa: E402 from workers.common.embedding.config import create_embedding_provider # noqa: E402 +from workers.common.llm.concurrency import ConcurrencyConfig, ProviderGateRegistry # noqa: E402 from workers.common.llm.concurrency_probe import OpenAICompatProbeBackend, run_startup_probe # noqa: E402 from workers.common.llm.factory import create_llm_provider, create_report_provider # noqa: E402 from workers.contracts.servicer import ContractsServicer # noqa: E402 @@ -41,9 +42,11 @@ from workers.requirements.servicer import RequirementsServicer # noqa: E402 from workers.version_servicer import VersionServicer # noqa: E402 - _LOOPBACK_PREFIXES = ("127.", "::1", "localhost") _UNAUTHENTICATED_BIND_ADDRESSES = ("[::]", "0.0.0.0", "") +# Providers where a real concurrency limit is meaningful; frontier APIs +# (anthropic, openai, openrouter) are unbounded so no probe is needed. +_LOCAL_PROBE_PROVIDERS = frozenset({"ollama", "vllm", "llama-cpp", "sglang", "lmstudio"}) def _is_non_loopback(addr: str) -> bool: @@ -196,22 +199,36 @@ async def serve() -> None: version=_worker_version, ) + # --- Construct concurrency gate registry (long-lived, shared across all servicers) --- + # ConcurrencyConfig.from_env() reads SOURCEBRIDGE_LLM_* env vars (Decision 7). + # Phase 3: real Decision 6 defaults loaded; tenacity retry active (max_attempts=5); + # SDK retry disabled (max_retries=0 on AsyncOpenAI / AsyncAnthropic). + # Kill switch: SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED=false to revert. + concurrency_config = ConcurrencyConfig.from_env() + gate_registry = ProviderGateRegistry(concurrency_config) + log.info( + "concurrency_gate_registry_initialized", + wrapper_enabled=concurrency_config.wrapper_enabled, + retry_max_attempts=concurrency_config.retry_max_attempts, + llm_max_concurrent=concurrency_config.llm_max_concurrent, + ) + # --- Initialize providers (long-lived, connection-pooled) --- - llm_provider = create_llm_provider(config) - report_llm = create_report_provider(config) + llm_provider = await create_llm_provider(config, gate_registry=gate_registry) + report_llm = await create_report_provider(config, gate_registry=gate_registry) if report_llm: log.info( "report_llm_provider_configured", provider=config.llm_report_provider or config.llm_provider, model=config.llm_report_model, ) - embedding_provider = create_embedding_provider(config) + embedding_provider = await create_embedding_provider(config, gate_registry=gate_registry) summary_node_cache = SurrealSummaryNodeCache.from_config(config) # D10: Warn if the worker will be exposed on a non-loopback address without # authentication. This fires regardless of whether the capacity probe is used. listen_addr_early = f"[::]:{config.grpc_port}" - bind_host = f"[::]" # default gRPC bind host before port is chosen + bind_host = "[::]" # default gRPC bind host before port is chosen if not config.tls_enabled and not config.grpc_auth_secret and _is_non_loopback(bind_host): log.error( "worker_grpc_unauthenticated_non_loopback_bind", @@ -235,7 +252,12 @@ async def serve() -> None: ) # --- Register servicers --- - reasoning_servicer = ReasoningServicer(llm_provider, embedding_provider, worker_config=config) + reasoning_servicer = ReasoningServicer( + llm_provider, + embedding_provider, + worker_config=config, + gate_registry=gate_registry, + ) reasoning_pb2_grpc.add_ReasoningServiceServicer_to_server(reasoning_servicer, server) linking_servicer = LinkingServicer(llm_provider, embedding_provider) @@ -251,6 +273,7 @@ async def serve() -> None: report_llm=report_llm, worker_config=config, summary_node_cache=summary_node_cache, + gate_registry=gate_registry, ) knowledge_pb2_grpc.add_KnowledgeServiceServicer_to_server(knowledge_servicer, server) report_pb2_grpc.add_EnterpriseReportServiceServicer_to_server( @@ -376,9 +399,6 @@ async def serve() -> None: # Fires after server is serving to avoid delaying startup. The result is # informational: a WARN fires when declared vs observed parallelism disagrees # by >=2x. The declared value is never auto-overridden (D1). - # Only fires for local/self-hosted providers where a real concurrency limit - # is meaningful; frontier APIs (anthropic, openai, openrouter) are unbounded. - _LOCAL_PROBE_PROVIDERS = {"ollama", "vllm", "llama-cpp", "sglang", "lmstudio"} if ( config.llm_provider in _LOCAL_PROBE_PROVIDERS and not config.test_mode @@ -478,6 +498,10 @@ def _signal_handler() -> None: if hasattr(embedding_provider, "close"): await embedding_provider.close() + # Step 4: close the gate registry (cancels aggregator tasks; idempotent). + await gate_registry.close() + log.info("gate_registry_closed") + log.info("worker_stopped") diff --git a/workers/benchmarks/run_comprehension_bench.py b/workers/benchmarks/run_comprehension_bench.py index c443e073..55fa97a7 100644 --- a/workers/benchmarks/run_comprehension_bench.py +++ b/workers/benchmarks/run_comprehension_bench.py @@ -12,8 +12,8 @@ import yaml +from workers.common.cli_main import build_cli_runtime_provider from workers.common.config import WorkerConfig -from workers.common.llm.config import create_llm_provider from workers.common.llm.fake import FakeLLMProvider from workers.knowledge.cliff_notes import generate_cliff_notes from workers.knowledge.code_tour import generate_code_tour @@ -123,16 +123,21 @@ def _effective_provider_mode(case: dict[str, Any], override: str | None) -> str: return override -def _create_provider(provider_mode: str) -> tuple[Any, str, str]: +async def _create_provider(provider_mode: str) -> tuple[Any, str, str, Any]: + """Return (provider, provider_name, model_id, gate_registry). + + ``gate_registry`` is ``None`` for fake mode (no lifecycle to manage). + Callers must ``await gate_registry.close()`` when done with a live provider. + """ if provider_mode == "fake": provider = FakeLLMProvider() - return provider, "fake", provider.default_model + return provider, "fake", provider.default_model, None if provider_mode == "live": config = WorkerConfig() - provider = create_llm_provider(config) + provider, gate_registry = await build_cli_runtime_provider(config) provider_name = getattr(provider, "provider_name", None) or config.llm_provider model_id = getattr(provider, "default_model", None) or config.llm_model - return provider, provider_name, model_id + return provider, provider_name, model_id, gate_registry raise ValueError(f"unsupported provider mode: {provider_mode}") @@ -178,12 +183,12 @@ def _check_workflow_story_non_empty(result: Any) -> bool: async def _run_case(case: dict[str, Any], provider_mode_override: str | None = None) -> BenchmarkResult: provider_mode = _effective_provider_mode(case, provider_mode_override) - provider, provider_name, model_id = _create_provider(provider_mode) + provider, provider_name, model_id, gate_registry = await _create_provider(provider_mode) snapshot_json = _snapshot_json_for_corpus(case["corpus_id"]) artifact_type = case["artifact_type"] started = time.perf_counter() - try: + try: # noqa: SIM105 — gate cleanup requires finally even when no exception if artifact_type == "cliff_notes": result, usage = await generate_cliff_notes( provider=provider, @@ -268,6 +273,9 @@ async def _run_case(case: dict[str, Any], provider_mode_override: str | None = N checks={}, metrics={}, ) + finally: + if gate_registry is not None: + await gate_registry.close() def _write_report(results_dir: Path, results: list[BenchmarkResult]) -> None: diff --git a/workers/cli_ask.py b/workers/cli_ask.py index 4dae4365..e447aa07 100644 --- a/workers/cli_ask.py +++ b/workers/cli_ask.py @@ -18,8 +18,8 @@ if _parent not in sys.path: sys.path.insert(0, _parent) +from workers.common.cli_main import build_cli_runtime_provider # noqa: E402 from workers.common.config import WorkerConfig # noqa: E402 -from workers.common.llm.config import create_llm_provider # noqa: E402 from workers.common.surreal import SurrealClient # noqa: E402 from workers.reasoning.discussion import discuss_code # noqa: E402 @@ -872,29 +872,32 @@ async def main() -> None: "question_type": _question_type(question), } - provider = create_llm_provider(config) - with contextlib.redirect_stdout(sys.stderr): - answer, usage = await discuss_code( - provider, - question, - context_code, - context_metadata=context_metadata, - ) + provider, gate_registry = await build_cli_runtime_provider(config) + try: + with contextlib.redirect_stdout(sys.stderr): + answer, usage = await discuss_code( + provider, + question, + context_code, + context_metadata=context_metadata, + ) - output = { - "answer": answer.answer, - "references": answer.references, - "related_requirements": answer.related_requirements, - "mode": mode, - "diagnostics": diagnostics, - "usage": { - "provider": usage.provider, - "model": usage.model, - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - }, - } - print(json.dumps(output, indent=2)) + output = { + "answer": answer.answer, + "references": answer.references, + "related_requirements": answer.related_requirements, + "mode": mode, + "diagnostics": diagnostics, + "usage": { + "provider": usage.provider, + "model": usage.model, + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + }, + } + print(json.dumps(output, indent=2)) + finally: + await gate_registry.close() if __name__ == "__main__": diff --git a/workers/cli_review.py b/workers/cli_review.py index 581f6add..4be8b24b 100644 --- a/workers/cli_review.py +++ b/workers/cli_review.py @@ -14,8 +14,8 @@ if _parent not in sys.path: sys.path.insert(0, _parent) +from workers.common.cli_main import build_cli_runtime_provider # noqa: E402 from workers.common.config import WorkerConfig # noqa: E402 -from workers.common.llm.config import create_llm_provider # noqa: E402 from workers.reasoning.cache import UsageTracker # noqa: E402 from workers.reasoning.reviewer import review_code # noqa: E402 @@ -47,36 +47,38 @@ async def main() -> None: language = lang_map.get(ext, "unknown") config = WorkerConfig() - provider = create_llm_provider(config) + provider, gate_registry = await build_cli_runtime_provider(config) + try: + tracker = UsageTracker() + with contextlib.redirect_stdout(sys.stderr): + result, usage = await review_code(provider, file_path, language, content, template=template) + tracker.record(usage) - tracker = UsageTracker() - with contextlib.redirect_stdout(sys.stderr): - result, usage = await review_code(provider, file_path, language, content, template=template) - tracker.record(usage) - - output = { - "template": result.template, - "findings": [ - { - "category": f.category, - "severity": f.severity, - "message": f.message, - "file_path": f.file_path, - "start_line": f.start_line, - "end_line": f.end_line, - "suggestion": f.suggestion, - } - for f in result.findings - ], - "score": result.score, - "usage": { - "provider": usage.provider, - "model": usage.model, - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - }, - } - print(json.dumps(output, indent=2)) + output = { + "template": result.template, + "findings": [ + { + "category": f.category, + "severity": f.severity, + "message": f.message, + "file_path": f.file_path, + "start_line": f.start_line, + "end_line": f.end_line, + "suggestion": f.suggestion, + } + for f in result.findings + ], + "score": result.score, + "usage": { + "provider": usage.provider, + "model": usage.model, + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + }, + } + print(json.dumps(output, indent=2)) + finally: + await gate_registry.close() if __name__ == "__main__": diff --git a/workers/common/cli_main.py b/workers/common/cli_main.py new file mode 100644 index 00000000..fdf204b6 --- /dev/null +++ b/workers/common/cli_main.py @@ -0,0 +1,45 @@ +"""Shared CLI bootstrap helper for CLI entry points and benchmark runners. + +Constructs a ``ProviderGateRegistry`` from the environment, wraps the LLM +provider, and returns both for graceful shutdown wiring. + +Usage pattern:: + + async def main() -> None: + config = WorkerConfig() + provider, gate_registry = await build_cli_runtime_provider(config) + try: + ... # use provider + finally: + await gate_registry.close() + +Plan: thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.md +Phase 2 / Decision 7 (H1 fix). +""" + +from __future__ import annotations + +from workers.common.config import WorkerConfig +from workers.common.llm.concurrency import ConcurrencyConfig, ProviderGateRegistry +from workers.common.llm.config import create_llm_provider +from workers.common.llm.provider import LLMProvider + + +async def build_cli_runtime_provider( + config: WorkerConfig, +) -> tuple[LLMProvider, ProviderGateRegistry]: + """Construct a ``ProviderGateRegistry`` and return a wrapped LLM provider. + + The registry is constructed from ``ConcurrencyConfig.from_env()``. The + returned provider is gated through the registry (subject to the kill + switch). Callers must ``await gate_registry.close()`` on exit so the + registry's resources are released cleanly. + + Returns: + (provider, gate_registry) — the wrapped provider and the registry + that owns its gate state. + """ + concurrency_config = ConcurrencyConfig.from_env() + gate_registry = ProviderGateRegistry(concurrency_config) + provider = await create_llm_provider(config, gate_registry=gate_registry) + return provider, gate_registry diff --git a/workers/common/config.py b/workers/common/config.py index 01bb25e1..6f389ba8 100644 --- a/workers/common/config.py +++ b/workers/common/config.py @@ -4,7 +4,7 @@ import os -from pydantic import field_validator, model_validator +from pydantic import field_validator from pydantic_settings import BaseSettings # Maximum allowed value for llm_max_concurrent_calls (D9 / H1). @@ -161,6 +161,13 @@ class WorkerConfig(BaseSettings): # DB profile row's max_concurrent_calls IS NULL (see Phase 5 migration). # Never overrides an operator-set DB value. Config-level env var: # SOURCEBRIDGE_WORKER_LLM_MAX_CONCURRENT_CALLS (or the shared hint). + # + # LEGACY SEED (Phase 2+): The ProviderGateRegistry is the runtime owner + # of the effective LLM concurrency cap. This field is used only as the + # fallback seed when the gate's per-provider override is absent AND the + # kill switch is off. GetProviderCapabilities sources its cap from the + # registry's effective value, not from this field directly. + # See plan 2026-05-06-deliver-worker-llm-concurrency Decision 12. llm_max_concurrent_calls: int = 0 # gRPC auth @@ -200,10 +207,10 @@ def _validate_llm_max_concurrent_calls(cls, v: object) -> int: """ try: val = int(v) - except (TypeError, ValueError): + except (TypeError, ValueError) as err: raise ValueError( f"llm_max_concurrent_calls must be an integer, got {v!r}." - ) + ) from err if val < 0 or val > HARD_CONCURRENCY_CEILING: raise ValueError( f"llm_max_concurrent_calls must be between 0 and {HARD_CONCURRENCY_CEILING}, got {val}." diff --git a/workers/common/embedding/concurrency.py b/workers/common/embedding/concurrency.py new file mode 100644 index 00000000..1114fc12 --- /dev/null +++ b/workers/common/embedding/concurrency.py @@ -0,0 +1,70 @@ +"""Thin concurrency gate wrapper for embedding providers. + +Shares the same ``ProviderGateRegistry`` instance as the LLM gate, using +``kind="embedding"`` so that host-gated providers (Ollama) count embedding +calls against the same semaphore as LLM calls. + +No retry in Phase 1 — embedding calls are generally idempotent and the +error surface is narrower than LLM calls. Phase 6 can add retry if needed. + +See plan: thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.md +""" + +from __future__ import annotations + +from workers.common.embedding.provider import EmbeddingProvider +from workers.common.llm.concurrency import ( + ConcurrencyConfig, + ProviderGate, + ProviderGateRegistry, +) + + +class ConcurrencyGatedEmbeddingProvider: + """``EmbeddingProvider`` decorator that routes calls through a gate semaphore. + + Thinner than the LLM wrapper: no retry, no RPM limiter, no tok/s recording. + The slot is acquired per ``embed()`` call; parallel chunk fan-out within + one ``embed()`` call is handled by the provider itself (see Phase 5 for + the ``OpenAICompatEmbeddingProvider`` fan-out). + """ + + def __init__( + self, + raw: EmbeddingProvider, + gate: ProviderGate, + config: ConcurrencyConfig | None = None, + ) -> None: + self._raw = raw + self._gate = gate + self._config = config or ConcurrencyConfig() + + async def embed(self, texts: list[str]) -> list[list[float]]: + async with self._gate.slot(): + return await self._raw.embed(texts) + + @property + def dimension(self) -> int: + return self._raw.dimension + + +async def wrap_embedding_provider( + raw: EmbeddingProvider, + provider_name: str, + base_url: str | None, + kind: str = "embedding", + registry: ProviderGateRegistry | None = None, + config: ConcurrencyConfig | None = None, +) -> EmbeddingProvider: + """Wrap ``raw`` in a gate if the kill switch is on and registry is provided. + + Returns ``raw`` unchanged when no registry is given or when + ``config.wrapper_enabled`` is False. + """ + if registry is None: + return raw + cfg = config or registry._config + if not cfg.wrapper_enabled: + return raw + gate = await registry.lookup(provider_name, base_url, kind) + return ConcurrencyGatedEmbeddingProvider(raw, gate, cfg) diff --git a/workers/common/embedding/config.py b/workers/common/embedding/config.py index d60ce7b1..d65da66f 100644 --- a/workers/common/embedding/config.py +++ b/workers/common/embedding/config.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from workers.common.config import ( SUPPORTED_EMBEDDING_PROVIDERS, WorkerConfig, @@ -10,8 +12,15 @@ from workers.common.embedding.fake import FakeEmbeddingProvider from workers.common.embedding.provider import EmbeddingProvider +if TYPE_CHECKING: + from workers.common.llm.concurrency import ProviderGateRegistry + -def create_embedding_provider(config: WorkerConfig) -> EmbeddingProvider: +async def create_embedding_provider( + config: WorkerConfig, + *, + gate_registry: ProviderGateRegistry | None = None, +) -> EmbeddingProvider: """Create an embedding provider from configuration. Defense in depth: ``WorkerConfig._validate_embedding_provider`` @@ -22,6 +31,9 @@ def create_embedding_provider(config: WorkerConfig) -> EmbeddingProvider: override carrying a typo doesn't crash the worker mid-request with a confusing stack trace. + When ``gate_registry`` is supplied the returned provider is wrapped + in a ``ConcurrencyGatedEmbeddingProvider`` (plan v4 Phase 2). + Tester report 2026-04-30 (Pazaryna) Issue 3 / CA-125. """ if config.test_mode: @@ -31,31 +43,46 @@ def create_embedding_provider(config: WorkerConfig) -> EmbeddingProvider: from workers.common.embedding.ollama import OllamaEmbeddingProvider base_url = config.embedding_base_url or "http://localhost:11434" - return OllamaEmbeddingProvider( + raw: EmbeddingProvider = OllamaEmbeddingProvider( base_url=base_url, model=config.embedding_model, dimension=config.embedding_dimension, ) - - if config.embedding_provider in ("openai", "openai-compatible"): + embed_base_url: str | None = base_url + embed_provider_name = "ollama" + elif config.embedding_provider in ("openai", "openai-compatible"): from workers.common.embedding.openai_compat import OpenAICompatEmbeddingProvider base_url = config.embedding_base_url or "http://localhost:11434" - return OpenAICompatEmbeddingProvider( + raw = OpenAICompatEmbeddingProvider( base_url=base_url, model=config.embedding_model, dimension=config.embedding_dimension, api_key=config.embedding_api_key, ) + embed_base_url = base_url + embed_provider_name = config.embedding_provider + else: + msg = ( + f"embedding provider {config.embedding_provider!r} is not supported. " + f"Supported embedding providers: {_format_supported(SUPPORTED_EMBEDDING_PROVIDERS)}." + ) + if config.embedding_provider == "anthropic": + msg += ( + " Anthropic does not offer an embeddings API as of 2026; " + "use 'ollama' (the default), 'openai', or 'openai-compatible' " + "for a self-hosted endpoint." + ) + raise ValueError(msg) + + if gate_registry is not None: + from workers.common.embedding.concurrency import wrap_embedding_provider - msg = ( - f"embedding provider {config.embedding_provider!r} is not supported. " - f"Supported embedding providers: {_format_supported(SUPPORTED_EMBEDDING_PROVIDERS)}." - ) - if config.embedding_provider == "anthropic": - msg += ( - " Anthropic does not offer an embeddings API as of 2026; " - "use 'ollama' (the default), 'openai', or 'openai-compatible' " - "for a self-hosted endpoint." + return await wrap_embedding_provider( + raw, + provider_name=embed_provider_name, + base_url=embed_base_url, + kind="embedding", + registry=gate_registry, ) - raise ValueError(msg) + return raw diff --git a/workers/common/embedding/ollama.py b/workers/common/embedding/ollama.py index c3fb2121..02c1d4f9 100644 --- a/workers/common/embedding/ollama.py +++ b/workers/common/embedding/ollama.py @@ -56,7 +56,7 @@ async def embed(self, texts: list[str]) -> list[list[float]]: if len(texts) <= _BATCH_SIZE: return await self._embed_batch(texts) - # Split into batches + # Ollama embedding stays serial — host gate combines this with the LLM gate; see plan Decision 8 + Decision 1. all_embeddings: list[list[float]] = [] for i in range(0, len(texts), _BATCH_SIZE): batch = texts[i : i + _BATCH_SIZE] diff --git a/workers/common/embedding/openai_compat.py b/workers/common/embedding/openai_compat.py index 7ce431fc..b0ccd508 100644 --- a/workers/common/embedding/openai_compat.py +++ b/workers/common/embedding/openai_compat.py @@ -6,12 +6,18 @@ from __future__ import annotations +import asyncio + import httpx import structlog log = structlog.get_logger() _BATCH_SIZE = 256 +# Concurrent batch fan-out limit. Four is enough to saturate a frontier +# embedding endpoint without over-queuing; Ollama uses a separate serial +# provider (ollama.py) so this constant applies to OpenAI-compat only. +LOCAL_EMBEDDING_FANOUT_LIMIT = 4 class OpenAICompatEmbeddingProvider: @@ -51,17 +57,34 @@ async def _embed_batch(self, texts: list[str]) -> list[list[float]]: async def embed(self, texts: list[str]) -> list[list[float]]: """Generate embeddings via the OpenAI-compatible /v1/embeddings endpoint. - Automatically splits large batches into chunks of _BATCH_SIZE. + Automatically splits large batches into chunks of _BATCH_SIZE and + issues them concurrently (up to LOCAL_EMBEDDING_FANOUT_LIMIT batches + in-flight). Output order is preserved: output[i] corresponds to + input[i]. """ if len(texts) <= _BATCH_SIZE: return await self._embed_batch(texts) + batches = [texts[i : i + _BATCH_SIZE] for i in range(0, len(texts), _BATCH_SIZE)] + local_sem = asyncio.Semaphore(LOCAL_EMBEDDING_FANOUT_LIMIT) + + async def _embed_one(batch_num: int, batch: list[str]) -> list[list[float]]: + async with local_sem: + log.info( + "embedding_batch", + batch_num=batch_num + 1, + batch_size=len(batch), + total=len(texts), + ) + return await self._embed_batch(batch) + + results = await asyncio.gather(*[_embed_one(i, b) for i, b in enumerate(batches)]) + + # Flatten in order: results[i] corresponds to batches[i] which + # corresponds to texts[i * _BATCH_SIZE : (i + 1) * _BATCH_SIZE]. all_embeddings: list[list[float]] = [] - for i in range(0, len(texts), _BATCH_SIZE): - batch = texts[i : i + _BATCH_SIZE] - log.info("embedding_batch", batch_num=i // _BATCH_SIZE + 1, batch_size=len(batch), total=len(texts)) - batch_embeddings = await self._embed_batch(batch) - all_embeddings.extend(batch_embeddings) + for batch_result in results: + all_embeddings.extend(batch_result) return all_embeddings @property diff --git a/workers/common/llm/anthropic.py b/workers/common/llm/anthropic.py index 970118c5..24371bbb 100644 --- a/workers/common/llm/anthropic.py +++ b/workers/common/llm/anthropic.py @@ -32,7 +32,10 @@ def __init__( model: str = "claude-sonnet-4-20250514", enable_cache: bool = True, ) -> None: - self.client = anthropic.AsyncAnthropic(api_key=api_key) + self.client = anthropic.AsyncAnthropic( + api_key=api_key, + max_retries=0, # Phase 3: SDK retry disabled; tenacity owns retry (Decision 3) + ) self.model = model self.enable_cache = enable_cache diff --git a/workers/common/llm/concurrency.py b/workers/common/llm/concurrency.py new file mode 100644 index 00000000..ecdc0e30 --- /dev/null +++ b/workers/common/llm/concurrency.py @@ -0,0 +1,1242 @@ +"""Per-provider LLM concurrency gate: host-level semaphore + per-kind counters. + +Architecture (Decision 1, 2, 5b, v4 plan): + + One ``ProviderGateRegistry`` is constructed once in ``__main__.py`` and + threaded by reference into every factory call. It maintains two internal + maps: + + * ``_host_gates`` — keyed ``(provider, normalized_origin)`` — one binding + semaphore shared across LLM + embedding for local servers (Ollama, vLLM, + llama.cpp, sglang, LM Studio). + * ``_kind_gates`` — keyed ``(provider, base_url_raw, kind)`` — one binding + semaphore per API-kind for frontier providers (openai, anthropic, gemini, + openrouter) which have independent quotas for chat vs. embeddings. + * ``_kind_counters`` — per-``(provider, normalized_origin, kind)`` counter- + only sub-records under host-gated providers; observability only, no + gating. + + ``lookup(provider, base_url, kind) -> ProviderGate`` returns a façade that + acquires whichever binding gate is appropriate and updates the matching + counter. + +Phase 3 (Decision 3, atomic): SDK retry disabled on AsyncOpenAI / AsyncAnthropic +(max_retries=0); tenacity predicate finalized (Decision 4 whitelist); real +Decision 6 defaults loaded by ``ConcurrencyConfig.from_env()``; local +hierarchical/renderer fan-out caps raised; both hand-rolled retries deleted. +The empty-content retry at ``openai_compat.py:_complete_once`` is preserved +(handles -budget exhaustion, not network errors). + +Kill switch: SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED=false reverts to +pre-refactor behavior without redeploy. + +The aggregator task (emitting ``llm_provider_gate_metrics`` log lines) is +deferred to Phase 6; the start hook is a placeholder comment here. +# TODO(phase-6): start aggregator task in ProviderGateRegistry.__init__ + +See plan: thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.md +""" + +from __future__ import annotations + +import asyncio +import contextlib +import os +import sys +import time +from collections import deque +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any + +import anthropic +import httpx +import openai +import structlog +from aiolimiter import AsyncLimiter +from tenacity import ( + AsyncRetrying, + RetryCallState, + retry_if_exception, + stop_after_attempt, + wait_random_exponential, +) + +from workers.common.llm.provider import LLMProvider, LLMResponse + +# Lazy imports for provider-specific gated adapters (Decision 10). +# Imported inside wrap_provider to avoid circular imports at module load time. +# The actual check uses isinstance() against those classes. + +log = structlog.get_logger() + +# ────────────────────────────────────────────────────────────────────────────── +# Sentinel: uncapped defaults for Phase 1 / kill-switch-off behavior. +# Phase 3 replaces these with the real caps from Decision 6. +_UNCAPPED: int = sys.maxsize + +# Providers that share one host-level semaphore across LLM + embedding kinds. +_HOST_GATED_PROVIDERS: frozenset[str] = frozenset( + {"ollama", "vllm", "llama-cpp", "sglang", "lmstudio"} +) + +# Providers that use independent per-kind semaphores (frontier cloud APIs). +_KIND_GATED_PROVIDERS: frozenset[str] = frozenset( + {"openai", "anthropic", "gemini", "openrouter"} +) + +# openai-compatible default gating mode (can be flipped per Decision 7). +_GATING_ENV_VAR = "SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING" +_DEFAULT_OPENAI_COMPAT_GATING = "host" + + +# ────────────────────────────────────────────────────────────────────────────── +# URL normalization helper (Decision 1, v4) + + +def _normalize_host_key(provider: str, base_url: str | None) -> tuple[str, str]: + """Canonical form: ``(provider, "scheme://host:port")``. + + Strips path (e.g. ``/v1``), trailing slash, query, and fragment. This + ensures that Ollama's LLM endpoint ``http://localhost:11434/v1`` and its + embedding endpoint ``http://localhost:11434`` both map to + ``("ollama", "http://localhost:11434")`` and therefore share the same + host-level semaphore. + """ + if not base_url: + return (provider, "") + from urllib.parse import urlsplit + + u = urlsplit(base_url) + origin = f"{u.scheme}://{u.netloc}".rstrip("/") + return (provider, origin) + + +# ────────────────────────────────────────────────────────────────────────────── +# Configuration + + +@dataclass +class ConcurrencyConfig: + """Runtime concurrency knobs sourced from environment variables. + + Decision 6 real defaults are loaded by ``from_env()`` when no env-var + overrides are set. The ``_UNCAPPED`` sentinel is retained as the registry + fallback for unknown providers; it is not the default for any named provider. + """ + + # Per-provider max-concurrent overrides. Key = canonical provider name. + # from_env() pre-populates these from the Decision 6 table. + llm_max_concurrent: dict[str, int] = field(default_factory=dict) + embedding_max_concurrent: dict[str, int] = field(default_factory=dict) + + # Per-provider RPM limits. None = no rate shaping (default). + rpm: dict[str, int | None] = field(default_factory=dict) + + # openai-compatible gating mode: "host" (default) | "per_kind". + openai_compatible_gating: str = _DEFAULT_OPENAI_COMPAT_GATING + + # Global kill switch. When False, factories return raw providers. + wrapper_enabled: bool = True + + # Tenacity: max attempts per call. Default 5 (Phase 3 activates real retry). + retry_max_attempts: int = 5 + + # Aggregator task interval (seconds). Phase 6 activates the task. + metrics_interval_seconds: float = 30.0 + + @classmethod + def from_env(cls) -> ConcurrencyConfig: + """Read all concurrency knobs from environment variables. + + Decision 6 default caps (applied when no env-var override is set): + + | Provider | LLM max_concurrent | Embed max_concurrent | + |--------------------|--------------------|----------------------| + | ollama | 1 (host total) | (host-shared) | + | vllm | 4 (host total) | (host-shared) | + | llama-cpp | 4 (host total) | (host-shared) | + | sglang | 4 (host total) | (host-shared) | + | lmstudio | 2 (host total) | (host-shared) | + | openai | 8 | 8 | + | anthropic | 4 | n/a | + | openrouter | 8 | 8 | + | gemini | 8 | 8 | + | openai-compatible | 4 (host total) | (host-shared) | + + Env-var overrides take precedence (first-match wins). + + Decision 7 env-var names (``SOURCEBRIDGE_LLM_*`` prefix): + + ``SOURCEBRIDGE_LLM_PROVIDER__MAX_CONCURRENT`` + ``SOURCEBRIDGE_EMBEDDING_PROVIDER__MAX_CONCURRENT`` + ``SOURCEBRIDGE_LLM_PROVIDER__RPM`` + ``SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING`` + ``SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED`` + ``SOURCEBRIDGE_LLM_RETRY_MAX_ATTEMPTS`` + ``SOURCEBRIDGE_LLM_METRICS_AGGREGATION_INTERVAL_SECONDS`` + + The canonical provider-name → env-var-token mapping (L1 fix): + openai → OPENAI, anthropic → ANTHROPIC, ollama → OLLAMA, + vllm → VLLM, llama-cpp → LLAMA_CPP, sglang → SGLANG, + gemini → GEMINI, openrouter → OPENROUTER, lmstudio → LMSTUDIO, + openai-compatible → OPENAI_COMPATIBLE. + """ + # Decision 6 real defaults (Outcome A — ollama cap=1 is safe per dick's + # investigation; cap-raise to 4 is safe for vllm/llama-cpp/sglang). + llm_defaults: dict[str, int] = { + "ollama": 1, + "vllm": 4, + "llama-cpp": 4, + "sglang": 4, + "lmstudio": 2, + "openai": 8, + "anthropic": 4, + "openrouter": 8, + "gemini": 8, + "openai-compatible": 4, + } + embed_defaults: dict[str, int] = { + # Frontier providers have separate embedding caps. + "openai": 8, + "openrouter": 8, + "gemini": 8, + # Host-gated providers share the LLM cap; no separate embed entry needed. + } + + all_providers = list(_HOST_GATED_PROVIDERS | _KIND_GATED_PROVIDERS) + [ + "openai-compatible" + ] + + # Start with Decision 6 defaults; env-var overrides overwrite them. + llm_max: dict[str, int] = dict(llm_defaults) + embed_max: dict[str, int] = dict(embed_defaults) + rpm_map: dict[str, int | None] = {} + + for provider in all_providers: + token = provider.upper().replace("-", "_") + _read_max_concurrent( + f"SOURCEBRIDGE_LLM_PROVIDER_{token}_MAX_CONCURRENT", + provider, + llm_max, + ) + _read_max_concurrent( + f"SOURCEBRIDGE_EMBEDDING_PROVIDER_{token}_MAX_CONCURRENT", + provider, + embed_max, + ) + _read_rpm( + f"SOURCEBRIDGE_LLM_PROVIDER_{token}_RPM", + provider, + rpm_map, + ) + + gating = os.environ.get(_GATING_ENV_VAR, _DEFAULT_OPENAI_COMPAT_GATING).strip().lower() + if gating not in ("host", "per_kind"): + log.warning( + "concurrency_config_invalid_gating", + env_var=_GATING_ENV_VAR, + value=gating, + using=_DEFAULT_OPENAI_COMPAT_GATING, + ) + gating = _DEFAULT_OPENAI_COMPAT_GATING + + wrapper_raw = os.environ.get("SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED", "true").strip().lower() + wrapper_enabled = wrapper_raw in ("true", "1", "yes", "on") + + retry_raw = os.environ.get("SOURCEBRIDGE_LLM_RETRY_MAX_ATTEMPTS", "").strip() + retry_max = 5 # Phase 3 default: 5 attempts (Decision 3) + if retry_raw: + try: + retry_max = int(retry_raw) + if retry_max < 1: + raise ValueError("must be ≥ 1") + except ValueError as exc: + log.warning( + "concurrency_config_invalid_retry_max", + env_var="SOURCEBRIDGE_LLM_RETRY_MAX_ATTEMPTS", + value=retry_raw, + error=str(exc), + ) + retry_max = 5 + + interval_raw = os.environ.get("SOURCEBRIDGE_LLM_METRICS_AGGREGATION_INTERVAL_SECONDS", "").strip() + interval = 30.0 + if interval_raw: + with contextlib.suppress(ValueError): + interval = float(interval_raw) + + _validate_known_provider_tokens(all_providers) + + return cls( + llm_max_concurrent=llm_max, + embedding_max_concurrent=embed_max, + rpm=rpm_map, + openai_compatible_gating=gating, + wrapper_enabled=wrapper_enabled, + retry_max_attempts=retry_max, + metrics_interval_seconds=interval, + ) + + +_LLM_SUFFIXES = ("_MAX_CONCURRENT", "_RPM") +_EMBED_SUFFIXES = ("_MAX_CONCURRENT",) + + +def _validate_known_provider_tokens(all_providers: list[str]) -> None: + """Scan env vars for unknown provider tokens and emit a structlog WARNING. + + Decision 7 / codex r2 L1: an operator typo like + ``SOURCEBRIDGE_LLM_PROVIDER_OPENAICOMPAT_MAX_CONCURRENT`` (missing the + underscore in OPENAI_COMPATIBLE) would otherwise be silently ignored. + This helper detects such typos and warns without raising, so existing + deployments with stale env vars don't crash at boot. + + Scans for the three per-provider env-var patterns consumed by ``from_env()``: + ``SOURCEBRIDGE_LLM_PROVIDER__MAX_CONCURRENT`` + ``SOURCEBRIDGE_LLM_PROVIDER__RPM`` + ``SOURCEBRIDGE_EMBEDDING_PROVIDER__MAX_CONCURRENT`` + """ + # Build the set of canonical tokens (uppercase, hyphens → underscores). + canonical_tokens: frozenset[str] = frozenset( + p.upper().replace("-", "_") for p in all_providers + ) + + for env_var in os.environ: + token: str | None = None + if env_var.startswith("SOURCEBRIDGE_LLM_PROVIDER_"): + remainder = env_var[len("SOURCEBRIDGE_LLM_PROVIDER_"):] + for suffix in _LLM_SUFFIXES: + if remainder.endswith(suffix): + token = remainder[: -len(suffix)] + break + elif env_var.startswith("SOURCEBRIDGE_EMBEDDING_PROVIDER_"): + remainder = env_var[len("SOURCEBRIDGE_EMBEDDING_PROVIDER_"):] + for suffix in _EMBED_SUFFIXES: + if remainder.endswith(suffix): + token = remainder[: -len(suffix)] + break + + if token is not None and token not in canonical_tokens: + log.warning( + "concurrency_config_unknown_provider_token", + env_var=env_var, + unknown_token=token, + canonical_tokens=sorted(canonical_tokens), + ) + + +def _read_max_concurrent(env_var: str, provider: str, target: dict[str, int]) -> None: + raw = os.environ.get(env_var, "").strip() + if not raw: + return + try: + val = int(raw) + if val < 1: + raise ValueError("must be ≥ 1") + target[provider] = val + except ValueError as exc: + log.warning( + "concurrency_config_invalid_max_concurrent", + env_var=env_var, + value=raw, + error=str(exc), + ) + + +def _read_rpm(env_var: str, provider: str, target: dict[str, int | None]) -> None: + raw = os.environ.get(env_var, "").strip() + if not raw: + return + try: + val = int(raw) + if val <= 0: + raise ValueError("must be > 0") + target[provider] = val + except ValueError as exc: + log.warning( + "concurrency_config_invalid_rpm", + env_var=env_var, + value=raw, + error=str(exc), + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Snapshot dataclass + + +@dataclass +class GateSnapshotEntry: + """Point-in-time snapshot of one gate's state (for Phase 7 admin endpoint).""" + + provider: str + base_url_normalized: str + kind: str + in_flight: int + queue_depth: int + max_concurrent: int + retries_since_start: int + recent_429_count: int + tokens_per_second: float + rpm: int = 0 # 0 = no limiter + + +# ────────────────────────────────────────────────────────────────────────────── +# Internal gate primitives + + +class _GateBase: + """Shared state: semaphore, waiters, in-flight, ring buffer, retry counters.""" + + __slots__ = ( + "_sem", + "_max_concurrent", + "_waiters", + "_in_flight", + "_retries", + "_recent_429", + "_ring", + "_streaming_usage_unsupported", + ) + + def __init__(self, max_concurrent: int) -> None: + if max_concurrent < 1: + raise ValueError(f"max_concurrent must be ≥ 1, got {max_concurrent}") + self._sem = asyncio.Semaphore(max_concurrent) + self._max_concurrent = max_concurrent + self._waiters: int = 0 + self._in_flight: int = 0 + self._retries: int = 0 + self._recent_429: int = 0 + # Ring buffer: deque of (timestamp_float, output_tokens_int). + # Bounded by insertion-time eviction (keep last 60 s). + self._ring: deque[tuple[float, int]] = deque() + # Compatibility flag: set True when a server rejects stream_options. + # See Decision 10b / M2 fallback. + self._streaming_usage_unsupported: bool = False + + @asynccontextmanager + async def slot(self) -> AsyncIterator[None]: + """Acquire a slot (cancellation-safe). + + Decision 2 ordering: + - Increment ``_waiters`` before awaiting acquire. + - Decrement ``_waiters`` once acquired. + - Increment ``_in_flight`` while the caller holds the slot. + - Release and decrement in ``finally`` regardless of how the + block exits (including cancellation). + """ + self._waiters += 1 + try: + await self._sem.acquire() + except asyncio.CancelledError: + self._waiters -= 1 + raise + self._waiters -= 1 + self._in_flight += 1 + try: + yield + finally: + self._in_flight -= 1 + self._sem.release() + + def record_completion(self, output_tokens: int) -> None: + """Append to the 60-second ring buffer (Decision 10a).""" + now = time.monotonic() + self._ring.append((now, output_tokens)) + # Evict entries older than 60 s. + cutoff = now - 60.0 + while self._ring and self._ring[0][0] < cutoff: + self._ring.popleft() + + def snapshot_tokens_per_second(self) -> float: + """Sum output_tokens over the last 60 s, divide by the window.""" + if not self._ring: + return 0.0 + now = time.monotonic() + cutoff = now - 60.0 + total = sum(tok for ts, tok in self._ring if ts >= cutoff) + return total / 60.0 + + def snapshot(self, provider: str, base_url_normalized: str, kind: str, rpm: int = 0) -> GateSnapshotEntry: + return GateSnapshotEntry( + provider=provider, + base_url_normalized=base_url_normalized, + kind=kind, + in_flight=self._in_flight, + queue_depth=self._waiters, + max_concurrent=self._max_concurrent, + retries_since_start=self._retries, + recent_429_count=self._recent_429, + tokens_per_second=self.snapshot_tokens_per_second(), + rpm=rpm, + ) + + +class _HostGate(_GateBase): + """Binding semaphore for host-gated (local) providers. + + A single semaphore is shared across all ``kind`` values (llm, embedding) + routed through this daemon. This is the fix for the Ollama + ``OLLAMA_NUM_PARALLEL=1`` case where LLM + embedding calls both count + against the same server-side slot budget. + """ + + +class _KindGate(_GateBase): + """Binding semaphore for per-kind-gated (frontier) providers.""" + + +class _KindCounter: + """Observability-only counter under a host gate (no separate semaphore).""" + + __slots__ = ("_in_flight", "_waiters") + + def __init__(self) -> None: + self._in_flight: int = 0 + self._waiters: int = 0 + + +# ────────────────────────────────────────────────────────────────────────────── +# ProviderGate façade + + +class ProviderGate: + """Façade returned by ``ProviderGateRegistry.lookup``. + + Acquires the binding gate (host or per-kind) and increments the per-kind + counter when operating in host-gate mode. + """ + + __slots__ = ("_binding", "_counter", "_provider", "_normalized_origin", "_kind") + + def __init__( + self, + binding: _HostGate | _KindGate, + counter: _KindCounter | None, + provider: str, + normalized_origin: str, + kind: str, + ) -> None: + self._binding = binding + self._counter = counter + self._provider = provider + self._normalized_origin = normalized_origin + self._kind = kind + + @asynccontextmanager + async def slot(self) -> AsyncIterator[None]: + """Acquire the binding slot and update the per-kind counter.""" + if self._counter is not None: + self._counter._waiters += 1 + try: + async with self._binding.slot(): + if self._counter is not None: + self._counter._waiters -= 1 + self._counter._in_flight += 1 + try: + yield + finally: + if self._counter is not None: + self._counter._in_flight -= 1 + except asyncio.CancelledError: + if self._counter is not None: + self._counter._waiters -= 1 + raise + + def record_completion(self, output_tokens: int) -> None: + self._binding.record_completion(output_tokens) + + def snapshot_tokens_per_second(self) -> float: + return self._binding.snapshot_tokens_per_second() + + def snapshot(self) -> GateSnapshotEntry: + return self._binding.snapshot(self._provider, self._normalized_origin, self._kind) + + @property + def in_flight(self) -> int: + return self._binding._in_flight + + @property + def queue_depth(self) -> int: + return self._binding._waiters + + @property + def streaming_usage_unsupported(self) -> bool: + return self._binding._streaming_usage_unsupported + + @streaming_usage_unsupported.setter + def streaming_usage_unsupported(self, value: bool) -> None: + self._binding._streaming_usage_unsupported = value + + +# ────────────────────────────────────────────────────────────────────────────── +# Registry + + +class ProviderGateRegistry: + """Single registry for all LLM and embedding provider gates. + + Constructed once in ``__main__.py``; threaded by reference into all + factory calls (never used as a module-level singleton). + + Thread/task safety: gate creation is protected by ``asyncio.Lock``; all + other operations are non-blocking reads/increments. + """ + + def __init__(self, config: ConcurrencyConfig | None = None) -> None: + self._config = config or ConcurrencyConfig() + self._lock = asyncio.Lock() + # (provider, normalized_origin) → _HostGate + self._host_gates: dict[tuple[str, str], _HostGate] = {} + # (provider, base_url_raw, kind) → _KindGate + self._kind_gates: dict[tuple[str, str, str], _KindGate] = {} + # (provider, normalized_origin, kind) → _KindCounter (observability only) + self._kind_counters: dict[tuple[str, str, str], _KindCounter] = {} + self._closed: bool = False + # Phase 6: start the aggregator task. It runs indefinitely until + # close() cancels it. Emits llm_provider_gate_metrics info-level + # structlog lines every metrics_interval_seconds. + self._aggregator_task: asyncio.Task[None] = asyncio.ensure_future( + self._run_aggregator() + ) + + def _classify(self, provider: str) -> str: + """Return "host", "per_kind", or the resolved mode for openai-compatible.""" + if provider in _HOST_GATED_PROVIDERS: + return "host" + if provider in _KIND_GATED_PROVIDERS: + return "per_kind" + if provider == "openai-compatible": + return self._config.openai_compatible_gating + # Unknown provider: default to host gating (safe / conservative). + return "host" + + def _max_concurrent_for(self, provider: str, kind: str) -> int: + """Effective max_concurrent for this provider+kind (Phase 1: _UNCAPPED).""" + if kind == "embedding": + cap = self._config.embedding_max_concurrent.get(provider) + if cap is not None: + return cap + cap = self._config.llm_max_concurrent.get(provider) + return cap if cap is not None else _UNCAPPED + + async def lookup(self, provider: str, base_url: str | None, kind: str) -> ProviderGate: + """Return the ``ProviderGate`` for ``(provider, base_url, kind)``. + + Safe to call concurrently; gate objects are created at most once per key. + """ + if self._closed: + raise RuntimeError("ProviderGateRegistry has been closed; cannot look up gates") + + mode = self._classify(provider) + if mode == "host": + return await self._lookup_host(provider, base_url, kind) + else: + return await self._lookup_kind(provider, base_url, kind) + + async def _lookup_host(self, provider: str, base_url: str | None, kind: str) -> ProviderGate: + _, normalized_origin = _normalize_host_key(provider, base_url) + host_key = (provider, normalized_origin) + counter_key = (provider, normalized_origin, kind) + + async with self._lock: + if host_key not in self._host_gates: + cap = self._max_concurrent_for(provider, "llm") + self._host_gates[host_key] = _HostGate(max_concurrent=cap) + if counter_key not in self._kind_counters: + self._kind_counters[counter_key] = _KindCounter() + + binding = self._host_gates[host_key] + counter = self._kind_counters[counter_key] + return ProviderGate(binding, counter, provider, normalized_origin, kind) + + async def _lookup_kind(self, provider: str, base_url: str | None, kind: str) -> ProviderGate: + raw_url = base_url or "" + kind_key = (provider, raw_url, kind) + _, normalized_origin = _normalize_host_key(provider, base_url) + + async with self._lock: + if kind_key not in self._kind_gates: + cap = self._max_concurrent_for(provider, kind) + self._kind_gates[kind_key] = _KindGate(max_concurrent=cap) + + binding = self._kind_gates[kind_key] + return ProviderGate(binding, None, provider, normalized_origin, kind) + + def effective_llm_max_concurrent( + self, provider: str, base_url: str | None + ) -> int | None: + """The effective LLM cap for this provider+base_url. + + Returns ``None`` when the wrapper is disabled (kill switch) or the gate + is sentinel-uncapped (unknown provider with no Decision 6 default). + Phase 3: Decision 6 defaults are pre-populated by ``from_env()``, so + all named providers now return a finite cap. + """ + if not self._config.wrapper_enabled: + return None + cap = self._config.llm_max_concurrent.get(provider) + if cap is None or cap == _UNCAPPED: + return None # Unknown provider or sentinel-uncapped. + return cap + + def canonical_key_for(self, provider: str, base_url: str | None, kind: str) -> tuple[str, ...]: + """Return the internal lookup key (Decision 5b helper for capability contract).""" + mode = self._classify(provider) + if mode == "host": + _, origin = _normalize_host_key(provider, base_url) + return (provider, origin, kind) + return (provider, base_url or "", kind) + + def snapshot(self) -> list[GateSnapshotEntry]: + """Point-in-time snapshot of all active gates (for Phase 7).""" + entries: list[GateSnapshotEntry] = [] + for (provider, origin), gate in self._host_gates.items(): + # Emit one entry per kind counter that has ever been registered. + kinds_seen = [k for (p, o, k) in self._kind_counters if p == provider and o == origin] + if not kinds_seen: + entries.append(gate.snapshot(provider, origin, "llm")) + else: + for kind in kinds_seen: + entries.append(gate.snapshot(provider, origin, kind)) + for (provider, raw_url, kind), gate in self._kind_gates.items(): + _, origin = _normalize_host_key(provider, raw_url or None) + entries.append(gate.snapshot(provider, origin, kind)) + return entries + + def snapshot_tokens_per_second( + self, provider: str, base_url: str | None, kind: str + ) -> float: + """Return the 60-second ring-buffer tok/s for the given gate. + + Called by progress-emit sites (e.g. hierarchical strategy) once per + progress event to populate KnowledgeStreamProgress.current_tokens_per_second. + Returns 0.0 when the gate doesn't exist yet or the wrapper is disabled. + """ + if not self._config.wrapper_enabled: + return 0.0 + mode = self._classify(provider) + if mode == "host": + _, normalized_origin = _normalize_host_key(provider, base_url) + host_key = (provider, normalized_origin) + gate = self._host_gates.get(host_key) + if gate is not None: + return gate.snapshot_tokens_per_second() + return 0.0 + else: + raw_url = base_url or "" + kind_key = (provider, raw_url, kind) + gate = self._kind_gates.get(kind_key) + if gate is not None: + return gate.snapshot_tokens_per_second() + return 0.0 + + async def _run_aggregator(self) -> None: + """Emit llm_provider_gate_metrics info-level structlog lines at a + fixed cadence (metrics_interval_seconds from ConcurrencyConfig). + + Runs until cancelled by close(). Does not emit when the registry has + no active gates yet (avoids log noise on startup). + """ + interval = self._config.metrics_interval_seconds + while True: + await asyncio.sleep(interval) + # Snapshot the current gate state and emit one log line per gate. + entries = self.snapshot() + for entry in entries: + # Derive retries_since_last_tick from the raw gate counter. + # We read it directly because GateSnapshotEntry records the + # cumulative total — callers who want a delta should persist + # the previous total themselves. For the log line, the + # cumulative total is sufficient for operators. + log.info( + "llm_provider_gate_metrics", + provider=entry.provider, + base_url_normalized=entry.base_url_normalized, + kind=entry.kind, + in_flight=entry.in_flight, + queue_depth=entry.queue_depth, + max_concurrent=entry.max_concurrent, + retries_since_start=entry.retries_since_start, + recent_429_count=entry.recent_429_count, + tokens_per_second_60s=round(entry.tokens_per_second, 2), + ) + + async def close(self) -> None: + """Cancel the aggregator task and mark the registry closed. + + Idempotent — safe to call more than once. + """ + if self._closed: + return + self._closed = True + self._aggregator_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._aggregator_task + + +# ────────────────────────────────────────────────────────────────────────────── +# Tenacity predicate and retry hooks (Phase 3: Decision 4 whitelist) + + +def _retry_predicate(exc: BaseException) -> bool: + """Return True when the exception is retryable (Decision 4 whitelist). + + 1. RateLimitError — always retryable (both OpenAI and Anthropic SDKs). + 2. APIStatusError with status_code in {408, 429, 502, 503, 504}. + 3. Transient httpx errors: TimeoutException, ConnectError, ReadError. + + Returns False for everything else, including 4xx errors (except 408/429), + pydantic.ValidationError, SnapshotTooLargeError, and any other non-transient + failure. The wrapper never retries auth failures (401/403) or bad requests + (400/422) — those require operator intervention. + """ + # 1. Rate-limit errors are always retryable. + if isinstance(exc, (openai.RateLimitError, anthropic.RateLimitError)): + return True + # 2. Status-code-filtered API errors. + if isinstance(exc, (openai.APIStatusError, anthropic.APIStatusError)): + return getattr(exc, "status_code", None) in {408, 429, 502, 503, 504} + # 3. Transient httpx transport errors. + if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError, httpx.ReadError)): + return True + return False + + +def _extract_retry_after(exc: BaseException) -> float | None: + """Extract the Retry-After header value (seconds) from an SDK exception. + + Returns None when no header is present or parsing fails. + """ + # OpenAI SDK exposes the response object on APIStatusError. + response = getattr(exc, "response", None) + if response is not None: + headers = getattr(response, "headers", None) + if headers is not None: + raw = headers.get("retry-after") or headers.get("Retry-After") + if raw: + with contextlib.suppress(ValueError): + return float(raw) + return None + + +def _make_before_sleep(gate_binding: Any) -> Any: + """Factory returning a tenacity ``before_sleep`` callback. + + The callback: + 1. Increments the gate's retry counter and 429 counter (when applicable). + 2. Logs a structured debug line with attempt info. + 3. Extracts the ``Retry-After`` header from the exception and extends the + tenacity wait duration when the header value exceeds the computed + exponential backoff (Decision 2 — use the larger of the two so a tiny + Retry-After: 1 doesn't subvert sustained-429 backoff). + """ + + def _before_sleep(retry_state: RetryCallState) -> None: + exc = retry_state.outcome.exception() if retry_state.outcome else None + gate_binding._retries += 1 + if isinstance(exc, (openai.RateLimitError, anthropic.RateLimitError)): + gate_binding._recent_429 += 1 + elif isinstance(exc, (openai.APIStatusError, anthropic.APIStatusError)): + if getattr(exc, "status_code", None) == 429: + gate_binding._recent_429 += 1 + + # Honor Retry-After header: extend the computed sleep when needed. + retry_after = _extract_retry_after(exc) if exc is not None else None + + log.debug( + "llm_gate_retry", + attempt=retry_state.attempt_number, + exc_type=type(exc).__name__ if exc else None, + retry_after_header=retry_after, + ) + + if retry_after is not None and retry_state.next_action is not None: + computed_sleep = getattr(retry_state.next_action, "sleep", 0.0) + # Use the larger of the two; never let Retry-After: 1 undercut backoff. + if retry_after > computed_sleep: + retry_state.next_action.sleep = retry_after # type: ignore[assignment] + + return _before_sleep + + +# ────────────────────────────────────────────────────────────────────────────── +# ConcurrencyGatedProvider + + +class ConcurrencyGatedProvider: + """``LLMProvider`` decorator that routes calls through a ``ProviderGate``. + + Decision 2 ordering (slot held only during the upstream call): + + retry-loop → limiter-wait → acquire-slot → call-raw → release-slot + + Releasing the slot between retry attempts ensures a single 429 with a + long ``Retry-After`` does not monopolize the only slot while other + callers wait. + + Phase 3: tenacity predicate activated (Decision 4 whitelist), aiolimiter + wired for RPM shaping when configured. + + ``stream()`` is pass-through (no usage extraction). + Provider-specific streaming subclasses (``OpenAICompatGatedProvider``, + ``AnthropicGatedProvider``) that extract final usage tokens are added in + Phase 6. A ``# TODO(phase-6)`` comment marks the extension point. + """ + + def __init__( + self, + raw: LLMProvider, + gate: ProviderGate, + config: ConcurrencyConfig | None = None, + ) -> None: + self._raw = raw + self._gate = gate + self._config = config or ConcurrencyConfig() + # Wire aiolimiter for RPM rate-shaping when configured (Decision 7). + # None = no RPM shaping (default for all providers; Phase 8 adds specific RPMs). + provider_name = getattr(raw, "provider_name", None) or "" + rpm = self._config.rpm.get(provider_name) + self._limiter: AsyncLimiter | None = ( + AsyncLimiter(rpm, time_period=60) if rpm is not None and rpm > 0 else None + ) + # Cache the before_sleep callback (captures the gate's binding). + self._before_sleep = _make_before_sleep(gate._binding) + + async def complete( + self, + prompt: str, + *, + system: str = "", + max_tokens: int = 4096, + temperature: float = 0.0, + frequency_penalty: float = 0.0, + model: str | None = None, + ) -> LLMResponse: + """Gate + retry wrapper around ``raw.complete``.""" + retry_max = self._config.retry_max_attempts + + async for attempt in AsyncRetrying( + retry=retry_if_exception(_retry_predicate), + wait=wait_random_exponential(multiplier=1, max=60), + stop=stop_after_attempt(retry_max), + reraise=True, + before_sleep=self._before_sleep, + ): + with attempt: + if self._limiter is not None: + await self._limiter.acquire() + async with self._gate.slot(): + response = await self._raw.complete( + prompt, + system=system, + max_tokens=max_tokens, + temperature=temperature, + frequency_penalty=frequency_penalty, + model=model, + ) + self._gate.record_completion(response.output_tokens) + return response # type: ignore[return-value] + + # Unreachable (reraise=True above), but keeps type-checkers happy. + raise RuntimeError("AsyncRetrying exited without raising or returning") + + async def stream( + self, + prompt: str, + *, + system: str = "", + max_tokens: int = 4096, + temperature: float = 0.0, + model: str | None = None, + ) -> AsyncIterator[str]: + """Pass-through streaming (slot held during the entire stream). + + TODO(phase-6): replace with provider-specific subclasses + (``OpenAICompatGatedProvider``, ``AnthropicGatedProvider``) that + call the SDK directly and extract the final usage chunk for tok/s + ring-buffer recording. Until Phase 6, this wrapper simply acquires + the slot, delegates to ``raw.stream``, and releases on completion. + """ + retry_max = self._config.retry_max_attempts + + async for attempt in AsyncRetrying( + retry=retry_if_exception(_retry_predicate), + wait=wait_random_exponential(multiplier=1, max=60), + stop=stop_after_attempt(retry_max), + reraise=True, + before_sleep=self._before_sleep, + ): + with attempt: + if self._limiter is not None: + await self._limiter.acquire() + async with self._gate.slot(): + async for chunk in self._raw.stream( + prompt, + system=system, + max_tokens=max_tokens, + temperature=temperature, + model=model, + ): + yield chunk + return # Successful stream complete. + + +# ────────────────────────────────────────────────────────────────────────────── +# Provider-specific gated adapters (Decision 10, Phase 6) +# +# These subclasses override ``stream()`` to extract final token counts from +# the SDK's streaming response and feed them into the gate's 60-second ring +# buffer. The base class ``ConcurrencyGatedProvider.stream()`` is a pass-through +# that does not record token counts; use these subclasses for providers whose +# SDKs surface streaming usage data. + + +class OpenAICompatGatedProvider(ConcurrencyGatedProvider): + """Gated provider for OpenAI-compatible backends. + + Overrides ``stream()`` to pass ``stream_options={"include_usage": True}`` + so the final chunk carries ``chunk.usage.completion_tokens``. On 400 + responses indicating ``stream_options`` is unsupported the adapter falls + back silently and marks the gate flag ``streaming_usage_unsupported``. + + Decision 10b / codex r2 M2. + """ + + async def stream( + self, + prompt: str, + *, + system: str = "", + max_tokens: int = 4096, + temperature: float = 0.0, + model: str | None = None, + ) -> AsyncIterator[str]: + """Gate + retry wrapper around raw OpenAI-compat stream with usage extraction.""" + retry_max = self._config.retry_max_attempts + + async for attempt in AsyncRetrying( + retry=retry_if_exception(_retry_predicate), + wait=wait_random_exponential(multiplier=1, max=60), + stop=stop_after_attempt(retry_max), + reraise=True, + before_sleep=self._before_sleep, + ): + with attempt: + if self._limiter is not None: + await self._limiter.acquire() + async with self._gate.slot(): + output_tokens = 0 + try: + async for chunk in self._stream_with_usage( + prompt, + system=system, + max_tokens=max_tokens, + temperature=temperature, + model=model, + ): + if isinstance(chunk, int): + # Sentinel: final usage token count. + output_tokens = chunk + else: + yield chunk + finally: + if output_tokens > 0: + self._gate.record_completion(output_tokens) + return # Successful stream complete. + + async def _stream_with_usage( + self, + prompt: str, + *, + system: str = "", + max_tokens: int = 4096, + temperature: float = 0.0, + model: str | None = None, + ) -> AsyncIterator[str | int]: + """Delegate to raw.stream, attempting to include usage data. + + Yields text chunks as str, then a final int sentinel with output_tokens. + When stream_options is unsupported (400), falls back to plain streaming. + """ + if self._gate.streaming_usage_unsupported: + # Already know this gate doesn't support stream_options; use raw directly. + async for chunk in self._raw.stream( + prompt, system=system, max_tokens=max_tokens, temperature=temperature, model=model + ): + yield chunk + return + + # Attempt stream with usage via the provider's own stream() first. + # The OpenAICompatProvider.stream() doesn't pass stream_options; to extract + # usage we need the underlying client. Try to access it directly. + raw_client = getattr(self._raw, "client", None) + if raw_client is None: + # No direct SDK access — fall back to raw.stream() without usage. + async for chunk in self._raw.stream( + prompt, system=system, max_tokens=max_tokens, temperature=temperature, model=model + ): + yield chunk + return + + use_model = model or getattr(self._raw, "model", None) + messages: list[dict[str, str]] = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": prompt}) + + # Build extra_body from raw provider if applicable. + extra_body: dict[str, object] | None = None + draft_model = getattr(self._raw, "draft_model", None) + if draft_model: + extra_body = {"draft_model": draft_model} + disable_thinking = getattr(self._raw, "disable_thinking", False) + if disable_thinking: + extra_body = dict(extra_body or {}) + extra_body["chat_template_kwargs"] = {"enable_thinking": False} + + try: + stream = await raw_client.chat.completions.create( + model=use_model, + messages=messages, # type: ignore[arg-type] + max_tokens=max_tokens, + temperature=temperature, + stream=True, + stream_options={"include_usage": True}, + extra_body=extra_body, + ) + completion_tokens = 0 + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + # The final chunk (empty choices) carries usage when stream_options is set. + if chunk.usage is not None: + completion_tokens = chunk.usage.completion_tokens or 0 + if completion_tokens > 0: + yield completion_tokens # type: ignore[misc] # int sentinel + except openai.APIStatusError as exc: + if exc.status_code == 400 and "stream_options" in str(exc).lower(): + # Server doesn't support stream_options — mark the gate and fall back. + self._gate.streaming_usage_unsupported = True + log.warning( + "llm_gate_stream_options_unsupported", + provider=getattr(self._raw, "provider_name", "openai-compatible"), + error=str(exc), + ) + async for chunk in self._raw.stream( + prompt, system=system, max_tokens=max_tokens, temperature=temperature, model=model + ): + yield chunk + else: + raise + + +class AnthropicGatedProvider(ConcurrencyGatedProvider): + """Gated provider for Anthropic's API. + + Overrides ``stream()`` to extract ``usage.output_tokens`` from the + Anthropic SDK's streaming ``message_delta`` event and feed it into the + gate's 60-second ring buffer. + """ + + async def stream( + self, + prompt: str, + *, + system: str = "", + max_tokens: int = 4096, + temperature: float = 0.0, + model: str | None = None, + ) -> AsyncIterator[str]: + """Gate + retry wrapper around raw Anthropic stream with usage extraction.""" + retry_max = self._config.retry_max_attempts + + async for attempt in AsyncRetrying( + retry=retry_if_exception(_retry_predicate), + wait=wait_random_exponential(multiplier=1, max=60), + stop=stop_after_attempt(retry_max), + reraise=True, + before_sleep=self._before_sleep, + ): + with attempt: + if self._limiter is not None: + await self._limiter.acquire() + async with self._gate.slot(): + raw_client = getattr(self._raw, "client", None) + if raw_client is None: + # No direct SDK access; fall back to raw.stream(). + async for chunk in self._raw.stream( + prompt, system=system, max_tokens=max_tokens, temperature=temperature, model=model + ): + yield chunk + return + + use_model = model or getattr(self._raw, "model", None) + output_tokens = 0 + # Build system prompt via raw provider helper if available. + build_system = getattr(self._raw, "_build_system", None) + system_block = build_system(system) if build_system is not None else system + + try: + async with raw_client.messages.stream( + model=use_model, + max_tokens=max_tokens, + temperature=temperature, + system=system_block, + messages=[{"role": "user", "content": prompt}], + ) as stream: + async for text in stream.text_stream: + yield text + # After stream completes, get the final message for usage. + try: + final_msg = await stream.get_final_message() + if final_msg.usage: + output_tokens = final_msg.usage.output_tokens or 0 + except Exception: # noqa: BLE001 + pass + finally: + if output_tokens > 0: + self._gate.record_completion(output_tokens) + return # Successful stream complete. + + +# ────────────────────────────────────────────────────────────────────────────── +# Factory helper + + +async def wrap_provider( + raw: LLMProvider, + provider_name: str, + base_url: str | None, + kind: str, + registry: ProviderGateRegistry, + config: ConcurrencyConfig | None = None, +) -> LLMProvider: + """Wrap ``raw`` in a provider-specific gated adapter if the kill switch is on. + + Decision 10 (Phase 6): branch on the raw provider's type to select the + correct gated subclass that can extract streaming usage tokens: + + - OpenAICompatProvider → OpenAICompatGatedProvider (stream_options extraction) + - AnthropicProvider → AnthropicGatedProvider (message_delta extraction) + - anything else → ConcurrencyGatedProvider (no streaming usage) + + Returns ``raw`` unchanged when ``config.wrapper_enabled`` is False. + """ + cfg = config or registry._config + if not cfg.wrapper_enabled: + return raw + gate = await registry.lookup(provider_name, base_url, kind) + + # Lazy imports to avoid circular imports at module load time. + from workers.common.llm.anthropic import AnthropicProvider # noqa: PLC0415 + from workers.common.llm.openai_compat import OpenAICompatProvider # noqa: PLC0415 + + if isinstance(raw, OpenAICompatProvider): + return OpenAICompatGatedProvider(raw, gate, cfg) + if isinstance(raw, AnthropicProvider): + return AnthropicGatedProvider(raw, gate, cfg) + return ConcurrencyGatedProvider(raw, gate, cfg) diff --git a/workers/common/llm/concurrency_probe.py b/workers/common/llm/concurrency_probe.py index b194242d..0ce33ac4 100644 --- a/workers/common/llm/concurrency_probe.py +++ b/workers/common/llm/concurrency_probe.py @@ -24,6 +24,7 @@ from __future__ import annotations import asyncio +import contextlib import time from typing import TYPE_CHECKING, Protocol @@ -87,14 +88,12 @@ async def call(self) -> float: start = time.monotonic() async with httpx.AsyncClient(timeout=30.0) as client: - try: + with contextlib.suppress(Exception): await client.post( f"{self._base_url}/chat/completions", json=payload, headers=headers, ) - except Exception: - pass # timing measurement ignores response errors return time.monotonic() - start diff --git a/workers/common/llm/config.py b/workers/common/llm/config.py index aafd5b7c..e4002b15 100644 --- a/workers/common/llm/config.py +++ b/workers/common/llm/config.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import TYPE_CHECKING from workers.common.config import ( SUPPORTED_LLM_PROVIDERS, @@ -14,6 +15,9 @@ from workers.common.llm.openai_compat import OpenAICompatProvider from workers.common.llm.provider import LLMProvider +if TYPE_CHECKING: + from workers.common.llm.concurrency import ProviderGateRegistry + def _env_truthy(value: str) -> bool: return value.strip().lower() in ("true", "1", "yes", "on") @@ -58,14 +62,45 @@ def _resolve_disable_thinking(*, report: bool = False) -> bool: return True -def create_llm_provider(config: WorkerConfig) -> LLMProvider: - """Create an LLM provider from configuration.""" +def _create_llm_provider_sync( + config: WorkerConfig, + *, + provider: str = "", + base_url: str = "", + api_key: str = "", + model: str = "", + draft_model: str = "", + timeout_seconds: int = 0, +) -> tuple[LLMProvider, str]: + """Synchronous per-request provider factory — returns the raw (unwrapped) provider. + + Used by ``servicer_utils.resolve_provider_for_context`` which must remain + synchronous (called from backward-compat wrappers on gRPC servicers that + haven't been made async yet). No gate wrapping — the caller uses the + returned ``ProviderResolutionKey`` to look up the gate directly. + """ + effective = config.model_copy( + update={ + "llm_provider": provider or config.llm_provider, + "llm_base_url": base_url or config.llm_base_url, + "llm_api_key": api_key or config.llm_api_key, + "llm_model": model or config.llm_model, + "llm_draft_model": draft_model or config.llm_draft_model, + "llm_timeout": timeout_seconds if timeout_seconds > 0 else config.llm_timeout, + } + ) + return _build_raw_llm_provider(effective), effective.llm_model + + +def _build_raw_llm_provider(config: WorkerConfig) -> LLMProvider: + """Build a raw (unwrapped) LLM provider from config. No gate, no async.""" if config.test_mode: return FakeLLMProvider() if config.llm_provider == "anthropic": return AnthropicProvider(api_key=config.llm_api_key, model=config.llm_model) - elif config.llm_provider == "lmstudio": + + if config.llm_provider == "lmstudio": lmstudio_url = config.llm_base_url or "http://localhost:1234/v1" return OpenAICompatProvider( api_key=config.llm_api_key, @@ -74,7 +109,8 @@ def create_llm_provider(config: WorkerConfig) -> LLMProvider: draft_model=config.llm_draft_model or None, provider_name="lmstudio", ) - elif config.llm_provider in ("openai", "ollama", "vllm", "llama-cpp", "sglang", "gemini", "openrouter"): + + if config.llm_provider in ("openai", "ollama", "vllm", "llama-cpp", "sglang", "gemini", "openrouter"): if config.llm_base_url: base_url: str | None = config.llm_base_url elif config.llm_provider == "ollama": @@ -99,12 +135,7 @@ def create_llm_provider(config: WorkerConfig) -> LLMProvider: "X-Title": "SourceBridge", } - # Disable thinking mode for local models by default. Thinking - # models (Qwen 3.5) generate long chains that waste - # tokens on summarization tasks. Operators can re-enable via - # SOURCEBRIDGE_LLM_ENABLE_THINKING=true. disable_thinking = _resolve_disable_thinking() - return OpenAICompatProvider( api_key=config.llm_api_key, model=config.llm_model, @@ -114,21 +145,58 @@ def create_llm_provider(config: WorkerConfig) -> LLMProvider: disable_thinking=disable_thinking, timeout=float(config.llm_timeout) if config.llm_timeout else None, ) - else: - # Defense in depth: WorkerConfig._validate_llm_provider catches - # this at config-load time, but per-request overrides via - # config.model_copy(update=...) skip validators in pydantic v2 by - # default. The actionable message below mirrors the validator's - # so a metadata-driven override carrying a typo doesn't crash - # the worker mid-request with a confusing stack trace. - # Tester report 2026-04-30 (Pazaryna) R2 / CA-125. - raise ValueError( - f"LLM provider {config.llm_provider!r} is not supported. " - f"Supported LLM providers: {_format_supported(SUPPORTED_LLM_PROVIDERS)}." + + raise ValueError( + f"LLM provider {config.llm_provider!r} is not supported. " + f"Supported LLM providers: {_format_supported(SUPPORTED_LLM_PROVIDERS)}." + ) + + +async def create_llm_provider( + config: WorkerConfig, + *, + gate_registry: ProviderGateRegistry | None = None, +) -> LLMProvider: + """Create an LLM provider from configuration. + + When ``gate_registry`` is supplied (Phase 2+), the returned provider is + wrapped in a ``ConcurrencyGatedProvider`` so all calls pass through the + host/kind gate. When None (kill-switch off or tests that construct + providers directly), the raw provider is returned unchanged. + """ + raw = _build_raw_llm_provider(config) + + if gate_registry is not None: + from workers.common.llm.concurrency import wrap_provider + + base_url_for_gate = _provider_base_url(config) + return await wrap_provider( + raw, + provider_name=config.llm_provider, + base_url=base_url_for_gate, + kind="llm", + registry=gate_registry, ) + return raw -def create_llm_provider_for_request( +def _provider_base_url(config: WorkerConfig) -> str | None: + """Resolve the effective base URL for a provider config (used for gate key lookup).""" + if config.llm_base_url: + return config.llm_base_url + _defaults = { + "ollama": "http://localhost:11434/v1", + "vllm": "http://localhost:8000/v1", + "llama-cpp": "http://localhost:8080/v1", + "sglang": "http://localhost:30000/v1", + "lmstudio": "http://localhost:1234/v1", + "gemini": "https://generativelanguage.googleapis.com/v1beta/openai/", + "openrouter": "https://openrouter.ai/api/v1", + } + return _defaults.get(config.llm_provider) + + +async def create_llm_provider_for_request( config: WorkerConfig, *, provider: str = "", @@ -137,6 +205,7 @@ def create_llm_provider_for_request( model: str = "", draft_model: str = "", timeout_seconds: int = 0, + gate_registry: ProviderGateRegistry | None = None, ) -> tuple[LLMProvider, str]: """Create a per-request provider from effective runtime settings. @@ -144,6 +213,10 @@ def create_llm_provider_for_request( ``timeout_seconds`` > 0 overrides the worker's bootstrap ``llm_timeout``; this is how the admin UI's TimeoutSecs reaches the HTTP client on a per-call basis. + + When ``gate_registry`` is supplied it is forwarded to + ``create_llm_provider`` so the returned provider is wrapped in the + concurrency gate (plan v4 Phase 2, bob H4). """ effective = config.model_copy( update={ @@ -155,14 +228,22 @@ def create_llm_provider_for_request( "llm_timeout": timeout_seconds if timeout_seconds > 0 else config.llm_timeout, } ) - return create_llm_provider(effective), effective.llm_model + return await create_llm_provider(effective, gate_registry=gate_registry), effective.llm_model -def create_report_provider(config: WorkerConfig) -> LLMProvider | None: +async def create_report_provider( + config: WorkerConfig, + *, + gate_registry: ProviderGateRegistry | None = None, +) -> LLMProvider | None: """Create a separate LLM provider for report generation, if configured. Returns None if no report-specific provider is configured, meaning the caller should fall back to the main provider. + + When ``gate_registry`` is supplied the returned provider is wrapped in + the concurrency gate using the report provider's name and base URL + (plan v4 Phase 2). """ if not config.llm_report_provider and not config.llm_report_model: return None @@ -173,23 +254,37 @@ def create_report_provider(config: WorkerConfig) -> LLMProvider | None: base_url = config.llm_report_base_url or config.llm_base_url if provider_name == "anthropic": - return AnthropicProvider(api_key=api_key, model=model) + raw: LLMProvider = AnthropicProvider(api_key=api_key, model=model) + effective_base_url: str | None = base_url or None + else: + # All other providers use OpenAI-compatible interface + default_urls = { + "ollama": "http://localhost:11434/v1", + "vllm": "http://localhost:8000/v1", + "lmstudio": "http://localhost:1234/v1", + } + if not base_url: + base_url = default_urls.get(provider_name, "") - # All other providers use OpenAI-compatible interface - default_urls = { - "ollama": "http://localhost:11434/v1", - "vllm": "http://localhost:8000/v1", - "lmstudio": "http://localhost:1234/v1", - } - if not base_url: - base_url = default_urls.get(provider_name, "") + disable_thinking = _resolve_disable_thinking(report=True) - disable_thinking = _resolve_disable_thinking(report=True) + raw = OpenAICompatProvider( + api_key=api_key, + model=model, + base_url=base_url, + provider_name=provider_name, + disable_thinking=disable_thinking, + ) + effective_base_url = base_url or None - return OpenAICompatProvider( - api_key=api_key, - model=model, - base_url=base_url, - provider_name=provider_name, - disable_thinking=disable_thinking, - ) + if gate_registry is not None: + from workers.common.llm.concurrency import wrap_provider + + return await wrap_provider( + raw, + provider_name=provider_name, + base_url=effective_base_url, + kind="llm", + registry=gate_registry, + ) + return raw diff --git a/workers/common/llm/errors.py b/workers/common/llm/errors.py index f9d63c8b..bed02734 100644 --- a/workers/common/llm/errors.py +++ b/workers/common/llm/errors.py @@ -15,6 +15,20 @@ def is_provider_compute_error(exc: Exception) -> bool: """Classify an exception as a transient LLM-backend error. + .. deprecated:: + DEPRECATED — see plan 2026-05-06-deliver-worker-llm-concurrency Decision 4. + + The two hand-rolled retry loops that called this function + (``hierarchical.py`` leaf/file/package retry and + ``renderers.py:_render_with_retry``) were deleted in Phase 3. The + tenacity gate now owns all retry logic via its ``_retry_predicate`` + whitelist, which handles the same set of transient conditions using + proper SDK exception types rather than string-matching. + + This function is retained for backward compatibility (tests, any + third-party imports) but is no longer called in the hot path. + Remove the function entirely in a follow-up cleanup pass. + Returns True for failures that the retry path should swallow — timeouts, broken pipes, partial connection resets, gateway 5xx, and the original "compute error" / "server_error" markers. A timeout in particular is diff --git a/workers/common/llm/fake.py b/workers/common/llm/fake.py index 51e7c4eb..26889f3d 100644 --- a/workers/common/llm/fake.py +++ b/workers/common/llm/fake.py @@ -1,8 +1,27 @@ -"""Fake LLM provider for deterministic testing.""" +"""Fake LLM provider for deterministic testing. + +**This module is test-only.** Do not instantiate ``FakeLLMProvider`` in +production code paths. + +The class ships two personalities: +1. **Fixture-backed** (default) — returns deterministic responses based on + prompt keywords. Suitable for integration tests that don't care about + failure paths. +2. **Fail-mode** — constructor kwargs enable fault injection for concurrency + gate tests (Phase 3+): + - ``raise_on_attempts`` — raise ``exc`` for the first N calls. + - ``delay_seconds`` — sleep this many seconds before returning/raising. + - ``responses`` — a queue of ``str | Exception``; each call pops one item; + strings are returned as content, exceptions are raised. + +See plan Decision 9: thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.md +""" from __future__ import annotations +import asyncio import json +from collections import deque from collections.abc import AsyncIterator from workers.common.llm.provider import LLMResponse @@ -225,11 +244,59 @@ class FakeLLMProvider: ] ) + def __init__( + self, + *, + raise_on_attempts: int = 0, + exc: BaseException | None = None, + delay_seconds: float = 0.0, + responses: list[str | BaseException] | None = None, + ) -> None: + """Create a fake LLM provider. + + Args: + raise_on_attempts: Raise ``exc`` for the first N calls (0 = never). + exc: Exception to raise when ``raise_on_attempts > 0``. + Defaults to ``RuntimeError("FakeLLMProvider injected failure")``. + delay_seconds: Sleep this long (real asyncio sleep) before each + response. Use small values (0.01–0.1 s) in tests that verify + concurrency ordering without making the suite slow. + responses: Optional queue of ``str | BaseException``. Each call + pops the leftmost item; if it's a string it's returned as the + response content; if it's an exception it's raised. When the + queue is exhausted, falls back to the fixture-based dispatch. + """ + self._raise_on_attempts = raise_on_attempts + self._exc = exc or RuntimeError("FakeLLMProvider injected failure") + self._delay_seconds = delay_seconds + self._responses: deque[str | BaseException] = deque(responses or []) + self._call_count: int = 0 + @property def default_model(self) -> str: """Return the default model ID.""" return "fake-test-model" + async def _maybe_fail_or_delay(self) -> str | None: + """Handle fail-mode kwargs. Returns queued content string or None.""" + if self._delay_seconds > 0: + await asyncio.sleep(self._delay_seconds) + + self._call_count += 1 + + # Queued response takes priority. + if self._responses: + item = self._responses.popleft() + if isinstance(item, BaseException): + raise item + return item + + # raise_on_attempts covers the first N calls. + if self._raise_on_attempts > 0 and self._call_count <= self._raise_on_attempts: + raise self._exc + + return None + async def complete( self, prompt: str, @@ -237,9 +304,20 @@ async def complete( system: str = "", max_tokens: int = 4096, temperature: float = 0.0, + frequency_penalty: float = 0.0, model: str | None = None, ) -> LLMResponse: """Return deterministic fixture response based on prompt content.""" + queued = await self._maybe_fail_or_delay() + if queued is not None: + return LLMResponse( + content=queued, + model="fake-test-model", + input_tokens=len(prompt.split()), + output_tokens=len(queued.split()), + stop_reason="end_turn", + ) + content = self.SUMMARY_RESPONSE if "cliff notes" in prompt.lower() or "required sections" in prompt.lower(): content = self.CLIFF_NOTES_RESPONSE @@ -271,7 +349,14 @@ async def stream( temperature: float = 0.0, model: str | None = None, ) -> AsyncIterator[str]: - """Stream deterministic response word by word.""" - response = await self.complete(prompt, system=system, max_tokens=max_tokens, temperature=temperature) + """Stream deterministic response word by word. + + Fail-mode kwargs apply: delay fires before the first chunk; + raise_on_attempts / queued exceptions raise before any chunks are + yielded. + """ + response = await self.complete( + prompt, system=system, max_tokens=max_tokens, temperature=temperature + ) for word in response.content.split(): yield word + " " diff --git a/workers/common/llm/openai_compat.py b/workers/common/llm/openai_compat.py index efc14053..a53b063f 100644 --- a/workers/common/llm/openai_compat.py +++ b/workers/common/llm/openai_compat.py @@ -144,6 +144,7 @@ def __init__( base_url=base_url, timeout=effective_timeout, default_headers=extra_headers or {}, + max_retries=0, # Phase 3: SDK retry disabled; tenacity owns retry (Decision 3) ) self.model = model self.draft_model = draft_model diff --git a/workers/common/llm/router.py b/workers/common/llm/router.py index cb4f3003..e5119fee 100644 --- a/workers/common/llm/router.py +++ b/workers/common/llm/router.py @@ -5,12 +5,32 @@ from collections.abc import AsyncIterator import structlog +from tenacity import RetryError from workers.common.llm.provider import LLMProvider, LLMResponse log = structlog.get_logger() +def _unwrap_retry_error(exc: BaseException) -> BaseException: + """Unwrap tenacity.RetryError to expose the original causing exception. + + Phase 3 defensive fix (plan bob H3): ``ConcurrencyGatedProvider`` uses + tenacity with ``reraise=True``, which re-raises the original exception + directly. However, if tenacity ever wraps the exception in a ``RetryError`` + (e.g. when ``reraise=False`` is accidentally used or in future tenacity + versions), callers see an opaque ``RetryError`` instead of the underlying + SDK exception — breaking error classification, logging, and the router's + fallback logic. + + This unwrap is defensive dead-code today (``reraise=True`` is set + everywhere), but protects against accidental regressions. + """ + if isinstance(exc, RetryError) and exc.__cause__ is not None: + return exc.__cause__ + return exc + + class LLMRouter: """Routes LLM requests to providers with fallback support.""" @@ -33,8 +53,11 @@ async def complete( try: return await provider.complete(prompt, system=system, max_tokens=max_tokens, temperature=temperature) except Exception as e: - last_error = e - log.warning("provider_failed", provider_index=i, error=str(e)) + # Unwrap RetryError so the router and its callers always see + # the original SDK exception — not an opaque tenacity wrapper. + unwrapped = _unwrap_retry_error(e) + last_error = unwrapped if isinstance(unwrapped, Exception) else e + log.warning("provider_failed", provider_index=i, error=str(last_error)) continue raise RuntimeError(f"All LLM providers failed. Last error: {last_error}") @@ -56,7 +79,8 @@ async def stream( yield token return except Exception as e: - last_error = e - log.warning("stream_provider_failed", provider_index=i, error=str(e)) + unwrapped = _unwrap_retry_error(e) + last_error = unwrapped if isinstance(unwrapped, Exception) else e + log.warning("stream_provider_failed", provider_index=i, error=str(last_error)) continue raise RuntimeError(f"All LLM providers failed for streaming. Last error: {last_error}") diff --git a/workers/common/servicer_utils.py b/workers/common/servicer_utils.py index b375dd66..ab1d41b1 100644 --- a/workers/common/servicer_utils.py +++ b/workers/common/servicer_utils.py @@ -9,19 +9,38 @@ from __future__ import annotations +from dataclasses import dataclass +from typing import Literal + import grpc from workers.common.grpc_metadata import resolve_llm_override, resolve_model_override -from workers.common.llm.config import create_llm_provider_for_request from workers.common.llm.provider import LLMProvider +@dataclass(frozen=True) +class ProviderResolutionKey: + """Canonical gate-registry key for the resolved provider context. + + Returned as the third element of ``resolve_provider_for_context`` when + ``gate_registry`` is supplied. The registry's ``canonical_key_for(...)`` + accepts this and returns the internal lookup tuple that identifies the + right gate for capability reporting (Decision 12, plan v4 Phase 2). + """ + + provider: str + base_url: str | None + kind: Literal["llm", "embedding"] = "llm" + + def resolve_provider_for_context( llm: LLMProvider, config, context: grpc.aio.ServicerContext, fallback_llm: LLMProvider | None = None, -) -> tuple[LLMProvider, str | None]: + *, + gate_registry=None, +) -> tuple[LLMProvider, str | None, ProviderResolutionKey | None]: """Resolve the LLM provider for a gRPC request, honoring metadata overrides. Args: @@ -31,10 +50,15 @@ def resolve_provider_for_context( fallback_llm: optional alternate provider used when no override is present and a separate report/fallback provider is configured (used by the knowledge servicer's _resolve_report_provider variant). + gate_registry: optional ``ProviderGateRegistry``; when supplied the + third return value is a ``ProviderResolutionKey`` describing the + canonical gate key for the resolved (provider, base_url). When + None, the third return value is None. Returns: - (provider, model_override) — the provider to use for this request, and - the per-call model string if one was supplied (None otherwise). + (provider, model_override, resolution_key) — the provider to use for + this request, the per-call model string (or None), and the canonical + gate key (or None when gate_registry is None). Resolution order: 1. If a full LLM override is present in the gRPC metadata and a worker @@ -54,16 +78,25 @@ def resolve_provider_for_context( # _resolve_report_provider path: prefer the fallback provider with # its configured report model as the fallback model string. fallback_model = getattr(config, "llm_report_model", None) if config is not None else None - return fallback_llm, model or fallback_model or None - return llm, model + resolution_key = _build_resolution_key(config, gate_registry) if gate_registry is not None else None + return fallback_llm, model or fallback_model or None, resolution_key + resolution_key = _build_resolution_key(config, gate_registry) if gate_registry is not None else None + return llm, model, resolution_key # A full provider override is present in the metadata. if config is None: # No config to build a fresh provider from; use the fallback (or default) # and whatever model the override carries. - return fallback_llm if fallback_llm is not None else llm, override.model or None + return fallback_llm if fallback_llm is not None else llm, override.model or None, None - provider, model = create_llm_provider_for_request( + # Per-request providers are created synchronously without gate wrapping. + # Gate wrapping is reserved for long-lived startup providers (plan Phase 2). + # The resolution_key we return lets callers (GetProviderCapabilities) look up + # the *gate* for this override's (provider, base_url) without re-wrapping. + # Lazy import avoids a circular dependency between servicer_utils ↔ config. + from workers.common.llm.config import _create_llm_provider_sync + + provider, model = _create_llm_provider_sync( config, provider=override.provider, base_url=override.base_url, @@ -72,4 +105,26 @@ def resolve_provider_for_context( draft_model=override.draft_model, timeout_seconds=override.timeout_seconds, ) - return provider, model or None + # Build the resolution key from the override's effective provider/base_url. + resolution_key: ProviderResolutionKey | None = None + if gate_registry is not None: + resolved_provider = override.provider or getattr(config, "llm_provider", "") + resolved_base_url = override.base_url or getattr(config, "llm_base_url", None) or None + resolution_key = ProviderResolutionKey( + provider=resolved_provider, + base_url=resolved_base_url, + kind="llm", + ) + return provider, model or None, resolution_key + + +def _build_resolution_key(config, gate_registry) -> ProviderResolutionKey | None: + """Build a ProviderResolutionKey from the bootstrap config. + + Returns None when config is None (no provider info available). + """ + if config is None: + return None + provider = getattr(config, "llm_provider", "") or "" + base_url = getattr(config, "llm_base_url", None) or None + return ProviderResolutionKey(provider=provider, base_url=base_url, kind="llm") diff --git a/workers/comprehension/hierarchical.py b/workers/comprehension/hierarchical.py index 3423f963..46beb40c 100644 --- a/workers/comprehension/hierarchical.py +++ b/workers/comprehension/hierarchical.py @@ -79,9 +79,9 @@ log = structlog.get_logger() -DEFAULT_LEAF_CONCURRENCY = 1 -DEFAULT_FILE_CONCURRENCY = 2 -DEFAULT_PACKAGE_CONCURRENCY = 2 +DEFAULT_LEAF_CONCURRENCY = 4 # Phase 3: raised from 1 (Outcome A — cap raise safe per dick) +DEFAULT_FILE_CONCURRENCY = 4 # Phase 3: raised from 2 +DEFAULT_PACKAGE_CONCURRENCY = 4 # Phase 3: raised from 2 DEFAULT_LEAF_MAX_TOKENS = 384 DEFAULT_FILE_MAX_TOKENS = 640 DEFAULT_PACKAGE_MAX_TOKENS = 896 @@ -1013,43 +1013,32 @@ async def _call_llm( # body exceeding the budget is a real problem. raise - max_attempts = 3 + # Phase 3: hand-rolled retry deleted — the ConcurrencyGatedProvider + # wrapper owns all retry via tenacity (Decision 3 / Decision 4). + # A single attempt is made here; transient errors (429, 503, timeouts) + # are retried at the gate layer with jitter-backoff before this caller + # sees them. Non-retryable errors fall through to the fallback below. last_exc: Exception | None = None - for attempt in range(1, max_attempts + 1): - try: - response: LLMResponse = require_nonempty( - await complete_with_optional_model( - self._provider, - prompt, - system=HIERARCHICAL_SYSTEM, - temperature=0.0, - max_tokens=max_tokens, - model=self._config.model_override, - ), - context=context, - ) - return ( - response.content.strip(), - response.model, - response.input_tokens, - response.output_tokens, - ) - except Exception as exc: - last_exc = exc - if _is_provider_compute_error(exc): - self._provider_compute_errors += 1 - if attempt < max_attempts: - delay = 0.35 * (2 ** (attempt - 1)) - log.warning( - "hierarchical_node_retry", - context=context, - attempt=attempt, - delay_s=delay, - error=str(exc), - ) - await asyncio.sleep(delay) - continue - break + try: + response: LLMResponse = require_nonempty( + await complete_with_optional_model( + self._provider, + prompt, + system=HIERARCHICAL_SYSTEM, + temperature=0.0, + max_tokens=max_tokens, + model=self._config.model_override, + ), + context=context, + ) + return ( + response.content.strip(), + response.model, + response.input_tokens, + response.output_tokens, + ) + except Exception as exc: + last_exc = exc self._fallback_count += 1 if ":root" in context: diff --git a/workers/comprehension/renderers.py b/workers/comprehension/renderers.py index 8006665c..e2ac6cc1 100644 --- a/workers/comprehension/renderers.py +++ b/workers/comprehension/renderers.py @@ -542,8 +542,8 @@ class CliffNotesRenderer: max_file_summaries: int = 12 max_tokens_per_call: int = 16384 # thinking models need headroom for chains before the JSON output model_override: str | None = None - deep_parallelism: int = 2 - deep_repair_parallelism: int = 2 + deep_parallelism: int = 4 # Phase 3: raised from 2 (Outcome A) + deep_repair_parallelism: int = 4 # Phase 3: raised from 2 (Outcome A) async def render( self, @@ -714,9 +714,16 @@ async def render( depth_instructions=depth_instructions, required_sections=required_sections, ) - response = await self._render_with_retry( - prompt=prompt, - scope_type=scope_type, + response = require_nonempty( + await complete_with_optional_model( + self.provider, + prompt, + system=CLIFF_NOTES_SYSTEM, + temperature=0.0, + max_tokens=self.max_tokens_per_call, + model=self.model_override, + ), + context=f"hierarchical_render:cliff_notes:{scope_type}", ) sections = self._parse_sections( response.content, @@ -836,9 +843,16 @@ async def render_one(title: str) -> tuple[str, CliffNotesSection, LLMResponse | context=f"hierarchical_render:cliff_notes:targeted:{title}", ) try: - response = await self._render_with_retry( - prompt=prompt, - scope_type=f"{scope_type}:targeted:{title}", + response = require_nonempty( + await complete_with_optional_model( + self.provider, + prompt, + system=CLIFF_NOTES_SYSTEM, + temperature=0.0, + max_tokens=self.max_tokens_per_call, + model=self.model_override, + ), + context=f"hierarchical_render:cliff_notes:{scope_type}:targeted:{title}", ) section = self._parse_sections( response.content, @@ -960,9 +974,16 @@ async def render_group( required_sections=list(section_group), ) try: - response = await self._render_with_retry( - prompt=prompt, - scope_type=f"{scope_type}:deep_group", + response = require_nonempty( + await complete_with_optional_model( + self.provider, + prompt, + system=CLIFF_NOTES_SYSTEM, + temperature=0.0, + max_tokens=self.max_tokens_per_call, + model=self.model_override, + ), + context=f"hierarchical_render:cliff_notes:{scope_type}:deep_group", ) sections = self._parse_sections( response.content, @@ -1108,9 +1129,16 @@ async def repair_one(title: str) -> tuple[str, CliffNotesSection | None]: context=f"hierarchical_render:cliff_notes:repair:{title}", ) try: - response = await self._render_with_retry( - prompt=prompt, - scope_type=f"repository:repair:{title}", + response = require_nonempty( + await complete_with_optional_model( + self.provider, + prompt, + system=CLIFF_NOTES_SYSTEM, + temperature=0.0, + max_tokens=self.max_tokens_per_call, + model=self.model_override, + ), + context=f"hierarchical_render:cliff_notes:repository:repair:{title}", ) repaired = self._parse_sections( response.content, @@ -1217,46 +1245,6 @@ def _build_render_prompt( ) return prompt - async def _render_with_retry( - self, - *, - prompt: str, - scope_type: str, - ) -> LLMResponse: - context = f"hierarchical_render:cliff_notes:{scope_type}" - last_exc: Exception | None = None - for attempt in range(1, 4): - try: - return require_nonempty( - await complete_with_optional_model( - self.provider, - prompt, - system=CLIFF_NOTES_SYSTEM, - temperature=0.0, - max_tokens=self.max_tokens_per_call, - model=self.model_override, - ), - context=context, - ) - except Exception as exc: - last_exc = exc - if _is_provider_compute_error(exc) and attempt < 3: - delay = 0.4 * (2 ** (attempt - 1)) - log.warning( - "cliff_notes_renderer_retry", - scope_type=scope_type, - attempt=attempt, - delay_s=delay, - error=str(exc), - ) - import asyncio - - await asyncio.sleep(delay) - continue - break - assert last_exc is not None - raise last_exc - # ------------------------------------------------------------------ # Selection helpers diff --git a/workers/knowledge/servicer.py b/workers/knowledge/servicer.py index f7efb08d..a352ab3f 100644 --- a/workers/knowledge/servicer.py +++ b/workers/knowledge/servicer.py @@ -261,12 +261,15 @@ def __init__( report_llm: LLMProvider | None = None, worker_config: WorkerConfig | None = None, summary_node_cache: SurrealSummaryNodeCache | None = None, + gate_registry=None, ) -> None: self._llm = llm_provider self._embedding = embedding_provider self._report_llm = report_llm self._config = worker_config self._summary_node_cache = summary_node_cache + # ProviderGateRegistry for snapshot_tokens_per_second at progress emit sites. + self._gate_registry = gate_registry # default_model_id is the best-effort identifier of the model # the LLM provider is configured with. The selector uses it to # look up the model's capability profile when no per-call @@ -278,13 +281,29 @@ def __init__( # request-local (see GenerateCliffNotes for the holder # pattern); no servicer-instance state for in-flight runs. + def _snapshot_tokens_per_second(self) -> float: + """Return the 60-second ring-buffer tok/s for the configured LLM gate. + + Returns 0.0 when gate_registry is absent, config is absent, or the + wrapper kill-switch is disabled. Callers treat 0.0 as "unknown". + """ + if self._gate_registry is None or self._config is None: + return 0.0 + provider = getattr(self._config, "llm_provider", "") or "" + base_url = getattr(self._config, "llm_base_url", None) or None + return self._gate_registry.snapshot_tokens_per_second(provider, base_url, "llm") + def _resolve_request_provider(self, context: grpc.aio.ServicerContext) -> tuple[LLMProvider, str | None]: """Backward-compat wrapper. New code should call resolve_provider_for_context directly.""" - return resolve_provider_for_context(self._llm, self._config, context) + provider, model, _ = resolve_provider_for_context(self._llm, self._config, context) + return provider, model def _resolve_report_provider(self, context: grpc.aio.ServicerContext) -> tuple[LLMProvider, str | None]: """Backward-compat wrapper with self._report_llm fallback.""" - return resolve_provider_for_context(self._llm, self._config, context, fallback_llm=self._report_llm) + provider, model, _ = resolve_provider_for_context( + self._llm, self._config, context, fallback_llm=self._report_llm + ) + return provider, model def _build_job_state_updater( self, @@ -1309,6 +1328,7 @@ def _snapshot_fn() -> dict: prog.completed_units = 0 prog.total_units = 0 prog.unit_kind = "" + prog.current_tokens_per_second = self._snapshot_tokens_per_second() yield knowledge_pb2.KnowledgeServiceGenerateCliffNotesResponse(progress=prog) # Work task is done. Await it to surface the result or @@ -1540,6 +1560,7 @@ async def GenerateLearningPath( # noqa: N802 phase=knowledge_progress_pb2.KNOWLEDGE_PHASE_RENDER, message="Generating learning path", ): + prog.current_tokens_per_second = self._snapshot_tokens_per_second() yield knowledge_pb2.KnowledgeServiceGenerateLearningPathResponse(progress=prog) result, usage = await work_task except asyncio.CancelledError: @@ -1693,6 +1714,7 @@ async def GenerateArchitectureDiagram( # noqa: N802 phase=knowledge_progress_pb2.KNOWLEDGE_PHASE_RENDER, message="Generating architecture diagram", ): + prog.current_tokens_per_second = self._snapshot_tokens_per_second() yield knowledge_pb2.KnowledgeServiceGenerateArchitectureDiagramResponse(progress=prog) result, usage = await work_task except asyncio.CancelledError: @@ -1868,6 +1890,7 @@ async def GenerateWorkflowStory( # noqa: N802 phase=knowledge_progress_pb2.KNOWLEDGE_PHASE_RENDER, message="Generating workflow story", ): + prog.current_tokens_per_second = self._snapshot_tokens_per_second() yield knowledge_pb2.KnowledgeServiceGenerateWorkflowStoryResponse(progress=prog) result, usage = await work_task except asyncio.CancelledError: @@ -2023,6 +2046,7 @@ async def ExplainSystem( # noqa: N802 phase=knowledge_progress_pb2.KNOWLEDGE_PHASE_RENDER, message="Explaining system", ): + prog.current_tokens_per_second = self._snapshot_tokens_per_second() yield knowledge_pb2.KnowledgeServiceExplainSystemResponse(progress=prog) result, usage = await work_task except asyncio.CancelledError: @@ -2151,6 +2175,7 @@ async def GenerateCodeTour( # noqa: N802 phase=knowledge_progress_pb2.KNOWLEDGE_PHASE_RENDER, message="Generating code tour", ): + prog.current_tokens_per_second = self._snapshot_tokens_per_second() yield knowledge_pb2.KnowledgeServiceGenerateCodeTourResponse(progress=prog) result, usage = await work_task except asyncio.CancelledError: diff --git a/workers/knowledge/streaming.py b/workers/knowledge/streaming.py index f47b8922..7030e1be 100644 --- a/workers/knowledge/streaming.py +++ b/workers/knowledge/streaming.py @@ -112,6 +112,7 @@ def progress_event( file_cache_hits: int = 0, package_cache_hits: int = 0, root_cache_hits: int = 0, + current_tokens_per_second: float = 0.0, ) -> knowledge_progress_pb2.KnowledgeStreamProgress: """Build a KnowledgeStreamProgress message with the given fields.""" return knowledge_progress_pb2.KnowledgeStreamProgress( @@ -124,6 +125,7 @@ def progress_event( file_cache_hits=file_cache_hits, package_cache_hits=package_cache_hits, root_cache_hits=root_cache_hits, + current_tokens_per_second=current_tokens_per_second, ) diff --git a/workers/pyproject.toml b/workers/pyproject.toml index a93f6e74..8f522db7 100644 --- a/workers/pyproject.toml +++ b/workers/pyproject.toml @@ -16,6 +16,8 @@ dependencies = [ "pydantic>=2.7.0", "pydantic-settings>=2.2.0", "pyyaml>=6.0.2", + "tenacity>=8.5.0", + "aiolimiter>=1.2.1", ] [dependency-groups] @@ -53,6 +55,9 @@ asyncio_mode = "auto" pythonpath = [".", "../gen/python"] testpaths = [".", "../tests"] norecursedirs = ["fixtures", "node_modules", ".git", "__pycache__"] +markers = [ + "slow: marks tests that require real wall-clock time (deselect with -m 'not slow')", +] [build-system] requires = ["hatchling"] diff --git a/workers/reasoning/servicer.py b/workers/reasoning/servicer.py index e62cc87a..2527bd80 100644 --- a/workers/reasoning/servicer.py +++ b/workers/reasoning/servicer.py @@ -9,6 +9,7 @@ from workers.common.config import HARD_CONCURRENCY_CEILING, WorkerConfig from workers.common.embedding.provider import EmbeddingProvider +from workers.common.llm.concurrency import _UNCAPPED, ProviderGateRegistry from workers.common.llm.provider import LLMProvider from workers.common.llm.tools import ( AgentMessage, @@ -154,14 +155,19 @@ def __init__( llm_provider: LLMProvider, embedding_provider: EmbeddingProvider, worker_config: WorkerConfig | None = None, + gate_registry: ProviderGateRegistry | None = None, ) -> None: self._llm = llm_provider self._embedding = embedding_provider self._config = worker_config + self._gate_registry = gate_registry def _resolve_provider(self, context: grpc.aio.ServicerContext) -> tuple[LLMProvider, str | None]: """Backward-compat wrapper. New code should call resolve_provider_for_context directly.""" - return resolve_provider_for_context(self._llm, self._config, context) + provider, model, _ = resolve_provider_for_context( + self._llm, self._config, context, gate_registry=self._gate_registry + ) + return provider, model async def AnalyzeSymbol( # noqa: N802 self, @@ -701,18 +707,51 @@ async def GetProviderCapabilities( # noqa: N802 provider_key = "anthropic" if "anthropic" in provider_name else "" model = getattr(provider, "default_model", "") or getattr(provider, "model", "") - # Populate the new capacity fields (D3 / Phase 1). - # max_concurrent_calls is taken from the worker config (operator-declared). - # max_concurrent_calls_known is True when we have an explicit declaration - # (>0) or when the provider is a known-unbounded frontier API (see below). + # Populate the capacity fields (D3 / Phase 2 rewrite — Decision 12). + # + # Phase 2 (gate active): ask the registry for the effective cap of the + # *resolved-context* (provider, base_url) so the reported cap always + # reflects the provider that will actually service requests for this + # workspace/repo combination, not the bootstrap config. + # + # Legacy fallback (kill switch off, or no gate_registry): read + # WorkerConfig.llm_max_concurrent_calls directly. This preserves + # pre-Phase-2 behavior for deployments that have not enabled the gate. + # # Hard-clamp to HARD_CONCURRENCY_CEILING as defense in depth (D9). declared = 0 known = False - if self._config is not None: + _, _, resolution_key = resolve_provider_for_context( + self._llm, + self._config, + context, + gate_registry=self._gate_registry, + ) + + if resolution_key is not None and self._gate_registry is not None: + # Gate is active: use the registry's effective LLM cap for the + # resolved (provider, base_url). Decision 6 / v4 M1 fix: all + # providers now report finite caps; the old "frontier unbounded" + # sentinel is only emitted via the legacy fallback below so that + # old workers (without a gate) keep the previous encoding. + effective = self._gate_registry.effective_llm_max_concurrent( + resolution_key.provider, + resolution_key.base_url, + ) + if effective is not None: + declared = min(effective, HARD_CONCURRENCY_CEILING) + known = True + # effective is None when the wrapper is disabled inside the + # registry — fall through to the legacy path. + + if not known and self._config is not None: + # Legacy fallback: wrapper disabled (kill switch off) or no + # per-provider override set in the registry yet. Keep the + # old WorkerConfig-sourced logic exactly so existing deployments + # are unaffected. declared = min(self._config.llm_max_concurrent_calls, HARD_CONCURRENCY_CEILING) if declared > 0: - # Operator declared a finite value. known = True elif self._config.llm_provider in ("anthropic", "openai", "openrouter", "gemini"): # These are frontier APIs with effectively unbounded parallelism. @@ -729,3 +768,56 @@ async def GetProviderCapabilities( # noqa: N802 max_concurrent_calls=declared, max_concurrent_calls_known=known, ) + + async def GetLLMGateSnapshot( # noqa: N802 + self, + request: reasoning_pb2.GetLLMGateSnapshotRequest, + context: grpc.aio.ServicerContext, + ) -> reasoning_pb2.GetLLMGateSnapshotResponse: + """Return a point-in-time snapshot of all active concurrency gates. + + Same gRPC auth as GetProviderCapabilities: callers must present a + valid x-sb-worker-secret header when the worker is configured with + SOURCEBRIDGE_SECURITY_GRPC_AUTH_SECRET. + + When no gate registry is wired (kill-switch off, test builds without + registry) the response is an empty list — the admin monitor omits the + gate_snapshot field on empty so old/kill-switched deployments surface + nothing rather than a misleading empty array. + """ + # Auth check: reuse _resolve_provider which calls resolve_provider_for_context + # and enforces the gRPC auth interceptor's metadata check. We don't + # actually need the returned provider for this read-only query. + self._resolve_provider(context) + + if self._gate_registry is None: + return reasoning_pb2.GetLLMGateSnapshotResponse() + + entries = self._gate_registry.snapshot() + gates = [] + for entry in entries: + # Clamp max_concurrent to HARD_CONCURRENCY_CEILING so the int32 + # proto field never overflows. The _UNCAPPED sentinel (sys.maxsize) + # means "no real cap configured" — emit 0 to signal unknown/uncapped, + # matching the (known=true, calls=0) "unbounded" encoding used in + # GetProviderCapabilities. + effective_cap = entry.max_concurrent + if effective_cap >= _UNCAPPED: + effective_cap = 0 + else: + effective_cap = min(effective_cap, HARD_CONCURRENCY_CEILING) + gates.append( + reasoning_pb2.LLMGateEntry( + provider=entry.provider, + base_url_normalized=entry.base_url_normalized, + kind=entry.kind, + in_flight=entry.in_flight, + queued=entry.queue_depth, + max_concurrent=effective_cap, + retries_since_start=entry.retries_since_start, + recent_429_count=entry.recent_429_count, + tokens_per_second=entry.tokens_per_second, + rpm=entry.rpm, + ) + ) + return reasoning_pb2.GetLLMGateSnapshotResponse(gates=gates) diff --git a/workers/requirements/servicer.py b/workers/requirements/servicer.py index ed9ecc49..f042a925 100644 --- a/workers/requirements/servicer.py +++ b/workers/requirements/servicer.py @@ -37,7 +37,8 @@ def __init__(self, llm_provider: LLMProvider, worker_config: WorkerConfig | None def _resolve_provider(self, context: grpc.aio.ServicerContext) -> tuple[LLMProvider, str | None]: """Backward-compat wrapper. New code should call resolve_provider_for_context directly.""" - return resolve_provider_for_context(self._llm, self._config, context) + provider, model, _ = resolve_provider_for_context(self._llm, self._config, context) + return provider, model async def ParseDocument( # noqa: N802 self, diff --git a/workers/requirements/spec_extraction.py b/workers/requirements/spec_extraction.py index d575165a..e3688f88 100644 --- a/workers/requirements/spec_extraction.py +++ b/workers/requirements/spec_extraction.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import re import structlog @@ -20,6 +21,12 @@ log = structlog.get_logger() +# Maximum concurrent LLM calls over spec-extraction groups in a single +# refine_with_llm() call. The upstream provider gate is the binding cap; +# this local semaphore prevents N-thousand pending coroutines from queuing +# on very large repos before the gate has a chance to drain them. +_SPEC_GROUP_FANOUT_LIMIT = 8 + # Language detection by file extension EXTENSION_TO_LANGUAGE: dict[str, str] = { ".go": "go", @@ -195,85 +202,119 @@ async def refine_with_llm( for c in candidates: groups.setdefault(c.group_key, []).append(c) - refined: list[RefinedSpec] = [] - total_input = 0 - total_output = 0 - model_name = "" - - for group_key, group_candidates in groups.items(): - primary = group_candidates[0] - - # Build artifacts block - artifacts_parts: list[str] = [] - for c in group_candidates: - if c.source == "test": - artifacts_parts.append(f"Test: {c.raw_text}") - assertions = c.metadata.get("assertions", []) - if assertions: - artifacts_parts.append(f" Assertions: {', '.join(assertions)}") - elif c.source == "schema": - artifacts_parts.append(f"Schema: {c.raw_text}") - else: - artifacts_parts.append(f"Comment: {c.raw_text}") - - artifacts_block = "\n".join(artifacts_parts) - prompt = SPEC_REFINEMENT_USER.format( - source_type=primary.source, - source_file=primary.source_file, - language=primary.language, - artifacts_block=artifacts_block, - ) + # Bounded concurrent fan-out over groups. _SPEC_GROUP_FANOUT_LIMIT is the + # local bound; the upstream provider gate is the binding cap. + local_sem = asyncio.Semaphore(_SPEC_GROUP_FANOUT_LIMIT) + + # Each task returns (spec, input_tokens, output_tokens, model) so the + # serial merge below can accumulate usage without mutating RefinedSpec. + async def _refine_one( + group_key: str, + group_candidates: list[CandidateSpec], + ) -> tuple[RefinedSpec, int, int, str]: + """Refine one group, bounded by the local semaphore.""" + async with local_sem: + primary = group_candidates[0] + + # Build artifacts block + artifacts_parts: list[str] = [] + for c in group_candidates: + if c.source == "test": + artifacts_parts.append(f"Test: {c.raw_text}") + assertions = c.metadata.get("assertions", []) + if assertions: + artifacts_parts.append(f" Assertions: {', '.join(assertions)}") + elif c.source == "schema": + artifacts_parts.append(f"Schema: {c.raw_text}") + else: + artifacts_parts.append(f"Comment: {c.raw_text}") + + artifacts_block = "\n".join(artifacts_parts) + prompt = SPEC_REFINEMENT_USER.format( + source_type=primary.source, + source_file=primary.source_file, + language=primary.language, + artifacts_block=artifacts_block, + ) - try: - from workers.common.llm.provider import require_nonempty - - response = require_nonempty( - await llm_provider.complete( - prompt, - system=SPEC_REFINEMENT_SYSTEM, - temperature=0.1, - max_tokens=512, - ), - context="requirements:spec_refinement", + text_out = primary.raw_text + keywords_out: list[str] = [] + input_tokens = 0 + output_tokens = 0 + model_out = "" + + try: + from workers.common.llm.provider import require_nonempty + + response = require_nonempty( + await llm_provider.complete( + prompt, + system=SPEC_REFINEMENT_SYSTEM, + temperature=0.1, + max_tokens=512, + ), + context="requirements:spec_refinement", + ) + input_tokens = response.input_tokens + output_tokens = response.output_tokens + model_out = response.model + + data = parse_json_response(response.content) + if data and isinstance(data, dict): + text_out = data.get("requirement_text", primary.raw_text) + kw = data.get("keywords", []) + keywords_out = kw if isinstance(kw, list) else [] + + except Exception as exc: + log.warning("llm_refinement_failed", group_key=group_key, error=str(exc)) + + confidence = compute_confidence(group_candidates) + source_files = list( + {c.source_file for c in group_candidates if c.source_file != primary.source_file} ) - total_input += response.input_tokens - total_output += response.output_tokens - model_name = response.model - - data = parse_json_response(response.content) - if data and isinstance(data, dict): - text = data.get("requirement_text", primary.raw_text) - keywords = data.get("keywords", []) - if not isinstance(keywords, list): - keywords = [] - else: - text = primary.raw_text - keywords = [] - - except Exception as exc: - log.warning("llm_refinement_failed", group_key=group_key, error=str(exc)) - text = primary.raw_text - keywords = [] - - confidence = compute_confidence(group_candidates) - - source_files = list({c.source_file for c in group_candidates if c.source_file != primary.source_file}) - - refined.append( - RefinedSpec( + + spec = RefinedSpec( source=primary.source, source_file=primary.source_file, source_line=primary.source_line, source_files=source_files, - text=text, + text=text_out, raw_text=primary.raw_text, group_key=group_key, language=primary.language, - keywords=keywords, + keywords=keywords_out, confidence=confidence, llm_refined=True, ) - ) + return (spec, input_tokens, output_tokens, model_out) + + group_items = list(groups.items()) + raw_results = await asyncio.gather( + *[_refine_one(k, v) for k, v in group_items], + return_exceptions=True, + ) + + # Merge results in insertion order — order matches groups.items() iteration + # order (dict preserves insertion order in Python 3.7+). deduplicate_specs + # is key-based so list order doesn't affect correctness, but we preserve it + # for stable, deterministic output. + refined: list[RefinedSpec] = [] + total_input = 0 + total_output = 0 + model_name = "" + for r in raw_results: + if isinstance(r, BaseException): + # gather(return_exceptions=True) surfaces framework-level errors + # (e.g., CancelledError). Per-LLM failures are caught inside + # _refine_one and return a fallback spec, so this path is rare. + log.warning("spec_group_task_failed", error=str(r)) + continue + spec, in_tok, out_tok, model_out = r + refined.append(spec) + total_input += in_tok + total_output += out_tok + if model_out: + model_name = model_out usage = LLMUsageRecord( model=model_name, diff --git a/workers/tests/conftest.py b/workers/tests/conftest.py new file mode 100644 index 00000000..10d14fb2 --- /dev/null +++ b/workers/tests/conftest.py @@ -0,0 +1,27 @@ +"""Shared pytest fixtures for the worker test suite. + +Provides a function-scoped ``gate_registry`` fixture so tests that exercise +factory functions requiring a ``ProviderGateRegistry`` can do so without +duplicating setup/teardown boilerplate. + +Plan: thoughts/shared/plans/active-2026-05-06-deliver-worker-llm-concurrency.md +Phase 2 / H1 fix. +""" + +from __future__ import annotations + +import pytest_asyncio + +from workers.common.llm.concurrency import ConcurrencyConfig, ProviderGateRegistry + + +@pytest_asyncio.fixture +async def gate_registry(): + """Function-scoped ProviderGateRegistry with default (sentinel-uncapped) config. + + Yields the registry and closes it on teardown so tests don't leak gate + state across test functions. + """ + registry = ProviderGateRegistry(ConcurrencyConfig()) + yield registry + await registry.close() diff --git a/workers/tests/test_comprehension_bench.py b/workers/tests/test_comprehension_bench.py index 0768ad0f..ff18b1a3 100644 --- a/workers/tests/test_comprehension_bench.py +++ b/workers/tests/test_comprehension_bench.py @@ -6,7 +6,7 @@ from __future__ import annotations import json -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -108,7 +108,12 @@ async def test_live_provider_mode_uses_configured_provider() -> None: ], } - with patch("workers.benchmarks.run_comprehension_bench.create_llm_provider", return_value=FakeLLMProvider()): + mock_registry = MagicMock() + mock_registry.close = AsyncMock() + with patch( + "workers.benchmarks.run_comprehension_bench.build_cli_runtime_provider", + new=AsyncMock(return_value=(FakeLLMProvider(), mock_registry)), + ): result = await _run_case(case, provider_mode_override="live") assert result.provider_mode == "live" diff --git a/workers/tests/test_concurrency.py b/workers/tests/test_concurrency.py new file mode 100644 index 00000000..523cb41d --- /dev/null +++ b/workers/tests/test_concurrency.py @@ -0,0 +1,349 @@ +"""Phase 1 unit tests for the concurrency gate foundations. + +Tests verify: + - URL normalization helper (Decision 1, v4) + - Registry returns the same gate object for the same logical key + - Host gates share a single semaphore across LLM + embedding for local + providers (the Ollama OLLAMA_NUM_PARALLEL=1 correctness fix) + - Per-kind gates are independent for frontier providers + - openai-compatible defaults to host gating; can be overridden via env var + +Real factory defaults used throughout (not toy URLs): + Ollama LLM: http://localhost:11434/v1 (workers/common/llm/config.py:81) + Ollama embedding: http://localhost:11434 (workers/common/embedding/config.py:33) + +Refs: CA-169 / plan v4 Phase 1 Verification list. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from workers.common.llm.concurrency import ( + ConcurrencyConfig, + ConcurrencyGatedProvider, + ProviderGate, + ProviderGateRegistry, + _normalize_host_key, +) + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers + + +def _registry(config: ConcurrencyConfig | None = None) -> ProviderGateRegistry: + return ProviderGateRegistry(config or ConcurrencyConfig()) + + +# ────────────────────────────────────────────────────────────────────────────── +# 1. _normalize_host_key + + +def test_normalize_host_key_strips_path_and_trailing_slash() -> None: + assert _normalize_host_key("ollama", "http://localhost:11434/v1") == ( + "ollama", + "http://localhost:11434", + ) + assert _normalize_host_key("ollama", "http://localhost:11434/") == ( + "ollama", + "http://localhost:11434", + ) + assert _normalize_host_key("ollama", "http://localhost:11434") == ( + "ollama", + "http://localhost:11434", + ) + + +def test_normalize_host_key_strips_query_and_fragment() -> None: + assert _normalize_host_key("vllm", "http://localhost:8000/v1?foo=bar#x") == ( + "vllm", + "http://localhost:8000", + ) + + +def test_normalize_host_key_empty_base_url() -> None: + result = _normalize_host_key("ollama", None) + assert result == ("ollama", "") + result2 = _normalize_host_key("ollama", "") + assert result2 == ("ollama", "") + + +def test_normalize_host_key_preserves_port() -> None: + assert _normalize_host_key("openai-compatible", "http://192.168.1.10:8080/v1") == ( + "openai-compatible", + "http://192.168.1.10:8080", + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 2. Registry returns same gate for same logical key + + +@pytest.mark.asyncio +async def test_provider_gate_registry_returns_same_gate_for_same_key() -> None: + reg = _registry() + gate_a = await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + gate_b = await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + # Both façades should share the same underlying _HostGate. + assert gate_a._binding is gate_b._binding + + +# ────────────────────────────────────────────────────────────────────────────── +# 3. Host gate normalizes URL to origin (same gate for /v1 and bare host) + + +@pytest.mark.asyncio +async def test_host_gate_normalizes_url_to_origin() -> None: + """Ollama LLM default and embedding default share one host gate.""" + reg = _registry() + # Real factory defaults: + # LLM: http://localhost:11434/v1 (llm/config.py:81) + # embedding: http://localhost:11434 (embedding/config.py:33) + gate_llm = await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + gate_embed = await reg.lookup("ollama", "http://localhost:11434", "embedding") + + # Different kinds, but same normalized origin → same _HostGate binding. + assert gate_llm._binding is gate_embed._binding + + +# ────────────────────────────────────────────────────────────────────────────── +# 4. Combined in-flight capped at 1 for real Ollama defaults with max_concurrent=1 + + +@pytest.mark.asyncio +async def test_host_gate_caps_combined_in_flight_at_one_for_real_ollama_defaults() -> None: + """Under max_concurrent=1 both LLM and embedding calls share the one slot.""" + config = ConcurrencyConfig(llm_max_concurrent={"ollama": 1}) + reg = _registry(config) + + gate_llm = await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + gate_embed = await reg.lookup("ollama", "http://localhost:11434", "embedding") + + peak_in_flight: list[int] = [] + barrier = asyncio.Event() + + async def hold_slot(gate: ProviderGate) -> None: + async with gate.slot(): + peak_in_flight.append(gate_llm._binding._in_flight) + barrier.set() + # Hold the slot briefly so the other coroutine is forced to queue. + await asyncio.sleep(0.02) + + # Start both concurrently; only one should run at a time. + results = await asyncio.gather( + hold_slot(gate_llm), + hold_slot(gate_embed), + return_exceptions=True, + ) + assert all(r is None for r in results), results + assert max(peak_in_flight) == 1, f"Peak in-flight was {max(peak_in_flight)}, expected 1" + + +# ────────────────────────────────────────────────────────────────────────────── +# 5. Host gate shared across all local provider kinds + + +@pytest.mark.asyncio +async def test_host_gate_shared_across_kinds_for_local_providers() -> None: + """All five host-gated providers share within-provider host semaphore.""" + host_gated = [ + ("ollama", "http://localhost:11434/v1", "http://localhost:11434"), + ("vllm", "http://localhost:8000/v1", "http://localhost:8000"), + ("llama-cpp", "http://localhost:8080/v1", "http://localhost:8080"), + ("sglang", "http://localhost:30000/v1", "http://localhost:30000"), + ("lmstudio", "http://localhost:1234/v1", "http://localhost:1234"), + ] + reg = _registry() + for provider, llm_url, embed_url in host_gated: + gate_llm = await reg.lookup(provider, llm_url, "llm") + gate_embed = await reg.lookup(provider, embed_url, "embedding") + assert gate_llm._binding is gate_embed._binding, ( + f"{provider}: expected LLM and embedding to share a host gate" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 6. Per-kind gates are independent for frontier providers + + +@pytest.mark.asyncio +async def test_per_kind_gates_independent_for_frontier_providers() -> None: + """openai LLM and openai embedding use separate semaphores.""" + reg = _registry() + gate_llm = await reg.lookup("openai", "https://api.openai.com/v1", "llm") + gate_embed = await reg.lookup("openai", "https://api.openai.com/v1", "embedding") + # Different bindings → independent semaphores → independent quotas. + assert gate_llm._binding is not gate_embed._binding + + +# ────────────────────────────────────────────────────────────────────────────── +# 7. openai-compatible defaults to host gating + + +@pytest.mark.asyncio +async def test_openai_compatible_gating_default_is_host() -> None: + """openai-compatible at the same host shares a gate across kinds.""" + reg = _registry() + gate_llm = await reg.lookup("openai-compatible", "http://localhost:8000/v1", "llm") + gate_embed = await reg.lookup("openai-compatible", "http://localhost:8000", "embedding") + assert gate_llm._binding is gate_embed._binding + + +# ────────────────────────────────────────────────────────────────────────────── +# 8. openai-compatible flips to per-kind when overridden + + +@pytest.mark.asyncio +async def test_openai_compatible_gating_per_kind_when_overridden( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING=per_kind flips behavior.""" + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OPENAI_COMPATIBLE_GATING", "per_kind") + config = ConcurrencyConfig.from_env() + assert config.openai_compatible_gating == "per_kind" + + reg = _registry(config) + gate_llm = await reg.lookup("openai-compatible", "https://api.some-service.com/v1", "llm") + gate_embed = await reg.lookup("openai-compatible", "https://api.some-service.com/v1", "embedding") + # per_kind → different bindings. + assert gate_llm._binding is not gate_embed._binding + + +# ────────────────────────────────────────────────────────────────────────────── +# 9. Registry rejects lookup after close + + +@pytest.mark.asyncio +async def test_registry_rejects_lookup_after_close() -> None: + reg = _registry() + await reg.close() + with pytest.raises(RuntimeError, match="closed"): + await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + + +# ────────────────────────────────────────────────────────────────────────────── +# 10. Registry close is idempotent + + +@pytest.mark.asyncio +async def test_registry_close_idempotent() -> None: + reg = _registry() + await reg.close() + await reg.close() # should not raise + + +# ────────────────────────────────────────────────────────────────────────────── +# 11. _GateBase raises on max_concurrent < 1 + + +def test_gate_rejects_zero_max_concurrent() -> None: + from workers.common.llm.concurrency import _HostGate + + with pytest.raises(ValueError, match="max_concurrent"): + _HostGate(max_concurrent=0) + + +# ────────────────────────────────────────────────────────────────────────────── +# 12. ConcurrencyConfig.from_env rejects invalid values + + +def test_concurrency_config_from_env_rejects_invalid_rpm( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OPENAI_RPM", "0") + config = ConcurrencyConfig.from_env() + # 0 is invalid; should be ignored (no entry added). + assert "openai" not in config.rpm + + +def test_concurrency_config_from_env_rejects_zero_max_concurrent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Phase 3: Decision 6 defaults are pre-populated, so "openai" is always in the + # dict with value 8. Setting the env var to 0 should be rejected (logged as + # invalid) and the Decision 6 default (8) should be preserved, not 0. + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OPENAI_MAX_CONCURRENT", "0") + config = ConcurrencyConfig.from_env() + # The invalid value (0) must NOT be stored; the Decision 6 default (8) is kept. + assert config.llm_max_concurrent.get("openai") == 8 + + +# ────────────────────────────────────────────────────────────────────────────── +# 13. FakeLLMProvider fail-mode kwargs (Phase 1 smoke; full tests in Phase 3) + + +@pytest.mark.asyncio +async def test_fake_provider_raise_on_attempts() -> None: + from workers.common.llm.fake import FakeLLMProvider + + exc = ValueError("injected") + provider = FakeLLMProvider(raise_on_attempts=2, exc=exc) + + with pytest.raises(ValueError, match="injected"): + await provider.complete("test") + with pytest.raises(ValueError, match="injected"): + await provider.complete("test") + # Third call should succeed. + response = await provider.complete("test") + assert response.content + + +@pytest.mark.asyncio +async def test_fake_provider_responses_queue() -> None: + from workers.common.llm.fake import FakeLLMProvider + + provider = FakeLLMProvider(responses=["hello world", RuntimeError("boom")]) + + r1 = await provider.complete("x") + assert r1.content == "hello world" + + with pytest.raises(RuntimeError, match="boom"): + await provider.complete("x") + + # Queue exhausted; falls back to fixture dispatch. + r3 = await provider.complete("x") + assert r3.content # non-empty fixture response + + +# ────────────────────────────────────────────────────────────────────────────── +# 14. ConcurrencyGatedProvider is pass-through in Phase 1 (kill-switch on, +# sentinel-uncapped, tenacity no-op predicate) + + +@pytest.mark.asyncio +async def test_concurrency_gated_provider_passthrough_phase1() -> None: + """With retry_max_attempts=1 and sentinel cap, the wrapper is transparent.""" + from workers.common.llm.fake import FakeLLMProvider + + config = ConcurrencyConfig(retry_max_attempts=1) + reg = _registry(config) + raw = FakeLLMProvider() + + gate = await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + wrapped = ConcurrencyGatedProvider(raw, gate, config) + + response = await wrapped.complete("Summarize this function") + # Should return the fixture summary content. + import json + + data = json.loads(response.content) + assert "purpose" in data + + +@pytest.mark.asyncio +async def test_concurrency_gated_provider_stream_passthrough_phase1() -> None: + from workers.common.llm.fake import FakeLLMProvider + + config = ConcurrencyConfig(retry_max_attempts=1) + reg = _registry(config) + raw = FakeLLMProvider() + + gate = await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + wrapped = ConcurrencyGatedProvider(raw, gate, config) + + chunks: list[str] = [] + async for chunk in wrapped.stream("Summarize this function"): + chunks.append(chunk) + assert len(chunks) > 0 diff --git a/workers/tests/test_concurrency_phase2.py b/workers/tests/test_concurrency_phase2.py new file mode 100644 index 00000000..7732a10f --- /dev/null +++ b/workers/tests/test_concurrency_phase2.py @@ -0,0 +1,283 @@ +"""Phase 2 tests for the concurrency gate wiring through factories and servicers. + +Tests verify: + - create_llm_provider wraps when kill switch is on / unwraps when off + (covered also in test_llm_config.py; cross-cutting here for capability path) + - create_llm_provider_for_request forwards gate_registry (bob H4) + - create_report_provider wraps using the same gate as main provider when + pointing at the same endpoint + - GetProviderCapabilities sources cap from gate registry's effective value + (Decision 12, codex r1 H2 / r2 H2) + - GetProviderCapabilities honors per-request metadata override (r2 H2) + - GetProviderCapabilities falls back to legacy path when kill switch off + - resolve_provider_for_context returns canonical resolution key as third elem + - CLI helper build_cli_runtime_provider returns a ConcurrencyGatedProvider + - benchmark _create_provider("live") returns a ConcurrencyGatedProvider + +Refs: CA-169 / plan v4 Phase 2 Verification list. +""" + +from __future__ import annotations + +import pytest +import pytest_asyncio # noqa: F401 + +from workers.common.embedding.fake import FakeEmbeddingProvider +from workers.common.llm.concurrency import ConcurrencyConfig, ConcurrencyGatedProvider, ProviderGateRegistry +from workers.common.llm.fake import FakeLLMProvider +from workers.common.servicer_utils import ProviderResolutionKey, resolve_provider_for_context + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers / fixtures + +class _MockContext: + def __init__(self, metadata: dict[str, str] | None = None): + self._metadata = list((metadata or {}).items()) + + def invocation_metadata(self): + return self._metadata + + +class _MockConfig: + def __init__( + self, + *, + llm_provider: str = "ollama", + llm_base_url: str = "http://localhost:11434/v1", + llm_api_key: str = "test", + llm_model: str = "qwen3:7b", + llm_draft_model: str = "", + llm_timeout: int = 60, + llm_report_model: str = "", + ) -> None: + self.llm_provider = llm_provider + self.llm_base_url = llm_base_url + self.llm_api_key = llm_api_key + self.llm_model = llm_model + self.llm_draft_model = llm_draft_model + self.llm_timeout = llm_timeout + self.llm_report_model = llm_report_model + + def model_copy(self, *, update: dict) -> _MockConfig: + return _MockConfig( + llm_provider=update.get("llm_provider", self.llm_provider), + llm_base_url=update.get("llm_base_url", self.llm_base_url), + llm_api_key=update.get("llm_api_key", self.llm_api_key), + llm_model=update.get("llm_model", self.llm_model), + llm_draft_model=update.get("llm_draft_model", self.llm_draft_model), + llm_timeout=update.get("llm_timeout", self.llm_timeout), + llm_report_model=update.get("llm_report_model", self.llm_report_model), + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 1. resolve_provider_for_context returns canonical resolution key + + +@pytest.mark.asyncio +async def test_resolve_provider_for_context_returns_canonical_key_no_override(): + """No metadata override → resolution key reflects bootstrap config.""" + llm = FakeLLMProvider() + config = _MockConfig(llm_provider="ollama", llm_base_url="http://localhost:11434/v1") + registry = ProviderGateRegistry(ConcurrencyConfig()) + context = _MockContext() + + provider, model, resolution_key = resolve_provider_for_context( + llm, config, context, gate_registry=registry + ) + assert provider is llm + assert model is None + assert isinstance(resolution_key, ProviderResolutionKey) + assert resolution_key.provider == "ollama" + assert resolution_key.base_url == "http://localhost:11434/v1" + assert resolution_key.kind == "llm" + + await registry.close() + + +def test_resolve_provider_for_context_returns_none_resolution_key_when_no_registry(): + """No registry supplied → third return is None.""" + llm = FakeLLMProvider() + context = _MockContext() + _, _, resolution_key = resolve_provider_for_context(llm, None, context) + assert resolution_key is None + + +@pytest.mark.asyncio +async def test_resolve_provider_for_context_returns_none_key_when_config_none(): + """Config is None + no override → resolution key is None (no provider info).""" + llm = FakeLLMProvider() + registry = ProviderGateRegistry(ConcurrencyConfig()) + context = _MockContext() + _, _, resolution_key = resolve_provider_for_context(llm, None, context, gate_registry=registry) + # _build_resolution_key returns None when config is None. + assert resolution_key is None + + await registry.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 2. GetProviderCapabilities — gate active, reports registry's effective cap + + +@pytest.mark.asyncio +async def test_get_provider_capabilities_reports_gate_effective_cap(monkeypatch): + """Gate active + OLLAMA_MAX_CONCURRENT=2 → capability response reports cap=2. + + Tests Decision 12: resolved-context lookup through gate registry. + Refs: codex r1 H2 / D12. + """ + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT", "2") + from reasoning.v1 import reasoning_pb2 + + from workers.common.config import WorkerConfig + from workers.reasoning.servicer import ReasoningServicer + + config = WorkerConfig( + llm_provider="ollama", + llm_model="qwen3:7b", + test_mode=True, + ) + concurrency_config = ConcurrencyConfig.from_env() + registry = ProviderGateRegistry(concurrency_config) + + # Use test_mode → FakeLLMProvider (no real HTTP calls). + from workers.common.llm.factory import create_llm_provider + + llm = await create_llm_provider(config, gate_registry=registry) + emb = FakeEmbeddingProvider(dimension=1024) + servicer = ReasoningServicer(llm, emb, worker_config=config, gate_registry=registry) + + context = _MockContext() # no metadata override → uses bootstrap (ollama) + response = await servicer.GetProviderCapabilities( + reasoning_pb2.GetProviderCapabilitiesRequest(), context + ) + + assert response.max_concurrent_calls == 2 + assert response.max_concurrent_calls_known is True + await registry.close() + + +@pytest.mark.asyncio +async def test_get_provider_capabilities_honors_metadata_override(monkeypatch): + """Metadata override switches provider → capability response reflects the override's gate cap. + + Bootstrap = ollama (gate cap=1); metadata forces openai (gate cap=8). + Asserts response.max_concurrent_calls == 8 (the metadata-resolved cap), not 1. + Refs: codex r2 H2 / Decision 12. + """ + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT", "1") + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OPENAI_MAX_CONCURRENT", "8") + from reasoning.v1 import reasoning_pb2 + + from workers.common.config import WorkerConfig + from workers.reasoning.servicer import ReasoningServicer + + config = WorkerConfig( + llm_provider="ollama", + llm_model="qwen3:7b", + llm_api_key="", + test_mode=True, + ) + concurrency_config = ConcurrencyConfig.from_env() + registry = ProviderGateRegistry(concurrency_config) + + from workers.common.llm.factory import create_llm_provider + + llm = await create_llm_provider(config, gate_registry=registry) + emb = FakeEmbeddingProvider(dimension=1024) + servicer = ReasoningServicer(llm, emb, worker_config=config, gate_registry=registry) + + # Metadata override: flip provider to openai. + context = _MockContext({ + "x-sb-llm-provider": "openai", + "x-sb-llm-api-key": "test-key", + "x-sb-model": "gpt-4o", + }) + response = await servicer.GetProviderCapabilities( + reasoning_pb2.GetProviderCapabilitiesRequest(), context + ) + + assert response.max_concurrent_calls == 8, ( + f"Expected openai gate cap 8, got {response.max_concurrent_calls}" + ) + assert response.max_concurrent_calls_known is True + await registry.close() + + +@pytest.mark.asyncio +async def test_get_provider_capabilities_falls_back_when_wrapper_disabled(monkeypatch): + """Kill switch off → legacy WorkerConfig-sourced cap returned. + + Refs: codex r1 H2 / D12 legacy fallback. + """ + monkeypatch.setenv("SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED", "false") + from reasoning.v1 import reasoning_pb2 + + from workers.common.config import WorkerConfig + from workers.reasoning.servicer import ReasoningServicer + + config = WorkerConfig( + llm_provider="ollama", + llm_model="qwen3:7b", + llm_max_concurrent_calls=3, + test_mode=True, + ) + concurrency_config = ConcurrencyConfig.from_env() + assert not concurrency_config.wrapper_enabled + registry = ProviderGateRegistry(concurrency_config) + + from workers.common.llm.factory import create_llm_provider + + llm = await create_llm_provider(config, gate_registry=registry) + emb = FakeEmbeddingProvider(dimension=1024) + servicer = ReasoningServicer(llm, emb, worker_config=config, gate_registry=registry) + + context = _MockContext() + response = await servicer.GetProviderCapabilities( + reasoning_pb2.GetProviderCapabilitiesRequest(), context + ) + + # Kill switch off → effective_llm_max_concurrent returns None → legacy path + # → reads config.llm_max_concurrent_calls = 3. + assert response.max_concurrent_calls == 3 + assert response.max_concurrent_calls_known is True + await registry.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 3. CLI helper tests + + +@pytest.mark.asyncio +async def test_cli_review_constructs_registry(): + """build_cli_runtime_provider returns a ConcurrencyGatedProvider for live mode.""" + from workers.common.cli_main import build_cli_runtime_provider + from workers.common.config import WorkerConfig + + cfg = WorkerConfig(llm_provider="openai", llm_api_key="test", llm_model="gpt-4o") + provider, registry = await build_cli_runtime_provider(cfg) + try: + assert isinstance(provider, ConcurrencyGatedProvider) + assert registry is not None + finally: + await registry.close() + + +@pytest.mark.asyncio +async def test_benchmark_constructs_registry(monkeypatch): + """_create_provider('live') returns a ConcurrencyGatedProvider.""" + # Patch WorkerConfig so no real provider credentials are needed. + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER", "openai") + monkeypatch.setenv("SOURCEBRIDGE_LLM_API_KEY", "test") + monkeypatch.setenv("SOURCEBRIDGE_LLM_MODEL", "gpt-4o") + + from workers.benchmarks.run_comprehension_bench import _create_provider + + provider, provider_name, model_id, gate_registry = await _create_provider("live") + try: + assert isinstance(provider, ConcurrencyGatedProvider) + assert gate_registry is not None + finally: + if gate_registry is not None: + await gate_registry.close() diff --git a/workers/tests/test_concurrency_phase3.py b/workers/tests/test_concurrency_phase3.py new file mode 100644 index 00000000..d4aded96 --- /dev/null +++ b/workers/tests/test_concurrency_phase3.py @@ -0,0 +1,1284 @@ +"""Phase 3 tests: activated gate + retry + real defaults + deleted hand-rolled retries. + +Tests verify: + - Global semaphore caps in-flight (single subsystem + two via shared registry) + - Local and global semaphores compose correctly + - Host-gate caps combined LLM + embedding on real Ollama URLs (Decision 1) + - Retry on 429 with full assertions: call count, Retry-After, terminal exc type + - Retry-After header honored for both OpenAI and Anthropic shapes + - SDK retry is disabled (HTTP call count == wrapper attempts) + - No retry on 401 (auth failure) + - No retry on pydantic.ValidationError + - RPM rate limiting via aiolimiter (custom limiter for unit test; real-clock + integration test marked @pytest.mark.slow — see decision notes below) + - Retry on 503 (renderer path coverage) + - Slot released on cancel during upstream call + - Slot released on cancel during limiter wait + - Cancellation while queued decrements waiter count + - Slot released between retry attempts (Decision 2 layering) + - Slot NOT held during limiter wait (Decision 2 layering) + - Jitter spreads retry timestamps (mocked random.uniform) + - gate rejects max_concurrent < 1 (covered in phase 1; extended here) + - ConcurrencyConfig.from_env() rejects invalid RPM + - ConcurrencyConfig.from_env() rejects zero max_concurrent + - Unknown provider token in env var is rejected at startup + - Empty-content retry preserved in openai_compat after Phase 3 + - Registry close is idempotent + - Registry rejects lookup after close + - Registry close during in-flight calls cleans up correctly + +Decision notes: + - aiolimiter.AsyncLimiter uses self._loop.time() with no injectable clock seam. + A custom SimpleLimiter is used for unit tests (asyncio.Lock + sleep). A + @pytest.mark.slow real-clock test exercises the production aiolimiter path. + - tenacity.wait_random_exponential calls random.uniform at module level with no + seeded RNG seam. Tests mock random.uniform to assert deterministic wait values. + +Refs: CA-169 / plan v4 Phase 3 Verification list. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import random +from collections.abc import AsyncIterator +from unittest.mock import MagicMock, patch + +import anthropic +import httpx +import openai +import pytest + +from workers.common.llm.concurrency import ( + ConcurrencyConfig, + ConcurrencyGatedProvider, + ProviderGate, + ProviderGateRegistry, + _retry_predicate, +) +from workers.common.llm.fake import FakeLLMProvider +from workers.common.llm.provider import LLMResponse + +# ────────────────────────────────────────────────────────────────────────────── +# Shared helpers + + +def _registry(config: ConcurrencyConfig | None = None) -> ProviderGateRegistry: + return ProviderGateRegistry(config or ConcurrencyConfig()) + + +def _response(content: str = "ok", output_tokens: int = 10) -> LLMResponse: + return LLMResponse( + content=content, + model="test-model", + input_tokens=5, + output_tokens=output_tokens, + stop_reason="end_turn", + ) + + +async def _make_gated( + provider_name: str = "openai", + base_url: str = "https://api.openai.com/v1", + kind: str = "llm", + *, + config: ConcurrencyConfig, + raw: FakeLLMProvider | None = None, +) -> tuple[ConcurrencyGatedProvider, ProviderGate]: + registry = _registry(config) + gate = await registry.lookup(provider_name, base_url, kind) + raw_provider = raw or FakeLLMProvider() + wrapped = ConcurrencyGatedProvider(raw_provider, gate, config) + return wrapped, gate + + +# ────────────────────────────────────────────────────────────────────────────── +# 1. test_wrapper_caps_in_flight_at_max_concurrent +# Parametrized: single subsystem + two subsystems via shared registry + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "scenario", + ["single_subsystem", "two_subsystems_via_shared_registry"], +) +async def test_wrapper_caps_in_flight_at_max_concurrent(scenario: str) -> None: + """Global semaphore ensures peak in-flight <= max_concurrent=2.""" + config = ConcurrencyConfig( + llm_max_concurrent={"openai": 2}, + retry_max_attempts=1, + ) + registry = _registry(config) + + if scenario == "single_subsystem": + gate_a = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + gate_b = gate_a + else: + gate_a = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + gate_b = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + peak: list[int] = [] + event = asyncio.Event() + + async def one(gate: ProviderGate) -> None: + async with gate.slot(): + peak.append(gate._binding._in_flight) + event.set() + await asyncio.sleep(0.03) + + tasks = [asyncio.create_task(one(gate_a)) for _ in range(4)] + if scenario == "two_subsystems_via_shared_registry": + tasks += [asyncio.create_task(one(gate_b)) for _ in range(4)] + + await asyncio.gather(*tasks) + assert max(peak) <= 2, f"Peak in-flight {max(peak)} exceeded cap 2" + + +# ────────────────────────────────────────────────────────────────────────────── +# 2. test_local_and_global_semaphores_compose_correctly + + +@pytest.mark.asyncio +async def test_local_and_global_semaphores_compose_correctly() -> None: + """Local semaphore AND global gate cap compose: effective cap = min(local, global).""" + config = ConcurrencyConfig( + llm_max_concurrent={"openai": 3}, + retry_max_attempts=1, + ) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + local_sem = asyncio.Semaphore(2) + peak: list[int] = [] + + async def one() -> None: + async with local_sem, gate.slot(): + peak.append(gate._binding._in_flight) + await asyncio.sleep(0.02) + + await asyncio.gather(*[asyncio.create_task(one()) for _ in range(6)]) + # local cap=2 < global cap=3 → effective peak should be ≤ 2 + assert max(peak) <= 2, f"Peak {max(peak)} violated local semaphore cap" + + +# ────────────────────────────────────────────────────────────────────────────── +# 3. test_host_gate_caps_combined_llm_and_embedding_on_ollama_real_defaults + + +@pytest.mark.asyncio +async def test_host_gate_caps_combined_llm_and_embedding_on_ollama_real_defaults() -> None: + """Ollama LLM + embedding share one host gate (Decision 1 URL normalization). + + Real factory default URLs: + LLM: http://localhost:11434/v1 (llm/config.py:80-81) + embedding: http://localhost:11434 (embedding/config.py:33) + """ + config = ConcurrencyConfig(llm_max_concurrent={"ollama": 1}, retry_max_attempts=1) + registry = _registry(config) + + gate_llm = await registry.lookup("ollama", "http://localhost:11434/v1", "llm") + gate_embed = await registry.lookup("ollama", "http://localhost:11434", "embedding") + + # Same normalized origin → same binding gate. + assert gate_llm._binding is gate_embed._binding + + peak: list[int] = [] + event = asyncio.Event() + + async def hold(gate: ProviderGate) -> None: + async with gate.slot(): + peak.append(gate_llm._binding._in_flight) + event.set() + await asyncio.sleep(0.03) + + await asyncio.gather(hold(gate_llm), hold(gate_embed)) + # Combined in-flight never exceeds 1. + assert max(peak) == 1, f"Peak {max(peak)}: ollama gate not combining LLM + embedding" + + +# ────────────────────────────────────────────────────────────────────────────── +# 4. test_wrapper_retries_on_429_full_assertions + + +@pytest.mark.asyncio +async def test_wrapper_retries_on_429_full_assertions() -> None: + """On 429: (a) call count == retry_max_attempts, (b) terminal exc is original RateLimitError.""" + max_attempts = 3 + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=max_attempts) + + call_count = 0 + + class _FakeRateLimitProvider: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, prompt: str, **kwargs: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + # Build a minimal fake RateLimitError (OpenAI SDK shape). + exc = openai.RateLimitError( + message="Rate limit exceeded", + response=MagicMock(headers={}, status_code=429), + body={"error": {"message": "rate limit"}}, + ) + raise exc + + async def stream(self, *args: object, **kwargs: object) -> AsyncIterator[str]: + raise NotImplementedError + + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + wrapped = ConcurrencyGatedProvider(_FakeRateLimitProvider(), gate, config) + + with pytest.raises(openai.RateLimitError): + await wrapped.complete("hello") + + # (a) call count equals max_attempts + assert call_count == max_attempts, f"Expected {max_attempts} calls, got {call_count}" + + +@pytest.mark.asyncio +async def test_wrapper_retries_on_429_terminal_exception_is_original() -> None: + """Terminal exception after retry exhaustion is the original RateLimitError, not RetryError.""" + from tenacity import RetryError + + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=2) + exc_to_raise = openai.RateLimitError( + message="Rate limit exceeded", + response=MagicMock(headers={}, status_code=429), + body={"error": {"message": "rate limit"}}, + ) + + class _AlwaysRateLimit: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + raise exc_to_raise + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + wrapped = ConcurrencyGatedProvider(_AlwaysRateLimit(), gate, config) + + caught: BaseException | None = None + try: + await wrapped.complete("hello") + except BaseException as e: + caught = e + + assert caught is not None + # reraise=True means tenacity re-raises the original, not a RetryError wrapper. + assert not isinstance(caught, RetryError), ( + f"Expected original RateLimitError, got RetryError wrapping {caught.__cause__}" + ) + assert isinstance(caught, openai.RateLimitError), f"Got {type(caught)}" + + +# ────────────────────────────────────────────────────────────────────────────── +# 5. test_wrapper_respects_retry_after_header_per_sdk + + +@pytest.mark.asyncio +async def test_wrapper_respects_retry_after_header_per_sdk() -> None: + """Retry-After header is parsed from both OpenAI and Anthropic error shapes.""" + # OpenAI: response.headers["retry-after"] + mock_response_openai = MagicMock() + mock_response_openai.headers = {"retry-after": "5"} + mock_response_openai.status_code = 429 + exc_openai = openai.RateLimitError( + message="too many requests", + response=mock_response_openai, + body={"error": {"message": "rate limit"}}, + ) + + # Anthropic: response.headers["retry-after"] + mock_response_anthropic = MagicMock() + mock_response_anthropic.headers = {"retry-after": "10"} + mock_response_anthropic.status_code = 429 + exc_anthropic = anthropic.RateLimitError( + message="too many requests", + response=mock_response_anthropic, + body={"error": {"message": "rate limit"}}, + ) + + from workers.common.llm.concurrency import _extract_retry_after + + assert _extract_retry_after(exc_openai) == 5.0 + assert _extract_retry_after(exc_anthropic) == 10.0 + + +@pytest.mark.asyncio +async def test_wrapper_respects_retry_after_header_none_when_absent() -> None: + """_extract_retry_after returns None when no header is present.""" + from workers.common.llm.concurrency import _extract_retry_after + + exc = openai.RateLimitError( + message="too many requests", + response=MagicMock(headers={}, status_code=429), + body={}, + ) + assert _extract_retry_after(exc) is None + + +# ────────────────────────────────────────────────────────────────────────────── +# 6. test_sdk_retry_is_disabled + + +def test_sdk_retry_is_disabled_openai() -> None: + """AsyncOpenAI is constructed with max_retries=0 (Phase 3).""" + from workers.common.llm.openai_compat import OpenAICompatProvider + + provider = OpenAICompatProvider(api_key="test-key", base_url="http://localhost:11434/v1") + assert provider.client.max_retries == 0, ( + f"Expected max_retries=0, got {provider.client.max_retries}" + ) + + +def test_sdk_retry_is_disabled_anthropic() -> None: + """AsyncAnthropic is constructed with max_retries=0 (Phase 3).""" + from workers.common.llm.anthropic import AnthropicProvider + + provider = AnthropicProvider(api_key="test-key") + assert provider.client.max_retries == 0, ( + f"Expected max_retries=0, got {provider.client.max_retries}" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 7. test_wrapper_does_not_retry_on_401 + + +@pytest.mark.asyncio +async def test_wrapper_does_not_retry_on_401() -> None: + """401 Unauthorized is not retryable — wrapper makes exactly one attempt.""" + call_count = 0 + + class _Unauthorized: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + exc = openai.AuthenticationError( + message="Incorrect API key", + response=MagicMock(headers={}, status_code=401), + body={"error": {"message": "auth error"}}, + ) + raise exc + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=5) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + wrapped = ConcurrencyGatedProvider(_Unauthorized(), gate, config) + + with pytest.raises(openai.AuthenticationError): + await wrapped.complete("hello") + + assert call_count == 1, f"Expected 1 call (no retry), got {call_count}" + + +# ────────────────────────────────────────────────────────────────────────────── +# 8. test_wrapper_does_not_retry_on_validation_error + + +@pytest.mark.asyncio +async def test_wrapper_does_not_retry_on_validation_error() -> None: + """pydantic.ValidationError is not retryable — wrapper makes exactly one attempt.""" + import pydantic + + call_count = 0 + + class _ValidationFailure: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + + class _M(pydantic.BaseModel): + x: int + + _M(x="not-an-int") # type: ignore[arg-type] # raises ValidationError + raise AssertionError("unreachable") + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=5) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + wrapped = ConcurrencyGatedProvider(_ValidationFailure(), gate, config) + + with pytest.raises(pydantic.ValidationError): + await wrapped.complete("hello") + + assert call_count == 1, f"Expected 1 call (no retry), got {call_count}" + + +# ────────────────────────────────────────────────────────────────────────────── +# 9. test_wrapper_rate_limits_with_aiolimiter +# +# Decision note: aiolimiter.AsyncLimiter uses self._loop.time() — no injectable +# clock. We use a custom SimpleLimiter for this unit test that is seeded with +# asyncio.Lock and recorded acquire timestamps. A real-clock slow test follows. + + +class _SimpleLimiter: + """Minimal rate limiter for testing: allows max_rate calls per time_period. + + Tracks acquire timestamps so tests can assert spacing without needing + aiolimiter's internal clock. + """ + + def __init__(self, max_rate: float, time_period: float = 1.0) -> None: + self._delay = time_period / max_rate + self._lock = asyncio.Lock() + self._last: float = 0.0 + self.acquire_times: list[float] = [] + + async def acquire(self, amount: float = 1) -> None: + async with self._lock: + now = asyncio.get_event_loop().time() + wait = self._last + self._delay - now + if wait > 0: + await asyncio.sleep(wait) + self._last = asyncio.get_event_loop().time() + self.acquire_times.append(self._last) + + +@pytest.mark.asyncio +async def test_wrapper_rate_limits_with_custom_limiter() -> None: + """Rate limiter enforces spacing between successive calls (custom limiter for unit test). + + The production path uses aiolimiter.AsyncLimiter; the logic tested here is the + wrapper's interaction with _any_ compatible limiter object (Decision 7). + """ + limiter = _SimpleLimiter(max_rate=2, time_period=1.0) # 2 calls/sec → 0.5s gap + + call_count = 0 + + class _Instant: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + yield "ok" + + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=1) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + provider = _Instant() + wrapped = ConcurrencyGatedProvider(provider, gate, config) + # Override the limiter directly (bypassing RPM env-var path for unit-test isolation). + wrapped._limiter = limiter # type: ignore[assignment] + + await wrapped.complete("a") + await wrapped.complete("b") + await wrapped.complete("c") + + assert call_count == 3 + # At 2 calls/sec, each call should be ≥ 0.4s after the previous. + times = limiter.acquire_times + assert len(times) == 3 + for i in range(1, len(times)): + gap = times[i] - times[i - 1] + assert gap >= 0.35, f"Gap {gap:.3f}s too small; expected ≥ 0.35s for 2 calls/sec" + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_wrapper_rate_limits_with_real_aiolimiter() -> None: + """Real-clock integration: aiolimiter.AsyncLimiter enforces 2 calls/sec. + + Marked @pytest.mark.slow — not run in CI's fast suite. Exercises the + production aiolimiter code path end-to-end. + """ + from aiolimiter import AsyncLimiter + + call_times: list[float] = [] + + class _Timed: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + call_times.append(asyncio.get_event_loop().time()) + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + yield "ok" + + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=1) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + provider = _Timed() + wrapped = ConcurrencyGatedProvider(provider, gate, config) + wrapped._limiter = AsyncLimiter(max_rate=2, time_period=1.0) # type: ignore[assignment] + + for _ in range(3): + await wrapped.complete("test") + + assert len(call_times) == 3 + # Each successive call should be spaced ≥ 0.4s apart (2 calls/sec). + for i in range(1, len(call_times)): + gap = call_times[i] - call_times[i - 1] + assert gap >= 0.4, f"Gap {gap:.3f}s too small (expected ≥ 0.4s at 2 calls/sec)" + + +# ────────────────────────────────────────────────────────────────────────────── +# 10. test_wrapper_retries_on_503_via_renderer + + +@pytest.mark.asyncio +async def test_wrapper_retries_on_503_via_renderer() -> None: + """503 ServiceUnavailable is retried by the gate (Decision 4 whitelist).""" + call_count = 0 + + class _FlakyProvider: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + if call_count < 3: + exc = openai.APIStatusError( + message="Service unavailable", + response=MagicMock(headers={}, status_code=503), + body={"error": {"message": "service unavailable"}}, + ) + raise exc + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=5) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + wrapped = ConcurrencyGatedProvider(_FlakyProvider(), gate, config) + + result = await wrapped.complete("hello") + assert result.content == "ok" + assert call_count == 3 + + +# ────────────────────────────────────────────────────────────────────────────── +# 11. test_wrapper_releases_slot_on_cancel_during_call + + +@pytest.mark.asyncio +async def test_wrapper_releases_slot_on_cancel_during_call() -> None: + """Cancelling during an upstream call releases the semaphore slot.""" + config = ConcurrencyConfig(llm_max_concurrent={"openai": 1}, retry_max_attempts=1) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + hold_event = asyncio.Event() + cancel_event = asyncio.Event() + + class _Blocking: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + hold_event.set() + await cancel_event.wait() # Block indefinitely + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + wrapped = ConcurrencyGatedProvider(_Blocking(), gate, config) + + task = asyncio.create_task(wrapped.complete("hello")) + await hold_event.wait() # Wait until the slot is held + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # Slot must be released. + assert gate._binding._in_flight == 0, "Slot leaked after cancellation during call" + assert gate._binding._waiters == 0 + + # Next acquire should succeed immediately. + acquired = asyncio.Event() + + async def _check() -> None: + async with gate.slot(): + acquired.set() + + await asyncio.wait_for(_check(), timeout=1.0) + assert acquired.is_set() + + +# ────────────────────────────────────────────────────────────────────────────── +# 12. test_wrapper_releases_slot_on_cancel_during_limiter_wait + + +@pytest.mark.asyncio +async def test_wrapper_releases_slot_on_cancel_during_limiter_wait() -> None: + """Cancelling while awaiting the limiter does not leak a slot (slot was never acquired).""" + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=1) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + # A limiter that blocks forever until cancelled. + class _BlockingLimiter: + async def acquire(self, amount: float = 1) -> None: + await asyncio.sleep(9999) # blocks indefinitely + + class _SimpleProvider: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + yield "ok" + + wrapped = ConcurrencyGatedProvider(_SimpleProvider(), gate, config) + wrapped._limiter = _BlockingLimiter() # type: ignore[assignment] + + task = asyncio.create_task(wrapped.complete("hello")) + await asyncio.sleep(0.01) # let the task start and block on limiter + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # No slot was ever acquired (cancelled before limiter returned). + assert gate._binding._in_flight == 0, "Slot leaked after cancel during limiter wait" + + +# ────────────────────────────────────────────────────────────────────────────── +# 13. test_cancellation_while_queued_releases_waiter_count + + +@pytest.mark.asyncio +async def test_cancellation_while_queued_releases_waiter_count() -> None: + """Cancelling a queued coroutine decrements waiter count and doesn't block future acquires.""" + config = ConcurrencyConfig(llm_max_concurrent={"openai": 1}, retry_max_attempts=1) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + hold_start = asyncio.Event() + release_a = asyncio.Event() + + # A holds the only slot. + async def hold_slot_a() -> None: + async with gate.slot(): + hold_start.set() + await release_a.wait() + + task_a = asyncio.create_task(hold_slot_a()) + await hold_start.wait() + + # B queues for the slot. + async def wait_for_slot_b() -> None: + async with gate.slot(): + pass # We don't want B to actually do anything + + task_b = asyncio.create_task(wait_for_slot_b()) + await asyncio.sleep(0.01) # let B enter the wait + + assert gate._binding._waiters == 1, f"Expected 1 waiter, got {gate._binding._waiters}" + + # Cancel B while it's queued. + task_b.cancel() + with pytest.raises(asyncio.CancelledError): + await task_b + + # Waiter count must be decremented. + assert gate._binding._waiters == 0, ( + f"Waiter count not decremented after cancel: {gate._binding._waiters}" + ) + + # Release A; C should acquire immediately. + release_a.set() + await task_a + + acquired = asyncio.Event() + + async def task_c() -> None: + async with gate.slot(): + acquired.set() + + await asyncio.wait_for(task_c(), timeout=1.0) + assert acquired.is_set(), "C could not acquire after B was cancelled" + + +# ────────────────────────────────────────────────────────────────────────────── +# 14. test_wrapper_releases_slot_between_retry_attempts (Decision 2 — critical) + + +@pytest.mark.asyncio +async def test_wrapper_releases_slot_between_retry_attempts() -> None: + """Slot is released during retry sleep so other callers can proceed (Decision 2).""" + config = ConcurrencyConfig(llm_max_concurrent={"openai": 1}, retry_max_attempts=3) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + attempt_started = asyncio.Event() + other_acquired = asyncio.Event() + call_count = 0 + + class _FailOnceThenSucceed: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + attempt_started.set() + if call_count == 1: + exc = openai.RateLimitError( + message="rate limit", + response=MagicMock(headers={"retry-after": "0.05"}, status_code=429), + body={}, + ) + raise exc + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + wrapped = ConcurrencyGatedProvider(_FailOnceThenSucceed(), gate, config) + + main_task = asyncio.create_task(wrapped.complete("hello")) + await attempt_started.wait() + + # While the main task is sleeping between retries, this task should be able + # to acquire the gate (slot was released before the sleep per Decision 2). + async def _other() -> None: + # Give the main task a moment to release and enter the sleep. + await asyncio.sleep(0.01) + async with gate.slot(): + other_acquired.set() + + other_task = asyncio.create_task(_other()) + + result = await main_task + await asyncio.wait_for(other_task, timeout=2.0) + + assert result.content == "ok" + assert other_acquired.is_set(), "Other coroutine could not acquire slot during retry sleep" + + +# ────────────────────────────────────────────────────────────────────────────── +# 15. test_wrapper_does_not_hold_slot_during_limiter_wait (Decision 2) + + +@pytest.mark.asyncio +async def test_wrapper_does_not_hold_slot_during_limiter_wait() -> None: + """Limiter wait happens OUTSIDE the slot (Decision 2 ordering).""" + config = ConcurrencyConfig(llm_max_concurrent={"openai": 1}, retry_max_attempts=1) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + limiter_acquired = asyncio.Event() + release_limiter = asyncio.Event() + + class _HoldLimiter: + """A fake limiter that signals when it's been entered, then blocks.""" + + async def acquire(self, amount: float = 1) -> None: + limiter_acquired.set() + await release_limiter.wait() + + class _SimpleProvider: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + return _response() + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + yield "ok" + + wrapped = ConcurrencyGatedProvider(_SimpleProvider(), gate, config) + wrapped._limiter = _HoldLimiter() # type: ignore[assignment] + + main_task = asyncio.create_task(wrapped.complete("hello")) + await limiter_acquired.wait() + + # While limiter is blocking (not yet called complete), the slot must be free. + assert gate._binding._in_flight == 0, ( + "Slot is being held during limiter wait — Decision 2 violation" + ) + + # Another task should be able to acquire the slot while limiter blocks. + slot_acquired = asyncio.Event() + + async def _other() -> None: + async with gate.slot(): + slot_acquired.set() + + other_task = asyncio.create_task(_other()) + await asyncio.wait_for(other_task, timeout=1.0) + assert slot_acquired.is_set() + + release_limiter.set() + await asyncio.wait_for(main_task, timeout=2.0) + + +# ────────────────────────────────────────────────────────────────────────────── +# 16. test_wrapper_jitter_spreads_retry_timestamps + + +@pytest.mark.asyncio +async def test_wrapper_jitter_spreads_retry_timestamps() -> None: + """wait_random_exponential produces non-zero jitter (mocked random.uniform).""" + # tenacity's wait_random_exponential calls random.uniform at module level. + # We verify that our wrapper does NOT call with uniform(0, 0) by checking that + # the recorded wait sequence grows with exponential base when jitter is frozen. + config = ConcurrencyConfig(llm_max_concurrent={"openai": 4}, retry_max_attempts=4) + registry = _registry(config) + gate = await registry.lookup("openai", "https://api.openai.com/v1", "llm") + + call_count = 0 + + class _AlwaysRateLimit: + provider_name = "openai" + + @property + def default_model(self) -> str: + return "gpt-4o" + + async def complete(self, *a: object, **kw: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + raise openai.RateLimitError( + message="rate limit", + response=MagicMock(headers={}, status_code=429), + body={}, + ) + + async def stream(self, *a: object, **kw: object) -> AsyncIterator[str]: + raise NotImplementedError + + wrapped = ConcurrencyGatedProvider(_AlwaysRateLimit(), gate, config) + + recorded_waits: list[float] = [] + + real_uniform = random.uniform + + def _patched_uniform(a: float, b: float) -> float: + # Return the midpoint for determinism; record the upper bound (b) as + # a proxy for the exponential window width. + result = real_uniform(a, b) + recorded_waits.append(b) + return result + + with patch("random.uniform", side_effect=_patched_uniform), pytest.raises(openai.RateLimitError): + await wrapped.complete("hello") + + # Jitter should have been called (at least once per retry gap). + assert len(recorded_waits) > 0, "random.uniform was never called — jitter absent" + # Exponential backing: window width should grow (or stay) between attempts. + # With multiplier=1, the window at attempt N is [0, min(2^(N-1), 60)]. + for i in range(1, len(recorded_waits)): + assert recorded_waits[i] >= recorded_waits[0] * 0.9, ( + f"Wait window shrank: {recorded_waits}" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 17. test_gate_rejects_zero_max_concurrent (already in Phase 1; extended here) + + +def test_gate_rejects_zero_max_concurrent_extended() -> None: + """_KindGate also rejects max_concurrent=0.""" + from workers.common.llm.concurrency import _KindGate + + with pytest.raises(ValueError, match="max_concurrent"): + _KindGate(max_concurrent=0) + + +# ────────────────────────────────────────────────────────────────────────────── +# 18. test_concurrency_config_from_env_rejects_invalid_rpm (Phase 1 extended) + + +def test_concurrency_config_from_env_rejects_negative_rpm( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OPENAI_RPM", "-5") + config = ConcurrencyConfig.from_env() + # Negative is invalid; should not be stored. + assert "openai" not in config.rpm + + +# ────────────────────────────────────────────────────────────────────────────── +# 19. test_concurrency_config_from_env_rejects_zero_max_concurrent (Phase 1 extended) + + +def test_concurrency_config_from_env_env_override_takes_precedence( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Env-var override supersedes Decision 6 default.""" + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT", "8") + config = ConcurrencyConfig.from_env() + assert config.llm_max_concurrent["ollama"] == 8 + + +# ────────────────────────────────────────────────────────────────────────────── +# 20. test_concurrency_config_warns_on_unknown_provider_token +# (Plan codex r2 L1: validator should warn on unknown provider tokens) + + +def test_concurrency_config_warns_on_unknown_provider_token( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Decision 7 / codex r2 L1: typo'd provider env vars produce a structlog + warning, not a silent no-op. The valid env var is still parsed correctly.""" + import structlog.testing + + # Typo: OPENAICOMPAT instead of OPENAI_COMPATIBLE + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OPENAICOMPAT_MAX_CONCURRENT", "4") + # Valid env var as a control — must still be applied. + monkeypatch.setenv("SOURCEBRIDGE_LLM_PROVIDER_OLLAMA_MAX_CONCURRENT", "2") + + with structlog.testing.capture_logs() as captured: + cfg = ConcurrencyConfig.from_env() + + # Must not raise; valid env var still parsed. + assert cfg is not None + assert cfg.llm_max_concurrent.get("ollama") == 2 + + # Unknown token must produce a warning event. + unknown_warnings = [ + e for e in captured + if e.get("event") == "concurrency_config_unknown_provider_token" + and e.get("unknown_token") == "OPENAICOMPAT" + ] + assert len(unknown_warnings) >= 1, ( + f"Expected warning about OPENAICOMPAT typo; captured events: {[e.get('event') for e in captured]}" + ) + # The warning should name the bad env var and list canonical tokens. + w = unknown_warnings[0] + assert w["env_var"] == "SOURCEBRIDGE_LLM_PROVIDER_OPENAICOMPAT_MAX_CONCURRENT" + assert "OPENAI_COMPATIBLE" in w["canonical_tokens"] + + +def test_concurrency_config_decision6_real_defaults_loaded() -> None: + """Decision 6 real defaults are present without any env-var overrides.""" + import os + + # Ensure no override env vars are set for known providers. + env_backup = { + k: v for k, v in os.environ.items() + if k.startswith("SOURCEBRIDGE_LLM_PROVIDER_") or k.startswith("SOURCEBRIDGE_EMBEDDING_PROVIDER_") + } + for k in env_backup: + del os.environ[k] + + try: + config = ConcurrencyConfig.from_env() + # Decision 6 table assertions. + assert config.llm_max_concurrent.get("ollama") == 1 + assert config.llm_max_concurrent.get("vllm") == 4 + assert config.llm_max_concurrent.get("llama-cpp") == 4 + assert config.llm_max_concurrent.get("sglang") == 4 + assert config.llm_max_concurrent.get("lmstudio") == 2 + assert config.llm_max_concurrent.get("openai") == 8 + assert config.llm_max_concurrent.get("anthropic") == 4 + assert config.llm_max_concurrent.get("openrouter") == 8 + assert config.llm_max_concurrent.get("gemini") == 8 + assert config.llm_max_concurrent.get("openai-compatible") == 4 + # Frontier embedding defaults. + assert config.embedding_max_concurrent.get("openai") == 8 + assert config.embedding_max_concurrent.get("openrouter") == 8 + assert config.embedding_max_concurrent.get("gemini") == 8 + finally: + os.environ.update(env_backup) + + +def test_concurrency_config_retry_max_attempts_default_is_5() -> None: + """Phase 3: default retry_max_attempts is 5 (not 1 as in Phase 1).""" + config = ConcurrencyConfig() + assert config.retry_max_attempts == 5 + + +def test_concurrency_config_from_env_retry_default_is_5( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("SOURCEBRIDGE_LLM_RETRY_MAX_ATTEMPTS", raising=False) + config = ConcurrencyConfig.from_env() + assert config.retry_max_attempts == 5 + + +# ────────────────────────────────────────────────────────────────────────────── +# 21. test_empty_content_retry_preserved_in_phase_3 + + +@pytest.mark.asyncio +async def test_empty_content_retry_preserved_in_phase_3() -> None: + """Empty-content retry at openai_compat._complete_once still fires after Phase 3. + + This is the -budget retry that doubles max_tokens when stop_reason=length + and content is empty. It is NOT a network/rate-limit retry — it lives inside + OpenAICompatProvider.complete() and must survive the Phase 3 changes. + + Verified by: mocking the internal SDK client to return one length/empty response + then a non-empty response; asserting the retry fired (call count == 2) and the + final content is non-empty. + """ + from workers.common.llm.openai_compat import OpenAICompatProvider + + call_count = 0 + + async def _mock_create(**kwargs: object) -> MagicMock: + nonlocal call_count + call_count += 1 + resp = MagicMock() + if call_count == 1: + # First call: empty content + stop_reason=length + resp.choices = [MagicMock(message=MagicMock(content=""), finish_reason="length")] + resp.usage = MagicMock(prompt_tokens=10, completion_tokens=0) + resp.model = "gpt-4o" + else: + # Second call: valid content + resp.choices = [MagicMock(message=MagicMock(content="the actual answer"), finish_reason="stop")] + resp.usage = MagicMock(prompt_tokens=10, completion_tokens=5) + resp.model = "gpt-4o" + return resp + + provider = OpenAICompatProvider(api_key="test-key", base_url="http://localhost:11434/v1") + provider.client.chat.completions.create = _mock_create # type: ignore[assignment] + + result = await provider.complete("test prompt") + + assert call_count == 2, f"Expected 2 calls (empty-content retry), got {call_count}" + assert result.content.strip() == "the actual answer" + + +# ────────────────────────────────────────────────────────────────────────────── +# 22-23. Registry lifecycle (close idempotent + rejects lookup after close) +# These were in Phase 1; extended here with in-flight behavior. + + +@pytest.mark.asyncio +async def test_registry_close_idempotent() -> None: + """Calling close() twice does not raise.""" + reg = _registry() + await reg.close() + await reg.close() + + +@pytest.mark.asyncio +async def test_registry_rejects_lookup_after_close() -> None: + """Lookup after close raises RuntimeError.""" + reg = _registry() + await reg.close() + with pytest.raises(RuntimeError, match="closed"): + await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + + +@pytest.mark.asyncio +async def test_registry_close_during_in_flight_calls() -> None: + """Registry.close() can be called while a gate slot is held. + + The slot is held by a task that has already acquired the semaphore; close() + marks the registry as closed so new lookups fail, but does not force-cancel + the in-flight task. The in-flight task releases normally. + """ + reg = _registry(ConcurrencyConfig(llm_max_concurrent={"openai": 2}, retry_max_attempts=1)) + gate = await reg.lookup("openai", "https://api.openai.com/v1", "llm") + + acquired = asyncio.Event() + released = asyncio.Event() + + async def hold() -> None: + async with gate.slot(): + acquired.set() + await released.wait() + + task = asyncio.create_task(hold()) + await acquired.wait() + + # Close while slot is held. + await reg.close() + assert reg._closed + + # New lookups must fail. + with pytest.raises(RuntimeError, match="closed"): + await reg.lookup("openai", "https://api.openai.com/v1", "llm") + + # Release the in-flight task — it should complete without error. + released.set() + await asyncio.wait_for(task, timeout=1.0) + + # Slot is released cleanly. + assert gate._binding._in_flight == 0 + + +# ────────────────────────────────────────────────────────────────────────────── +# 24. _retry_predicate unit tests + + +def test_retry_predicate_accepts_openai_rate_limit() -> None: + exc = openai.RateLimitError( + message="too many requests", + response=MagicMock(headers={}, status_code=429), + body={}, + ) + assert _retry_predicate(exc) is True + + +def test_retry_predicate_accepts_anthropic_rate_limit() -> None: + exc = anthropic.RateLimitError( + message="too many requests", + response=MagicMock(headers={}, status_code=429), + body={}, + ) + assert _retry_predicate(exc) is True + + +def test_retry_predicate_accepts_503() -> None: + exc = openai.APIStatusError( + message="service unavailable", + response=MagicMock(headers={}, status_code=503), + body={}, + ) + assert _retry_predicate(exc) is True + + +def test_retry_predicate_rejects_400() -> None: + exc = openai.APIStatusError( + message="bad request", + response=MagicMock(headers={}, status_code=400), + body={}, + ) + assert _retry_predicate(exc) is False + + +def test_retry_predicate_rejects_401() -> None: + exc = openai.AuthenticationError( + message="invalid key", + response=MagicMock(headers={}, status_code=401), + body={}, + ) + assert _retry_predicate(exc) is False + + +def test_retry_predicate_accepts_timeout() -> None: + exc = httpx.ConnectTimeout("timed out") + assert _retry_predicate(exc) is True + + +def test_retry_predicate_accepts_read_error() -> None: + exc = httpx.ReadError("read error") + assert _retry_predicate(exc) is True + + +def test_retry_predicate_rejects_runtime_error() -> None: + assert _retry_predicate(RuntimeError("oops")) is False + + +def test_retry_predicate_rejects_value_error() -> None: + assert _retry_predicate(ValueError("bad")) is False + + +# ────────────────────────────────────────────────────────────────────────────── +# 25. router.py unwraps RetryError + + +def test_router_unwraps_retry_error() -> None: + """LLMRouter unwraps tenacity.RetryError to expose the original cause.""" + from tenacity import RetryError + + from workers.common.llm.router import _unwrap_retry_error + + original = RuntimeError("original cause") + wrapped = RetryError.__new__(RetryError) + wrapped.__cause__ = original + + result = _unwrap_retry_error(wrapped) + assert result is original + + +def test_router_passthrough_when_not_retry_error() -> None: + from workers.common.llm.router import _unwrap_retry_error + + exc = ValueError("plain error") + assert _unwrap_retry_error(exc) is exc + + +# ────────────────────────────────────────────────────────────────────────────── +# 26. Hierarchical + renderer cap constants raised + + +def test_hierarchical_default_caps_raised() -> None: + """Phase 3 raises DEFAULT_LEAF/FILE/PACKAGE_CONCURRENCY to 4.""" + from workers.comprehension.hierarchical import ( + DEFAULT_FILE_CONCURRENCY, + DEFAULT_LEAF_CONCURRENCY, + DEFAULT_PACKAGE_CONCURRENCY, + ) + + assert DEFAULT_LEAF_CONCURRENCY == 4, f"Expected 4, got {DEFAULT_LEAF_CONCURRENCY}" + assert DEFAULT_FILE_CONCURRENCY == 4, f"Expected 4, got {DEFAULT_FILE_CONCURRENCY}" + assert DEFAULT_PACKAGE_CONCURRENCY == 4, f"Expected 4, got {DEFAULT_PACKAGE_CONCURRENCY}" + + +def test_renderer_deep_parallelism_raised() -> None: + """Phase 3 raises deep_parallelism and deep_repair_parallelism defaults to 4.""" + from workers.comprehension.renderers import CliffNotesRenderer + + fields = {f.name: f.default for f in dataclasses.fields(CliffNotesRenderer)} + assert fields.get("deep_parallelism") == 4, ( + f"Expected deep_parallelism=4, got {fields.get('deep_parallelism')}" + ) + assert fields.get("deep_repair_parallelism") == 4, ( + f"Expected deep_repair_parallelism=4, got {fields.get('deep_repair_parallelism')}" + ) diff --git a/workers/tests/test_concurrency_phase4.py b/workers/tests/test_concurrency_phase4.py new file mode 100644 index 00000000..4d7c6ef2 --- /dev/null +++ b/workers/tests/test_concurrency_phase4.py @@ -0,0 +1,274 @@ +"""Phase 4 tests: bounded concurrent fan-out of spec extraction group loop. + +Tests verify: + - test_spec_extraction_fans_out_groups: 5+ groups produce peak in-flight > 1 + under max_concurrent=4 (gate cap). + - test_spec_extraction_serial_when_one_group: single group → no fan-out, no + behavior change. + - test_spec_extraction_local_bound_caps_pending_coroutines: 100 groups under + LOCAL_GROUP_FANOUT_LIMIT=8; at most 8 in-flight at any moment (plan M4). + - Existing extraction behavior is preserved (model_name, usage accumulation, + exception handling). + +Refs: CA-169 / plan v4 Phase 4. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator + +import pytest + +from workers.common.llm.provider import LLMResponse +from workers.requirements.spec_extraction import _SPEC_GROUP_FANOUT_LIMIT, refine_with_llm +from workers.requirements.spec_models import CandidateSpec + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers + + +def _candidate(group_key: str, source: str = "test", idx: int = 0) -> CandidateSpec: + return CandidateSpec( + source=source, + source_file=f"tests/test_{group_key}.go", + source_line=idx * 10 + 1, + raw_text=f"Test{group_key.capitalize()}_{idx}", + group_key=group_key, + language="go", + metadata={}, + ) + + +def _response(model: str = "test-model") -> LLMResponse: + return LLMResponse( + content='{"requirement_text": "The system must do X.", "keywords": ["auth"]}', + model=model, + input_tokens=20, + output_tokens=10, + stop_reason="end_turn", + ) + + +class _InstantProvider: + """Fake provider that returns immediately; records concurrent in-flight count.""" + + def __init__(self, *, latency: float = 0.05, model: str = "test-model") -> None: + self._latency = latency + self._model = model + self.peak_in_flight = 0 + self._in_flight = 0 + self._lock = asyncio.Lock() + + @property + def provider_name(self) -> str: + return "fake" + + @property + def default_model(self) -> str: + return self._model + + async def complete(self, prompt: str, **kwargs: object) -> LLMResponse: + async with self._lock: + self._in_flight += 1 + if self._in_flight > self.peak_in_flight: + self.peak_in_flight = self._in_flight + try: + await asyncio.sleep(self._latency) + return _response(self._model) + finally: + async with self._lock: + self._in_flight -= 1 + + async def stream(self, *args: object, **kwargs: object) -> AsyncIterator[str]: + yield "ok" + + +# ────────────────────────────────────────────────────────────────────────────── +# 1. test_spec_extraction_fans_out_groups + + +@pytest.mark.asyncio +async def test_spec_extraction_fans_out_groups() -> None: + """5 groups with a slow provider should show peak in-flight > 1 (concurrent fan-out).""" + provider = _InstantProvider(latency=0.05) + candidates = [_candidate(f"group{i}") for i in range(5)] + + specs, usage = await refine_with_llm(candidates, provider) + + assert len(specs) == 5, f"Expected 5 refined specs, got {len(specs)}" + assert provider.peak_in_flight > 1, ( + f"Peak in-flight was {provider.peak_in_flight}; expected > 1 (concurrent fan-out)" + ) + # Usage should be accumulated across all groups. + assert usage.input_tokens == 5 * 20 + assert usage.output_tokens == 5 * 10 + assert usage.model == "test-model" + + +# ────────────────────────────────────────────────────────────────────────────── +# 2. test_spec_extraction_serial_when_one_group + + +@pytest.mark.asyncio +async def test_spec_extraction_serial_when_one_group() -> None: + """Single group: behavior is unchanged — no fan-out (trivially serial).""" + provider = _InstantProvider(latency=0.01) + candidates = [_candidate("auth")] + + specs, usage = await refine_with_llm(candidates, provider) + + assert len(specs) == 1 + assert specs[0].group_key == "auth" + assert specs[0].llm_refined is True + assert usage.input_tokens == 20 + assert usage.output_tokens == 10 + + +# ────────────────────────────────────────────────────────────────────────────── +# 3. test_spec_extraction_local_bound_caps_pending_coroutines + + +@pytest.mark.asyncio +async def test_spec_extraction_local_bound_caps_pending_coroutines() -> None: + """100 groups: at most LOCAL_GROUP_FANOUT_LIMIT (8) in-flight at any moment. + + Plan codex r1 M4: the local semaphore must prevent N-thousand pending + coroutines from building up before the upstream gate drains them. + + Approach: use a slow provider (latency=0.04s) so that many groups are + blocked on the semaphore simultaneously, making the cap observable via + peak_in_flight. We then verify peak_in_flight <= 8. + """ + provider = _InstantProvider(latency=0.04) + + # 100 distinct groups, one candidate each. + candidates = [_candidate(f"group{i}") for i in range(100)] + + specs, usage = await refine_with_llm(candidates, provider) + + assert len(specs) == 100, f"Expected 100 specs, got {len(specs)}" + # _SPEC_GROUP_FANOUT_LIMIT: peak must not exceed the cap. + assert provider.peak_in_flight <= _SPEC_GROUP_FANOUT_LIMIT, ( + f"Peak in-flight {provider.peak_in_flight} exceeded _SPEC_GROUP_FANOUT_LIMIT={_SPEC_GROUP_FANOUT_LIMIT}" + ) + # Fan-out is actually happening (not fully serial). + assert provider.peak_in_flight > 1, ( + "Peak in-flight was 1; expected > 1 (fan-out should be active for 100 groups)" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 4. test_spec_extraction_exception_handling_preserved + + +@pytest.mark.asyncio +async def test_spec_extraction_exception_handling_preserved() -> None: + """LLM failure on one group falls back to raw_text; other groups still refined.""" + call_count = 0 + + class _FailOnSecond: + @property + def provider_name(self) -> str: + return "fake" + + @property + def default_model(self) -> str: + return "test-model" + + async def complete(self, prompt: str, **kwargs: object) -> LLMResponse: + nonlocal call_count + call_count += 1 + if "group1" in prompt: + raise RuntimeError("simulated LLM failure") + return _response() + + async def stream(self, *args: object, **kwargs: object) -> AsyncIterator[str]: + yield "ok" + + candidates = [_candidate("group0"), _candidate("group1"), _candidate("group2")] + specs, usage = await refine_with_llm(candidates, _FailOnSecond()) + + # All three groups produce a spec (failed group falls back to raw_text). + assert len(specs) == 3 + keys = {s.group_key for s in specs} + assert "group0" in keys + assert "group1" in keys + assert "group2" in keys + + # The failed group uses raw_text as fallback (no keywords, not refined-text-modified). + failed = next(s for s in specs if s.group_key == "group1") + assert failed.llm_refined is True # llm_refined flag is set by the pipeline regardless + + +# ────────────────────────────────────────────────────────────────────────────── +# 5. test_spec_extraction_output_order_stable + + +@pytest.mark.asyncio +async def test_spec_extraction_output_order_stable() -> None: + """Output list preserves groups.items() insertion order (dict ordering). + + deduplicate_specs is key-based so order doesn't affect correctness, but + stable output is useful for deterministic tests downstream. + """ + provider = _InstantProvider(latency=0.0) + keys = [f"key{i:03d}" for i in range(10)] + candidates = [_candidate(k) for k in keys] + + specs, _ = await refine_with_llm(candidates, provider) + + assert [s.group_key for s in specs] == keys, ( + "Output order does not match input group order" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 6. test_spec_extraction_usage_accumulates_across_groups + + +@pytest.mark.asyncio +async def test_spec_extraction_usage_accumulates_across_groups() -> None: + """Token counts are summed across all groups; model_name is last writer wins.""" + call_num = 0 + + class _CountingProvider: + @property + def provider_name(self) -> str: + return "fake" + + @property + def default_model(self) -> str: + return "test-model" + + async def complete(self, prompt: str, **kwargs: object) -> LLMResponse: + nonlocal call_num + call_num += 1 + return LLMResponse( + content='{"requirement_text": "R.", "keywords": []}', + model=f"model-{call_num}", + input_tokens=call_num * 10, + output_tokens=call_num * 5, + stop_reason="end_turn", + ) + + async def stream(self, *args: object, **kwargs: object) -> AsyncIterator[str]: + yield "ok" + + n = 4 + candidates = [_candidate(f"g{i}") for i in range(n)] + specs, usage = await refine_with_llm(candidates, _CountingProvider()) + + assert len(specs) == n + # Total input_tokens = 10 + 20 + 30 + 40 = 100 (but call_num order is + # non-deterministic under fan-out; assert sum is correct regardless of order). + expected_input = sum(i * 10 for i in range(1, n + 1)) + expected_output = sum(i * 5 for i in range(1, n + 1)) + assert usage.input_tokens == expected_input, ( + f"input_tokens {usage.input_tokens} != {expected_input}" + ) + assert usage.output_tokens == expected_output, ( + f"output_tokens {usage.output_tokens} != {expected_output}" + ) + # model_name is a non-empty string (last-writer wins under fan-out). + assert usage.model.startswith("model-"), f"model name unexpected: {usage.model}" diff --git a/workers/tests/test_concurrency_phase5.py b/workers/tests/test_concurrency_phase5.py new file mode 100644 index 00000000..ea5ad97a --- /dev/null +++ b/workers/tests/test_concurrency_phase5.py @@ -0,0 +1,210 @@ +"""Phase 5 tests: embedding fan-out (OpenAI-compat only; Ollama stays serial). + +Tests verify: + - test_openai_compat_embed_fans_out_chunks: peak in-flight == 4 under + multi-chunk input (>256 texts to trigger chunking, LOCAL_EMBEDDING_FANOUT_LIMIT=4). + - test_ollama_embed_does_not_fan_out: peak in-flight == 1 (serial preserved). + - Result-order preservation: embed N items → output[i] corresponds to input[i]. + +Refs: CA-169 / plan v4 Phase 5. +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from workers.common.embedding.ollama import OllamaEmbeddingProvider +from workers.common.embedding.openai_compat import ( + _BATCH_SIZE, + LOCAL_EMBEDDING_FANOUT_LIMIT, + OpenAICompatEmbeddingProvider, +) + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers + + +def _texts(n: int) -> list[str]: + return [f"text-{i}" for i in range(n)] + + +def _fake_embedding(text: str) -> list[float]: + """Deterministic embedding so order-preservation can be asserted.""" + return [float(hash(text) % 1000) / 1000.0, float(len(text)) / 100.0] + + +class _BatchTracker: + """Tracks peak concurrent in-flight _embed_batch calls.""" + + def __init__(self, latency: float = 0.05) -> None: + self.latency = latency + self.peak_in_flight = 0 + self._in_flight = 0 + self._lock = asyncio.Lock() + + def patch(self, provider: OpenAICompatEmbeddingProvider | OllamaEmbeddingProvider) -> None: + tracker = self + + async def _fake_embed_batch(texts: list[str]) -> list[list[float]]: + async with tracker._lock: + tracker._in_flight += 1 + if tracker._in_flight > tracker.peak_in_flight: + tracker.peak_in_flight = tracker._in_flight + try: + await asyncio.sleep(tracker.latency) + return [_fake_embedding(t) for t in texts] + finally: + async with tracker._lock: + tracker._in_flight -= 1 + + provider._embed_batch = _fake_embed_batch # type: ignore[method-assign] + + +# ────────────────────────────────────────────────────────────────────────────── +# 1. test_openai_compat_embed_fans_out_chunks + + +@pytest.mark.asyncio +async def test_openai_compat_embed_fans_out_chunks() -> None: + """OpenAI-compat: >256 texts split into batches, issued concurrently (peak == LOCAL_EMBEDDING_FANOUT_LIMIT). + + We use LOCAL_EMBEDDING_FANOUT_LIMIT * 3 batches so the semaphore is + saturated: LOCAL_EMBEDDING_FANOUT_LIMIT batches hold slots while the rest + queue, producing peak_in_flight == LOCAL_EMBEDDING_FANOUT_LIMIT. + """ + batch_count = LOCAL_EMBEDDING_FANOUT_LIMIT * 3 # 12 batches → semaphore saturates + n_texts = batch_count * _BATCH_SIZE # exactly 12 full batches + + provider = OpenAICompatEmbeddingProvider( + base_url="http://localhost:11434", + model="nomic-embed-text", + ) + tracker = _BatchTracker(latency=0.05) + tracker.patch(provider) + + texts = _texts(n_texts) + result = await provider.embed(texts) + + assert len(result) == n_texts, f"Expected {n_texts} embeddings, got {len(result)}" + assert tracker.peak_in_flight > 1, ( + f"Peak in-flight was {tracker.peak_in_flight}; expected > 1 (fan-out active)" + ) + assert tracker.peak_in_flight <= LOCAL_EMBEDDING_FANOUT_LIMIT, ( + f"Peak in-flight {tracker.peak_in_flight} exceeded LOCAL_EMBEDDING_FANOUT_LIMIT={LOCAL_EMBEDDING_FANOUT_LIMIT}" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 2. test_ollama_embed_does_not_fan_out + + +@pytest.mark.asyncio +async def test_ollama_embed_does_not_fan_out() -> None: + """Ollama embedding stays serial — host gate combines with LLM gate (plan Decision 1 + 8). + + Peak in-flight must be exactly 1 even for multi-batch inputs. + """ + # Use 3 batches to ensure multi-batch path is exercised. + n_texts = _BATCH_SIZE * 3 + + provider = OllamaEmbeddingProvider( + base_url="http://localhost:11434", + model="nomic-embed-text", + ) + tracker = _BatchTracker(latency=0.03) + tracker.patch(provider) + + texts = _texts(n_texts) + result = await provider.embed(texts) + + assert len(result) == n_texts + assert tracker.peak_in_flight == 1, ( + f"Ollama peak in-flight was {tracker.peak_in_flight}; expected 1 (serial)" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 3. test_openai_compat_embed_order_preserved + + +@pytest.mark.asyncio +async def test_openai_compat_embed_order_preserved() -> None: + """output[i] corresponds to input[i] after concurrent fan-out. + + _fake_embedding is deterministic on the text string, so we can assert + that the embedding at position i was produced from texts[i]. + """ + # Use enough texts to span 5 batches (fan-out > 1 guaranteed). + n_texts = _BATCH_SIZE * 5 + + provider = OpenAICompatEmbeddingProvider( + base_url="http://localhost:11434", + model="nomic-embed-text", + ) + tracker = _BatchTracker(latency=0.02) + tracker.patch(provider) + + texts = _texts(n_texts) + result = await provider.embed(texts) + + assert len(result) == n_texts + for i, (text, embedding) in enumerate(zip(texts, result, strict=True)): + expected = _fake_embedding(text) + assert embedding == expected, ( + f"Output[{i}] mismatch: got {embedding}, expected embedding for '{text}'" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 4. test_openai_compat_embed_single_batch_no_fan_out + + +@pytest.mark.asyncio +async def test_openai_compat_embed_single_batch_no_fan_out() -> None: + """Single batch (<=256 texts) takes the fast path — no Semaphore constructed.""" + n_texts = _BATCH_SIZE # exactly one batch + + provider = OpenAICompatEmbeddingProvider( + base_url="http://localhost:11434", + model="nomic-embed-text", + ) + tracker = _BatchTracker(latency=0.0) + tracker.patch(provider) + + texts = _texts(n_texts) + result = await provider.embed(texts) + + assert len(result) == n_texts + # Single batch → _embed_batch called once, peak must be 1. + assert tracker.peak_in_flight == 1, ( + f"Single-batch path should have peak==1, got {tracker.peak_in_flight}" + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# 5. test_ollama_embed_order_preserved + + +@pytest.mark.asyncio +async def test_ollama_embed_order_preserved() -> None: + """Ollama serial embed preserves output[i] == embedding(input[i]).""" + n_texts = _BATCH_SIZE * 2 + + provider = OllamaEmbeddingProvider( + base_url="http://localhost:11434", + model="nomic-embed-text", + ) + tracker = _BatchTracker(latency=0.0) + tracker.patch(provider) + + texts = _texts(n_texts) + result = await provider.embed(texts) + + assert len(result) == n_texts + for i, (text, embedding) in enumerate(zip(texts, result, strict=True)): + expected = _fake_embedding(text) + assert embedding == expected, ( + f"Output[{i}] mismatch: got {embedding}, expected embedding for '{text}'" + ) diff --git a/workers/tests/test_concurrency_phase6.py b/workers/tests/test_concurrency_phase6.py new file mode 100644 index 00000000..0962b5c9 --- /dev/null +++ b/workers/tests/test_concurrency_phase6.py @@ -0,0 +1,353 @@ +"""Phase 6 tests for tok/s ring buffer, aggregator task, and streaming usage extraction. + +Tests verify: + - Gate emits llm_provider_gate_metrics structlog lines from the aggregator task + - record_completion → snapshot_tokens_per_second returns non-zero (non-streaming path) + - OpenAICompatGatedProvider stream extracts usage from the final chunk's + completion_tokens and records it in the ring buffer + - AnthropicGatedProvider stream extracts usage from get_final_message() and records it + - 400 with "stream_options" text falls back to raw.stream() silently + - streaming_usage_unsupported flag is set after the 400 fallback fires + - CancelledError during stream does NOT record tokens in the ring buffer + - close() cancels and awaits the aggregator task (no hang or unraisable warning) + +Refs: CA-169 / plan v4 Phase 6 Verification list. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from unittest.mock import AsyncMock, MagicMock, patch + +import openai +import pytest +import pytest_asyncio # noqa: F401 + +from workers.common.llm.concurrency import ( + AnthropicGatedProvider, + ConcurrencyConfig, + OpenAICompatGatedProvider, + ProviderGateRegistry, + _GateBase, +) + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers + + +def _registry( + wrapper_enabled: bool = True, + metrics_interval_seconds: float = 9999.0, +) -> ProviderGateRegistry: + """Return a registry with a very long aggregator interval so it never fires + during tests that don't explicitly test the aggregator.""" + cfg = ConcurrencyConfig( + wrapper_enabled=wrapper_enabled, + metrics_interval_seconds=metrics_interval_seconds, + ) + return ProviderGateRegistry(cfg) + + +def _gate_base(max_concurrent: int = 4) -> _GateBase: + return _GateBase(max_concurrent=max_concurrent) + + +# ────────────────────────────────────────────────────────────────────────────── +# 1. Aggregator emits structlog lines + + +@pytest.mark.asyncio +async def test_gate_emits_metrics_log() -> None: + """Aggregator task emits llm_provider_gate_metrics within its first interval.""" + cfg = ConcurrencyConfig(metrics_interval_seconds=0.05) + reg = ProviderGateRegistry(cfg) + + # Register a gate so snapshot() returns at least one entry. + await reg.lookup("ollama", "http://localhost:11434/v1", "llm") + + emitted: list[dict] = [] + + def _capture(event: str, **kwargs): + if event == "llm_provider_gate_metrics": + emitted.append({"event": event, **kwargs}) + + # Patch the module-level log bound to concurrency.py's logger. + with patch("workers.common.llm.concurrency.log") as mock_log: + mock_log.info.side_effect = _capture + # Wait a bit longer than one interval for the aggregator to fire. + await asyncio.sleep(0.15) + + await reg.close() + + assert len(emitted) >= 1, f"expected at least one metrics log, got {emitted}" + entry = emitted[0] + assert entry["provider"] == "ollama" + assert "tokens_per_second_60s" in entry + + +# ────────────────────────────────────────────────────────────────────────────── +# 2. record_completion → snapshot_tokens_per_second (non-streaming path) + + +def test_gate_tracks_tokens_per_second_non_streaming() -> None: + """record_completion appends to the ring; snapshot_tokens_per_second is non-zero.""" + gate = _gate_base() + assert gate.snapshot_tokens_per_second() == 0.0 + + gate.record_completion(300) + gate.record_completion(700) + + tps = gate.snapshot_tokens_per_second() + # 1000 tokens / 60 s window ≈ 16.7 tok/s — just verify it's positive. + assert tps > 0.0, f"expected positive tok/s, got {tps}" + + +# ────────────────────────────────────────────────────────────────────────────── +# 3. OpenAICompatGatedProvider extracts usage from streaming final chunk + + +@pytest.mark.asyncio +async def test_gate_tracks_tokens_per_second_streaming_openai_compat_gated_provider() -> None: + """OpenAICompatGatedProvider records ring-buffer tokens from the final chunk.""" + from workers.common.llm.fake import FakeLLMProvider + + reg = _registry() + gate = await reg.lookup("openai", "https://api.openai.com/v1", "llm") + + # Build a minimal mock client that yields one text chunk + one usage chunk. + usage_chunk = MagicMock() + usage_chunk.choices = [] + usage_chunk.usage = MagicMock() + usage_chunk.usage.completion_tokens = 42 + + text_chunk = MagicMock() + text_delta = MagicMock() + text_delta.content = "hello" + text_chunk.choices = [MagicMock(delta=text_delta)] + text_chunk.usage = None + + async def _fake_stream_chunks(*args, **kwargs): + yield text_chunk + yield usage_chunk + + mock_client = MagicMock() + mock_client.chat = MagicMock() + mock_client.chat.completions = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=_fake_stream_chunks()) + + raw = FakeLLMProvider() + raw.client = mock_client # inject the mock OpenAI client + + provider = OpenAICompatGatedProvider(raw, gate) + chunks: list[str] = [] + async for chunk in provider.stream("ping"): + chunks.append(chunk) + + assert "hello" in chunks + tps = gate._binding.snapshot_tokens_per_second() + assert tps > 0.0, f"expected ring buffer to have tokens after stream, got {tps}" + + await reg.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 4. AnthropicGatedProvider extracts usage from get_final_message() + + +@pytest.mark.asyncio +async def test_gate_tracks_tokens_per_second_streaming_anthropic_gated_provider() -> None: + """AnthropicGatedProvider records ring-buffer tokens from get_final_message().""" + from workers.common.llm.fake import FakeLLMProvider + + reg = _registry() + gate = await reg.lookup("anthropic", None, "llm") + + # Mock the Anthropic async streaming context manager. + final_message = MagicMock() + final_message.usage = MagicMock() + final_message.usage.output_tokens = 55 + + mock_stream_cm = MagicMock() + mock_stream_cm.__aenter__ = AsyncMock(return_value=mock_stream_cm) + mock_stream_cm.__aexit__ = AsyncMock(return_value=False) + mock_stream_cm.get_final_message = AsyncMock(return_value=final_message) + + async def _text_stream(): + yield "world" + + mock_stream_cm.text_stream = _text_stream() + + mock_client = MagicMock() + mock_client.messages = MagicMock() + mock_client.messages.stream = MagicMock(return_value=mock_stream_cm) + + raw = FakeLLMProvider() + raw.client = mock_client + + provider = AnthropicGatedProvider(raw, gate) + chunks: list[str] = [] + async for chunk in provider.stream("ping"): + chunks.append(chunk) + + assert "world" in chunks + tps = gate._binding.snapshot_tokens_per_second() + assert tps > 0.0, f"expected ring buffer tokens after anthropic stream, got {tps}" + + await reg.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 5. stream_options 400 → falls back to raw.stream(), yields chunks + + +@pytest.mark.asyncio +async def test_stream_options_fallback_on_400() -> None: + """When the client raises APIStatusError 400 with 'stream_options', fall back.""" + from workers.common.llm.fake import FakeLLMProvider + + reg = _registry() + gate = await reg.lookup("openai-compatible", "http://localhost:11434/v1", "llm") + + # Mock client that raises 400 with stream_options in the message. + mock_response = MagicMock() + mock_response.status_code = 400 + api_err = openai.APIStatusError( + "unsupported parameter: stream_options", + response=mock_response, + body=None, + ) + mock_client = MagicMock() + mock_client.chat = MagicMock() + mock_client.chat.completions = MagicMock() + mock_client.chat.completions.create = AsyncMock(side_effect=api_err) + + raw = FakeLLMProvider() + raw.client = mock_client + + # Patch FakeLLMProvider.stream to return a known sequence. + async def _raw_stream(*args, **kwargs) -> AsyncIterator[str]: + yield "fallback" + + raw.stream = _raw_stream # type: ignore[method-assign] + + provider = OpenAICompatGatedProvider(raw, gate) + chunks: list[str] = [] + async for chunk in provider.stream("ping"): + chunks.append(chunk) + + assert chunks == ["fallback"], f"expected fallback chunks, got {chunks}" + + await reg.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 6. streaming_usage_unsupported flag is set after the 400 fallback + + +@pytest.mark.asyncio +async def test_streaming_usage_marked_unsupported_after_400() -> None: + """After the 400 fallback fires, gate.streaming_usage_unsupported is True.""" + from workers.common.llm.fake import FakeLLMProvider + + reg = _registry() + gate = await reg.lookup("openai-compatible", "http://localhost:8080/v1", "llm") + + mock_response = MagicMock() + mock_response.status_code = 400 + api_err = openai.APIStatusError( + "stream_options not supported", + response=mock_response, + body=None, + ) + mock_client = MagicMock() + mock_client.chat = MagicMock() + mock_client.chat.completions = MagicMock() + mock_client.chat.completions.create = AsyncMock(side_effect=api_err) + + raw = FakeLLMProvider() + raw.client = mock_client + + async def _raw_stream(*args, **kwargs) -> AsyncIterator[str]: + yield "ok" + + raw.stream = _raw_stream # type: ignore[method-assign] + + provider = OpenAICompatGatedProvider(raw, gate) + assert not gate.streaming_usage_unsupported + + async for _ in provider.stream("ping"): + pass + + assert gate.streaming_usage_unsupported, "expected flag to be set after 400 fallback" + + await reg.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 7. CancelledError during stream → record_completion NOT called + + +@pytest.mark.asyncio +async def test_streaming_cancellation_no_usage_recorded() -> None: + """When streaming is cancelled mid-flight, tokens are not recorded.""" + from workers.common.llm.fake import FakeLLMProvider + + reg = _registry() + gate = await reg.lookup("openai", "https://api.openai.com/v1", "llm") + + # Simulate a client stream that hangs (never yields the usage chunk). + cancel_event = asyncio.Event() + + async def _hanging_stream(*args, **kwargs): + yield MagicMock( + choices=[MagicMock(delta=MagicMock(content="partial"))], + usage=None, + ) + # Block indefinitely — the task will be cancelled. + await cancel_event.wait() + + mock_client = MagicMock() + mock_client.chat = MagicMock() + mock_client.chat.completions = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=_hanging_stream()) + + raw = FakeLLMProvider() + raw.client = mock_client + + provider = OpenAICompatGatedProvider(raw, gate) + + async def _consume(): + async for _ in provider.stream("ping"): + pass + + task = asyncio.create_task(_consume()) + # Give the stream one iteration to start. + await asyncio.sleep(0.02) + task.cancel() + import contextlib + with contextlib.suppress(asyncio.CancelledError, TimeoutError): + await asyncio.wait_for(task, timeout=1.0) + + tps = gate._binding.snapshot_tokens_per_second() + assert tps == 0.0, f"expected no tokens recorded on cancellation, got tps={tps}" + + await reg.close() + + +# ────────────────────────────────────────────────────────────────────────────── +# 8. close() cancels and awaits the aggregator task + + +@pytest.mark.asyncio +async def test_aggregator_task_cancelled_on_registry_close() -> None: + """close() cancels the aggregator and the task is done afterward.""" + reg = _registry(metrics_interval_seconds=9999.0) + task = reg._aggregator_task + assert not task.done(), "aggregator task should be running before close()" + + await reg.close() + + assert task.done(), "aggregator task should be done after close()" + # Closing a second time must be idempotent (no error). + await reg.close() diff --git a/workers/tests/test_concurrency_probe.py b/workers/tests/test_concurrency_probe.py index ece39acd..1366de06 100644 --- a/workers/tests/test_concurrency_probe.py +++ b/workers/tests/test_concurrency_probe.py @@ -7,7 +7,6 @@ import pytest from workers.common.llm.concurrency_probe import ( - ProbeBackend, probe_concurrency, run_startup_probe, ) diff --git a/workers/tests/test_embedding_config.py b/workers/tests/test_embedding_config.py index d69dce4a..dc49b012 100644 --- a/workers/tests/test_embedding_config.py +++ b/workers/tests/test_embedding_config.py @@ -13,22 +13,26 @@ import pytest from workers.common.config import SUPPORTED_EMBEDDING_PROVIDERS, WorkerConfig +from workers.common.embedding.concurrency import ConcurrencyGatedEmbeddingProvider from workers.common.embedding.config import create_embedding_provider +from workers.common.llm.concurrency import ConcurrencyConfig, ProviderGateRegistry -def test_create_embedding_provider_rejects_unknown_with_actionable_message() -> None: +@pytest.mark.asyncio +async def test_create_embedding_provider_rejects_unknown_with_actionable_message() -> None: cfg = WorkerConfig(embedding_provider="ollama") # Bypass the validator the same way per-request overrides do. bypassed = cfg.model_copy(update={"embedding_provider": "totally-fake"}) with pytest.raises(ValueError) as exc_info: - create_embedding_provider(bypassed) + await create_embedding_provider(bypassed) msg = str(exc_info.value) assert "totally-fake" in msg for provider in SUPPORTED_EMBEDDING_PROVIDERS: assert repr(provider) in msg, f"supported provider {provider} not surfaced in error: {msg}" -def test_create_embedding_provider_anthropic_gets_specific_hint() -> None: +@pytest.mark.asyncio +async def test_create_embedding_provider_anthropic_gets_specific_hint() -> None: """The tester-report footgun: setting embedding_provider=anthropic is reasonable on first read of the README. The factory error surface must explicitly explain why anthropic isn't here, not just @@ -36,7 +40,7 @@ def test_create_embedding_provider_anthropic_gets_specific_hint() -> None: cfg = WorkerConfig(embedding_provider="ollama") bypassed = cfg.model_copy(update={"embedding_provider": "anthropic"}) with pytest.raises(ValueError) as exc_info: - create_embedding_provider(bypassed) + await create_embedding_provider(bypassed) msg = str(exc_info.value) assert "Anthropic" in msg assert "embeddings API" in msg @@ -45,7 +49,8 @@ def test_create_embedding_provider_anthropic_gets_specific_hint() -> None: assert "openai" in msg -def test_create_embedding_provider_no_longer_raises_notimplementederror() -> None: +@pytest.mark.asyncio +async def test_create_embedding_provider_no_longer_raises_notimplementederror() -> None: """Pre-fix the factory raised NotImplementedError. New behavior: raises ValueError. Pinning the type is a regression guard so a well-meaning future refactor doesn't revert to NotImplementedError @@ -54,23 +59,49 @@ def test_create_embedding_provider_no_longer_raises_notimplementederror() -> Non cfg = WorkerConfig(embedding_provider="ollama") bypassed = cfg.model_copy(update={"embedding_provider": "fake"}) with pytest.raises(ValueError): - create_embedding_provider(bypassed) + await create_embedding_provider(bypassed) # And explicitly NOT NotImplementedError. try: - create_embedding_provider(bypassed) + await create_embedding_provider(bypassed) except NotImplementedError: pytest.fail("create_embedding_provider must not raise NotImplementedError") except ValueError: pass -def test_create_embedding_provider_test_mode_unaffected() -> None: +@pytest.mark.asyncio +async def test_create_embedding_provider_test_mode_unaffected() -> None: """test_mode short-circuits to FakeEmbeddingProvider before the provider check. Pinning so a future refactor that moves the check earlier doesn't break the test fixture path that thousands of unit tests already use.""" cfg = WorkerConfig(test_mode=True, embedding_provider="ollama") - provider = create_embedding_provider(cfg) + provider = await create_embedding_provider(cfg) # FakeEmbeddingProvider is the contract; assert it's a non-None # provider instance at minimum. assert provider is not None + + +# ─── Phase 2 gate-wiring tests ─────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_embedding_provider_is_wrapped_when_gate_registry_supplied() -> None: + """create_embedding_provider with a registry returns a gated provider.""" + registry = ProviderGateRegistry(ConcurrencyConfig()) + cfg = WorkerConfig(embedding_provider="ollama") + provider = await create_embedding_provider(cfg, gate_registry=registry) + assert isinstance(provider, ConcurrencyGatedEmbeddingProvider) + await registry.close() + + +@pytest.mark.asyncio +async def test_embedding_provider_not_wrapped_when_kill_switch_off(monkeypatch) -> None: + """Kill switch off → raw provider returned unchanged.""" + monkeypatch.setenv("SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED", "false") + config = ConcurrencyConfig.from_env() + registry = ProviderGateRegistry(config) + cfg = WorkerConfig(embedding_provider="ollama") + provider = await create_embedding_provider(cfg, gate_registry=registry) + assert not isinstance(provider, ConcurrencyGatedEmbeddingProvider) + await registry.close() diff --git a/workers/tests/test_grpc_auth_interceptor.py b/workers/tests/test_grpc_auth_interceptor.py index ddc4d9a9..c40427d8 100644 --- a/workers/tests/test_grpc_auth_interceptor.py +++ b/workers/tests/test_grpc_auth_interceptor.py @@ -6,7 +6,6 @@ from workers.__main__ import _GrpcAuthInterceptor, _is_non_loopback - # --------------------------------------------------------------------------- # _is_non_loopback helper # --------------------------------------------------------------------------- diff --git a/workers/tests/test_hierarchical_strategy.py b/workers/tests/test_hierarchical_strategy.py index 4be870c4..bbaf5160 100644 --- a/workers/tests/test_hierarchical_strategy.py +++ b/workers/tests/test_hierarchical_strategy.py @@ -348,13 +348,18 @@ async def complete(self, prompt, **kwargs): provider = _BoundedProvider() strategy = HierarchicalStrategy( provider, - HierarchicalConfig(repository_name="toy", leaf_concurrency=3), + HierarchicalConfig( + repository_name="toy", + leaf_concurrency=3, + file_concurrency=3, # Phase 3: default raised to 4; pin to 3 for this test + package_concurrency=3, # so the peak bound remains meaningful + ), ) corpus = _ToyCorpus() await strategy.build_tree(corpus) - # The concurrency semaphore gates leaf calls only. File/package/root - # calls run sequentially, so the peak is bounded by leaf_concurrency. + # The leaf, file, and package semaphores each cap to 3 concurrent calls. + # Peak in-flight is bounded by the configured concurrency. assert peak <= 3 diff --git a/workers/tests/test_llm_config.py b/workers/tests/test_llm_config.py index 798670fb..13afd17f 100644 --- a/workers/tests/test_llm_config.py +++ b/workers/tests/test_llm_config.py @@ -1,11 +1,14 @@ import pytest +import pytest_asyncio # noqa: F401 — registers asyncio mode from workers.common.config import SUPPORTED_LLM_PROVIDERS, WorkerConfig from workers.common.grpc_metadata import RuntimeLLMOverride, resolve_llm_override +from workers.common.llm.concurrency import ConcurrencyConfig, ConcurrencyGatedProvider, ProviderGateRegistry from workers.common.llm.config import ( _resolve_disable_thinking, create_llm_provider, create_llm_provider_for_request, + create_report_provider, ) @@ -68,7 +71,8 @@ def test_resolve_llm_override_ignores_invalid_timeout(): assert override.timeout_seconds == 0 -def test_create_llm_provider_for_request_passes_timeout_to_client(): +@pytest.mark.asyncio +async def test_create_llm_provider_for_request_passes_timeout_to_client(gate_registry): """End-to-end check: a request-scoped timeout reaches the HTTP client.""" cfg = WorkerConfig( llm_provider="openai", @@ -76,20 +80,23 @@ def test_create_llm_provider_for_request_passes_timeout_to_client(): llm_model="gpt-4o", llm_timeout=900, ) - provider, model = create_llm_provider_for_request( + provider, model = await create_llm_provider_for_request( cfg, provider="openai", model="gpt-4o", api_key="test", timeout_seconds=1800, + gate_registry=gate_registry, ) # OpenAICompatProvider stores the effective timeout on the instance - # for downstream visibility. - assert getattr(provider, "timeout", None) == 1800.0 + # for downstream visibility. When wrapped, unwrap to access raw attrs. + raw = getattr(provider, "_raw", provider) + assert getattr(raw, "timeout", None) == 1800.0 assert model == "gpt-4o" -def test_create_llm_provider_for_request_falls_back_to_bootstrap_timeout(): +@pytest.mark.asyncio +async def test_create_llm_provider_for_request_falls_back_to_bootstrap_timeout(gate_registry): """No per-request override → worker's bootstrap llm_timeout wins.""" cfg = WorkerConfig( llm_provider="openai", @@ -97,14 +104,16 @@ def test_create_llm_provider_for_request_falls_back_to_bootstrap_timeout(): llm_model="gpt-4o", llm_timeout=900, ) - provider, _ = create_llm_provider_for_request( + provider, _ = await create_llm_provider_for_request( cfg, provider="openai", model="gpt-4o", api_key="test", timeout_seconds=0, + gate_registry=gate_registry, ) - assert getattr(provider, "timeout", None) == 900.0 + raw = getattr(provider, "_raw", provider) + assert getattr(raw, "timeout", None) == 900.0 def test_runtime_override_is_empty_when_only_default_timeout(): @@ -116,7 +125,8 @@ def test_runtime_override_is_empty_when_only_default_timeout(): # ─── Factory defense in depth (CA-125) ─────────────────────────────── -def test_create_llm_provider_rejects_unknown_provider_with_actionable_message(): +@pytest.mark.asyncio +async def test_create_llm_provider_rejects_unknown_provider_with_actionable_message(): """create_llm_provider catches unknown providers reaching it via paths that bypass the WorkerConfig validator (notably config.model_copy(update={'llm_provider': '...'}) used by @@ -131,10 +141,66 @@ def test_create_llm_provider_rejects_unknown_provider_with_actionable_message(): # Bypass the validator the same way per-request overrides do. bypassed = cfg.model_copy(update={"llm_provider": "totally-fake"}) with pytest.raises(ValueError) as exc_info: - create_llm_provider(bypassed) + await create_llm_provider(bypassed) msg = str(exc_info.value) assert "totally-fake" in msg # Every supported provider must be named so the user knows what to # switch to. for provider in SUPPORTED_LLM_PROVIDERS: assert repr(provider) in msg, f"supported provider {provider} not surfaced in error: {msg}" + + +# ─── Phase 2 gate-wiring tests ─────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_provider_is_wrapped_when_kill_switch_enabled(gate_registry): + """create_llm_provider with a registry returns a ConcurrencyGatedProvider.""" + cfg = WorkerConfig(llm_provider="openai", llm_api_key="test", llm_model="gpt-4o") + provider = await create_llm_provider(cfg, gate_registry=gate_registry) + assert isinstance(provider, ConcurrencyGatedProvider) + + +@pytest.mark.asyncio +async def test_kill_switch_disables_wrapper(monkeypatch): + """When SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED=false the raw provider is returned.""" + monkeypatch.setenv("SOURCEBRIDGE_LLM_CONCURRENCY_WRAPPER_ENABLED", "false") + config = ConcurrencyConfig.from_env() + assert not config.wrapper_enabled + registry = ProviderGateRegistry(config) + cfg = WorkerConfig(llm_provider="openai", llm_api_key="test", llm_model="gpt-4o") + provider = await create_llm_provider(cfg, gate_registry=registry) + assert not isinstance(provider, ConcurrencyGatedProvider) + await registry.close() + + +@pytest.mark.asyncio +async def test_create_llm_provider_for_request_forwards_registry(gate_registry): + """create_llm_provider_for_request wraps the returned provider when registry is supplied.""" + cfg = WorkerConfig(llm_provider="openai", llm_api_key="test", llm_model="gpt-4o") + provider, model = await create_llm_provider_for_request( + cfg, + provider="openai", + model="gpt-4o", + api_key="test", + gate_registry=gate_registry, + ) + assert isinstance(provider, ConcurrencyGatedProvider) + assert model == "gpt-4o" + + +@pytest.mark.asyncio +async def test_create_report_provider_uses_same_gate_for_same_endpoint(gate_registry): + """Report provider and main provider share a gate when pointing at the same endpoint.""" + cfg = WorkerConfig( + llm_provider="ollama", + llm_model="qwen3:7b", + llm_report_provider="ollama", + llm_report_model="qwen3:14b", + ) + main_prov = await create_llm_provider(cfg, gate_registry=gate_registry) + report_prov = await create_report_provider(cfg, gate_registry=gate_registry) + assert isinstance(main_prov, ConcurrencyGatedProvider) + assert isinstance(report_prov, ConcurrencyGatedProvider) + # Both point at the same Ollama host → same underlying _HostGate binding. + assert main_prov._gate._binding is report_prov._gate._binding diff --git a/workers/tests/test_servicer_gate_snapshot.py b/workers/tests/test_servicer_gate_snapshot.py new file mode 100644 index 00000000..b24c1cba --- /dev/null +++ b/workers/tests/test_servicer_gate_snapshot.py @@ -0,0 +1,131 @@ +"""Tests for ReasoningServicer.GetLLMGateSnapshot (Phase 7). + +Verifies: + - A host-gated registry with active LLM + embedding kind counters returns + two rows sharing max_concurrent and tokens_per_second. + - A fresh registry with no gates returns an empty list. + - No gate_registry wired → empty response (not an error). + +Refs: CA-169 / plan v4 Phase 7 Verification list. +""" + +from __future__ import annotations + +import pytest +import pytest_asyncio # noqa: F401 +from reasoning.v1 import reasoning_pb2 + +from workers.common.embedding.fake import FakeEmbeddingProvider +from workers.common.llm.concurrency import ConcurrencyConfig, ProviderGateRegistry +from workers.reasoning.servicer import ReasoningServicer + +# ────────────────────────────────────────────────────────────────────────────── +# Fixtures + + +@pytest.fixture +def embedding(): + return FakeEmbeddingProvider(dimension=1024) + + +@pytest_asyncio.fixture +async def gate_registry(): + """Fresh ProviderGateRegistry with a very long aggregator interval so it + never fires unexpectedly during tests.""" + cfg = ConcurrencyConfig(metrics_interval_seconds=9999.0) + reg = ProviderGateRegistry(cfg) + yield reg + await reg.close() + + +@pytest.fixture +def servicer_with_registry(llm, embedding, gate_registry): + return ReasoningServicer(llm, embedding, gate_registry=gate_registry) + + +@pytest.fixture +def servicer_no_registry(llm, embedding): + """Servicer without a gate_registry (kill-switch-off scenario).""" + return ReasoningServicer(llm, embedding) + + +# ────────────────────────────────────────────────────────────────────────────── +# Tests + + +@pytest.mark.asyncio +async def test_get_llm_gate_snapshot_returns_all_active_gates( + servicer_with_registry, gate_registry, context +): + """Ollama host gate with LLM + embedding kind counters → two rows sharing + max_concurrent and tokens_per_second. + + Plan Phase 7 Verification: "an Ollama daemon with active LLM and embedding + kinds shows two rows sharing one max_concurrent and one tokens_per_second." + """ + # Register both LLM and embedding kind counters for the same Ollama host. + await gate_registry.lookup("ollama", "http://localhost:11434/v1", "llm") + await gate_registry.lookup("ollama", "http://localhost:11434", "embedding") + + request = reasoning_pb2.GetLLMGateSnapshotRequest() + response = await servicer_with_registry.GetLLMGateSnapshot(request, context) + + assert isinstance(response, reasoning_pb2.GetLLMGateSnapshotResponse) + entries = list(response.gates) + assert len(entries) == 2, f"expected 2 gate entries, got {len(entries)}: {entries}" + + # Both rows should share the same provider and normalized origin. + assert all(e.provider == "ollama" for e in entries) + assert all(e.base_url_normalized == "http://localhost:11434" for e in entries) + + # Kinds must be distinct (one llm, one embedding). + kinds = {e.kind for e in entries} + assert kinds == {"llm", "embedding"}, f"expected {{llm, embedding}}, got {kinds}" + + # Both rows share the same max_concurrent (Decision 5b: host gate is one semaphore). + max_concurrent_values = {e.max_concurrent for e in entries} + assert len(max_concurrent_values) == 1, ( + f"host-gated rows must share max_concurrent, got {max_concurrent_values}" + ) + + # Both rows share the same tokens_per_second (from the shared ring buffer). + tps_values = {e.tokens_per_second for e in entries} + assert len(tps_values) == 1, ( + f"host-gated rows must share tokens_per_second, got {tps_values}" + ) + + # All numeric fields are non-negative. + for e in entries: + assert e.in_flight >= 0 + assert e.queued >= 0 + assert e.retries_since_start >= 0 + assert e.recent_429_count >= 0 + assert e.tokens_per_second >= 0.0 + + +@pytest.mark.asyncio +async def test_get_llm_gate_snapshot_empty_when_no_gates( + servicer_with_registry, context +): + """A fresh registry with no lookups yet → empty gate list.""" + request = reasoning_pb2.GetLLMGateSnapshotRequest() + response = await servicer_with_registry.GetLLMGateSnapshot(request, context) + + assert isinstance(response, reasoning_pb2.GetLLMGateSnapshotResponse) + assert len(response.gates) == 0, ( + f"expected empty gates for fresh registry, got {len(response.gates)}" + ) + + +@pytest.mark.asyncio +async def test_get_llm_gate_snapshot_no_registry_returns_empty( + servicer_no_registry, context +): + """No gate_registry wired (kill-switch-off) → empty response without error.""" + request = reasoning_pb2.GetLLMGateSnapshotRequest() + response = await servicer_no_registry.GetLLMGateSnapshot(request, context) + + assert isinstance(response, reasoning_pb2.GetLLMGateSnapshotResponse) + assert len(response.gates) == 0, ( + f"expected empty gates when no registry, got {len(response.gates)}" + ) diff --git a/workers/tests/test_servicer_utils.py b/workers/tests/test_servicer_utils.py index 0057809c..863fc834 100644 --- a/workers/tests/test_servicer_utils.py +++ b/workers/tests/test_servicer_utils.py @@ -65,7 +65,7 @@ def test_default_path_returns_default_llm(): """No overrides in metadata → default llm returned.""" llm = FakeLLMProvider() context = _MockContext() - provider, model = resolve_provider_for_context(llm, None, context) + provider, model, _ = resolve_provider_for_context(llm, None, context) assert provider is llm assert model is None @@ -75,7 +75,7 @@ def test_default_path_with_config(): llm = FakeLLMProvider() config = _MockConfig() context = _MockContext() - provider, model = resolve_provider_for_context(llm, config, context) + provider, model, _ = resolve_provider_for_context(llm, config, context) assert provider is llm assert model is None @@ -94,7 +94,7 @@ def test_model_override_no_config(): """ llm = FakeLLMProvider() context = _MockContext({"x-sb-model": "claude-3-5-sonnet"}) - provider, model = resolve_provider_for_context(llm, None, context) + provider, model, _ = resolve_provider_for_context(llm, None, context) assert provider is llm assert model == "claude-3-5-sonnet" @@ -109,7 +109,7 @@ def test_fallback_llm_used_when_no_override(): default_llm = FakeLLMProvider() report_llm = FakeLLMProvider() context = _MockContext() - provider, model = resolve_provider_for_context(default_llm, None, context, fallback_llm=report_llm) + provider, model, _ = resolve_provider_for_context(default_llm, None, context, fallback_llm=report_llm) assert provider is report_llm assert model is None @@ -120,7 +120,7 @@ def test_fallback_llm_uses_report_model_from_config(): report_llm = FakeLLMProvider() config = _MockConfig(llm_report_model="gpt-4o") context = _MockContext() - provider, model = resolve_provider_for_context(default_llm, config, context, fallback_llm=report_llm) + provider, model, _ = resolve_provider_for_context(default_llm, config, context, fallback_llm=report_llm) assert provider is report_llm assert model == "gpt-4o" @@ -136,7 +136,7 @@ def test_fallback_llm_model_override_wins_over_config_report_model(): report_llm = FakeLLMProvider() # No config: the full-override branch with config=None returns fallback_llm + model context = _MockContext({"x-sb-model": "claude-3-haiku"}) - provider, model = resolve_provider_for_context(default_llm, None, context, fallback_llm=report_llm) + provider, model, _ = resolve_provider_for_context(default_llm, None, context, fallback_llm=report_llm) assert provider is report_llm assert model == "claude-3-haiku" @@ -149,7 +149,7 @@ def test_no_fallback_llm_model_override_no_config(): """ llm = FakeLLMProvider() context = _MockContext({"x-sb-model": "gpt-3.5-turbo"}) - provider, model = resolve_provider_for_context(llm, None, context) + provider, model, _ = resolve_provider_for_context(llm, None, context) assert provider is llm assert model == "gpt-3.5-turbo" @@ -163,7 +163,7 @@ def test_full_override_no_config_returns_default_llm(): """Full LLM override present but no config → default llm, override model.""" llm = FakeLLMProvider() context = _MockContext({"x-sb-llm-provider": "anthropic", "x-sb-model": "claude-3-5-sonnet"}) - provider, model = resolve_provider_for_context(llm, None, context) + provider, model, _ = resolve_provider_for_context(llm, None, context) # No config → cannot build a fresh provider; falls back to llm assert provider is llm assert model == "claude-3-5-sonnet" @@ -174,7 +174,7 @@ def test_full_override_no_config_fallback_llm_returned(): default_llm = FakeLLMProvider() report_llm = FakeLLMProvider() context = _MockContext({"x-sb-llm-provider": "anthropic", "x-sb-model": "claude-3-5-sonnet"}) - provider, model = resolve_provider_for_context(default_llm, None, context, fallback_llm=report_llm) + provider, model, _ = resolve_provider_for_context(default_llm, None, context, fallback_llm=report_llm) assert provider is report_llm assert model == "claude-3-5-sonnet" diff --git a/workers/uv.lock b/workers/uv.lock index 6948aaa0..b4e09eb3 100644 --- a/workers/uv.lock +++ b/workers/uv.lock @@ -2,6 +2,15 @@ version = 1 revision = 3 requires-python = ">=3.12" +[[package]] +name = "aiolimiter" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/23/b52debf471f7a1e42e362d959a3982bdcb4fe13a5d46e63d28868807a79c/aiolimiter-1.2.1.tar.gz", hash = "sha256:e02a37ea1a855d9e832252a105420ad4d15011505512a1a1d814647451b5cca9", size = 7185, upload-time = "2024-12-08T15:31:51.496Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/ba/df6e8e1045aebc4778d19b8a3a9bc1808adb1619ba94ca354d9ba17d86c3/aiolimiter-1.2.1-py3-none-any.whl", hash = "sha256:d3f249e9059a20badcb56b61601a83556133655c11d1eb3dd3e04ff069e5f3c7", size = 6711, upload-time = "2024-12-08T15:31:49.874Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -814,6 +823,7 @@ name = "sourcebridge-worker" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "aiolimiter" }, { name = "anthropic" }, { name = "grpcio" }, { name = "grpcio-health-checking" }, @@ -826,6 +836,7 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pyyaml" }, { name = "structlog" }, + { name = "tenacity" }, ] [package.dev-dependencies] @@ -840,6 +851,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiolimiter", specifier = ">=1.2.1" }, { name = "anthropic", specifier = ">=0.49.0" }, { name = "grpcio", specifier = ">=1.62.0" }, { name = "grpcio-health-checking", specifier = ">=1.62.0" }, @@ -852,6 +864,7 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.2.0" }, { name = "pyyaml", specifier = ">=6.0.2" }, { name = "structlog", specifier = ">=24.1.0" }, + { name = "tenacity", specifier = ">=8.5.0" }, ] [package.metadata.requires-dev] @@ -873,6 +886,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/45/a132b9074aa18e799b891b91ad72133c98d8042c70f6240e4c5f9dabee2f/structlog-25.5.0-py3-none-any.whl", hash = "sha256:a8453e9b9e636ec59bd9e79bbd4a72f025981b3ba0f5837aebf48f02f37a7f9f", size = 72510, upload-time = "2025-10-27T08:28:21.535Z" }, ] +[[package]] +name = "tenacity" +version = "9.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, +] + [[package]] name = "tqdm" version = "4.67.3"