From a570c3cc13f51ceff3d42dc37509fefa1cf793b2 Mon Sep 17 00:00:00 2001 From: Grant McCloskey Date: Wed, 20 May 2026 18:53:46 +0000 Subject: [PATCH] feat(scheduler): optimize worker scheduling to O(1) using Valkey/Redis Sets --- .../ateapi/controlapi/workflow_resume.go | 60 +++++---------- cmd/servers/ateapi/store/ateredis/ateredis.go | 76 +++++++++++++++++++ cmd/servers/ateapi/store/store.go | 4 + 3 files changed, 99 insertions(+), 41 deletions(-) diff --git a/cmd/servers/ateapi/controlapi/workflow_resume.go b/cmd/servers/ateapi/controlapi/workflow_resume.go index 744422a..f1459f9 100644 --- a/cmd/servers/ateapi/controlapi/workflow_resume.go +++ b/cmd/servers/ateapi/controlapi/workflow_resume.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "log/slog" - "math/rand" "time" atev1alpha1 "github.com/agent-substrate/substrate/api/v1alpha1" @@ -85,38 +84,34 @@ func (s *AssignWorkerStep) IsComplete(ctx context.Context, input *ResumeInput, s return state.Actor.GetStatus() == ateapipb.Actor_STATUS_RUNNING, nil } func (s *AssignWorkerStep) Execute(ctx context.Context, input *ResumeInput, state *ResumeState) error { - workers, err := s.store.ListWorkers(ctx) - if err != nil { - return fmt.Errorf("while listing workers: %w", err) - } - var assignedWorker *ateapipb.Worker - // Check if we already have a worker assigned from a previous failed attempt - for _, worker := range workers { - if worker.GetActorId() == input.ActorID && worker.GetWorkerPool() == state.ActorTemplate.Spec.WorkerPoolRef.Name && worker.GetWorkerNamespace() == state.ActorTemplate.Spec.WorkerPoolRef.Namespace { + // Re-use previously assigned worker if available. + if state.Actor.GetAteomPodName() != "" { + worker, err := s.store.GetWorker(ctx, state.Actor.GetAteomPodNamespace(), state.ActorTemplate.Spec.WorkerPoolRef.Name, state.Actor.GetAteomPodName()) + if err == nil && worker.GetActorId() == input.ActorID { assignedWorker = worker - break } } - // If not, find a free one using randomized shuffling + // Claim a new idle worker. if assignedWorker == nil { - pickedWorker := s.findFreeWorker(workers, state.ActorTemplate.Spec.WorkerPoolRef.Namespace, state.ActorTemplate.Spec.WorkerPoolRef.Name) - if pickedWorker == nil { - return status.Errorf(codes.FailedPrecondition, "no free workers available") + pickedWorker, err := s.store.ClaimIdleWorker( + ctx, + state.ActorTemplate.Spec.WorkerPoolRef.Namespace, + state.ActorTemplate.Spec.WorkerPoolRef.Name, + input.ActorID, + state.Actor.GetActorTemplateNamespace(), + state.Actor.GetActorTemplateName(), + ) + if err != nil { + if errors.Is(err, store.ErrNotFound) { + return status.Errorf(codes.FailedPrecondition, "no free workers available") + } + return fmt.Errorf("while claiming idle worker: %w", err) } - assignedWorker = pickedWorker - slog.InfoContext(ctx, "Picked worker", slog.Any("worker", pickedWorker.String())) - } - - assignedWorker.ActorId = input.ActorID - assignedWorker.ActorNamespace = state.Actor.GetActorTemplateNamespace() - assignedWorker.ActorTemplate = state.Actor.GetActorTemplateName() - - if err := s.store.UpdateWorker(ctx, assignedWorker, assignedWorker.Version); err != nil { - return err + slog.InfoContext(ctx, "Claimed idle worker", slog.Any("worker", pickedWorker.String())) } state.Actor.Status = ateapipb.Actor_STATUS_RESUMING @@ -139,23 +134,6 @@ func (s *AssignWorkerStep) RetryBackoff() *wait.Backoff { } } -func (s *AssignWorkerStep) findFreeWorker(workers []*ateapipb.Worker, workerPoolNamespace, workerPoolName string) *ateapipb.Worker { - var freeWorkers []*ateapipb.Worker - for _, worker := range workers { - if worker.GetActorId() == "" && worker.GetWorkerPool() == workerPoolName && worker.GetWorkerNamespace() == workerPoolNamespace { - freeWorkers = append(freeWorkers, worker) - } - } - - if len(freeWorkers) > 0 { - rand.Shuffle(len(freeWorkers), func(i, j int) { - freeWorkers[i], freeWorkers[j] = freeWorkers[j], freeWorkers[i] - }) - return freeWorkers[0] - } - return nil -} - type CallAteletRestoreStep struct { dialer *AteletDialer } diff --git a/cmd/servers/ateapi/store/ateredis/ateredis.go b/cmd/servers/ateapi/store/ateredis/ateredis.go index ebde322..f77f08b 100644 --- a/cmd/servers/ateapi/store/ateredis/ateredis.go +++ b/cmd/servers/ateapi/store/ateredis/ateredis.go @@ -164,6 +164,13 @@ func (s *Persistence) CreateWorker(ctx context.Context, worker *ateapipb.Worker) return store.ErrAlreadyExists } + // Add to the idle set. + setKey := fmt.Sprintf("pool:%s:%s:idle_workers", worker.GetWorkerNamespace(), worker.GetWorkerPool()) + err = s.rdb.SAdd(ctx, setKey, worker.GetWorkerPod()).Err() + if err != nil { + return fmt.Errorf("while registering worker in idle set: %w", err) + } + return nil } @@ -192,6 +199,7 @@ func (s *Persistence) GetWorker(ctx context.Context, namespace, pool, pod string func (s *Persistence) UpdateWorker(ctx context.Context, worker *ateapipb.Worker, expectedVersion int64) error { dbKey := workerDBKey(worker.GetWorkerNamespace(), worker.GetWorkerPool(), worker.GetWorkerPod()) + var shouldAddToIdle bool // Clone because we will update the version field, and we don't want to // stomp the caller's copy. @@ -228,6 +236,10 @@ func (s *Persistence) UpdateWorker(ctx context.Context, worker *ateapipb.Worker, return fmt.Errorf("ip is immutable") } + if currentWorker.GetActorId() != "" && dbWorker.GetActorId() == "" { + shouldAddToIdle = true + } + newVal, err := protojson.Marshal(dbWorker) if err != nil { return fmt.Errorf("in protojson.Marshal: %w", err) @@ -246,15 +258,32 @@ func (s *Persistence) UpdateWorker(ctx context.Context, worker *ateapipb.Worker, return fmt.Errorf("while executing update worker transaction: %w", err) } + // Run SAdd sequentially outside the transaction to avoid cluster slot restrictions. + if shouldAddToIdle { + setKey := fmt.Sprintf("pool:%s:%s:idle_workers", worker.GetWorkerNamespace(), worker.GetWorkerPool()) + err = s.rdb.SAdd(ctx, setKey, worker.GetWorkerPod()).Err() + if err != nil { + return fmt.Errorf("while returning worker to idle set: %w", err) + } + } + return nil } func (s *Persistence) DeleteWorker(ctx context.Context, namespace, pool, pod string) error { dbKey := workerDBKey(namespace, pool, pod) + setKey := fmt.Sprintf("pool:%s:%s:idle_workers", namespace, pool) + err := s.rdb.Del(ctx, dbKey).Err() if err != nil { return fmt.Errorf("while deleting worker key %q: %w", dbKey, err) } + + err = s.rdb.SRem(ctx, setKey, pod).Err() + if err != nil { + return fmt.Errorf("while removing worker from idle set: %w", err) + } + return nil } @@ -452,3 +481,50 @@ func (s *Persistence) ReleaseLock(ctx context.Context, key string, value string) } return nil } + +func (s *Persistence) ClaimIdleWorker(ctx context.Context, namespace, pool string, actorID string, actorNamespace string, actorTemplate string) (*ateapipb.Worker, error) { + setKey := fmt.Sprintf("pool:%s:%s:idle_workers", namespace, pool) + + for { + // Pop a random idle worker name. + podName, err := s.rdb.SPop(ctx, setKey).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, store.ErrNotFound + } + return nil, fmt.Errorf("while popping idle worker from set: %w", err) + } + + worker, err := s.GetWorker(ctx, namespace, pool, podName) + if err != nil { + // If the worker was deleted, skip and pop the next one. + if errors.Is(err, store.ErrNotFound) { + continue + } + _ = s.rdb.SAdd(ctx, setKey, podName).Err() + return nil, fmt.Errorf("while loading popped worker metadata: %w", err) + } + + if worker.GetActorId() != "" { + // Skip busy workers. + continue + } + + worker.ActorId = actorID + worker.ActorNamespace = actorNamespace + worker.ActorTemplate = actorTemplate + + err = s.UpdateWorker(ctx, worker, worker.Version) + if err != nil { + if errors.Is(err, store.ErrPersistenceRetry) { + // Return to the idle set and retry on locking conflict. + _ = s.rdb.SAdd(ctx, setKey, podName).Err() + continue + } + _ = s.rdb.SAdd(ctx, setKey, podName).Err() + return nil, fmt.Errorf("while claiming popped worker: %w", err) + } + + return worker, nil + } +} diff --git a/cmd/servers/ateapi/store/store.go b/cmd/servers/ateapi/store/store.go index ab76aa8..d4126e3 100644 --- a/cmd/servers/ateapi/store/store.go +++ b/cmd/servers/ateapi/store/store.go @@ -69,6 +69,10 @@ type Interface interface { // Lists all known workers. Returns nil if none found. ListWorkers(ctx context.Context) ([]*ateapipb.Worker, error) + // ClaimIdleWorker atomically claims a random idle worker from the specified pool + // and associates it with the given Actor. Returns ErrNotFound if no idle workers are available. + ClaimIdleWorker(ctx context.Context, namespace, pool string, actorID string, actorNamespace string, actorTemplate string) (*ateapipb.Worker, error) + // AcquireLock attempts to acquire a distributed lock with a TTL. // Returns true if the lock was successfully acquired. // Returns false if the lock is already held by another client (conflict).