Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 19 additions & 41 deletions cmd/servers/ateapi/controlapi/workflow_resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"errors"
"fmt"
"log/slog"
"math/rand"
"time"

atev1alpha1 "github.com/agent-substrate/substrate/api/v1alpha1"
Expand Down Expand Up @@ -85,38 +84,34 @@ func (s *AssignWorkerStep) IsComplete(ctx context.Context, input *ResumeInput, s
return state.Actor.GetStatus() == ateapipb.Actor_STATUS_RUNNING, nil
}
func (s *AssignWorkerStep) Execute(ctx context.Context, input *ResumeInput, state *ResumeState) error {
workers, err := s.store.ListWorkers(ctx)
if err != nil {
return fmt.Errorf("while listing workers: %w", err)
}

var assignedWorker *ateapipb.Worker

// Check if we already have a worker assigned from a previous failed attempt
for _, worker := range workers {
if worker.GetActorId() == input.ActorID && worker.GetWorkerPool() == state.ActorTemplate.Spec.WorkerPoolRef.Name && worker.GetWorkerNamespace() == state.ActorTemplate.Spec.WorkerPoolRef.Namespace {
// Re-use previously assigned worker if available.
if state.Actor.GetAteomPodName() != "" {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will never be non-empty during resume (except maybe on retry of a failed resume?)

worker, err := s.store.GetWorker(ctx, state.Actor.GetAteomPodNamespace(), state.ActorTemplate.Spec.WorkerPoolRef.Name, state.Actor.GetAteomPodName())
if err == nil && worker.GetActorId() == input.ActorID {
assignedWorker = worker
break
}
}

// If not, find a free one using randomized shuffling
// Claim a new idle worker.
if assignedWorker == nil {
pickedWorker := s.findFreeWorker(workers, state.ActorTemplate.Spec.WorkerPoolRef.Namespace, state.ActorTemplate.Spec.WorkerPoolRef.Name)
if pickedWorker == nil {
return status.Errorf(codes.FailedPrecondition, "no free workers available")
pickedWorker, err := s.store.ClaimIdleWorker(
ctx,
state.ActorTemplate.Spec.WorkerPoolRef.Namespace,
state.ActorTemplate.Spec.WorkerPoolRef.Name,
input.ActorID,
state.Actor.GetActorTemplateNamespace(),
state.Actor.GetActorTemplateName(),
)
if err != nil {
if errors.Is(err, store.ErrNotFound) {
return status.Errorf(codes.FailedPrecondition, "no free workers available")
}
return fmt.Errorf("while claiming idle worker: %w", err)
}

assignedWorker = pickedWorker
slog.InfoContext(ctx, "Picked worker", slog.Any("worker", pickedWorker.String()))
}

assignedWorker.ActorId = input.ActorID
assignedWorker.ActorNamespace = state.Actor.GetActorTemplateNamespace()
assignedWorker.ActorTemplate = state.Actor.GetActorTemplateName()

if err := s.store.UpdateWorker(ctx, assignedWorker, assignedWorker.Version); err != nil {
return err
slog.InfoContext(ctx, "Claimed idle worker", slog.Any("worker", pickedWorker.String()))
}

state.Actor.Status = ateapipb.Actor_STATUS_RESUMING
Expand All @@ -139,23 +134,6 @@ func (s *AssignWorkerStep) RetryBackoff() *wait.Backoff {
}
}

func (s *AssignWorkerStep) findFreeWorker(workers []*ateapipb.Worker, workerPoolNamespace, workerPoolName string) *ateapipb.Worker {
var freeWorkers []*ateapipb.Worker
for _, worker := range workers {
if worker.GetActorId() == "" && worker.GetWorkerPool() == workerPoolName && worker.GetWorkerNamespace() == workerPoolNamespace {
freeWorkers = append(freeWorkers, worker)
}
}

if len(freeWorkers) > 0 {
rand.Shuffle(len(freeWorkers), func(i, j int) {
freeWorkers[i], freeWorkers[j] = freeWorkers[j], freeWorkers[i]
})
return freeWorkers[0]
}
return nil
}

type CallAteletRestoreStep struct {
dialer *AteletDialer
}
Expand Down
76 changes: 76 additions & 0 deletions cmd/servers/ateapi/store/ateredis/ateredis.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@ func (s *Persistence) CreateWorker(ctx context.Context, worker *ateapipb.Worker)
return store.ErrAlreadyExists
}

// Add to the idle set.
setKey := fmt.Sprintf("pool:%s:%s:idle_workers", worker.GetWorkerNamespace(), worker.GetWorkerPool())
err = s.rdb.SAdd(ctx, setKey, worker.GetWorkerPod()).Err()
if err != nil {
return fmt.Errorf("while registering worker in idle set: %w", err)
}

return nil
}

Expand Down Expand Up @@ -192,6 +199,7 @@ func (s *Persistence) GetWorker(ctx context.Context, namespace, pool, pod string

func (s *Persistence) UpdateWorker(ctx context.Context, worker *ateapipb.Worker, expectedVersion int64) error {
dbKey := workerDBKey(worker.GetWorkerNamespace(), worker.GetWorkerPool(), worker.GetWorkerPod())
var shouldAddToIdle bool

// Clone because we will update the version field, and we don't want to
// stomp the caller's copy.
Expand Down Expand Up @@ -228,6 +236,10 @@ func (s *Persistence) UpdateWorker(ctx context.Context, worker *ateapipb.Worker,
return fmt.Errorf("ip is immutable")
}

if currentWorker.GetActorId() != "" && dbWorker.GetActorId() == "" {
shouldAddToIdle = true
}

newVal, err := protojson.Marshal(dbWorker)
if err != nil {
return fmt.Errorf("in protojson.Marshal: %w", err)
Expand All @@ -246,15 +258,32 @@ func (s *Persistence) UpdateWorker(ctx context.Context, worker *ateapipb.Worker,
return fmt.Errorf("while executing update worker transaction: %w", err)
}

// Run SAdd sequentially outside the transaction to avoid cluster slot restrictions.
if shouldAddToIdle {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be a three phase process.

  1. Mark the worker as free, but not yet returned to the free set.
  2. Return to the free set
  3. Make the worker as fully free.

Otherwise, a crash between steps 1 and 2 will permanently leak the worker from consideration.

With the three-phase approach, we can have an additional background thread sweeping workers that are "free, but not yet returned to the free set" back to the free set.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do have a re-sync process? Will need to double check. But I will make this change. Thanks Taahir.

setKey := fmt.Sprintf("pool:%s:%s:idle_workers", worker.GetWorkerNamespace(), worker.GetWorkerPool())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of our design goals is to let all ActorTemplates share the same WorkerPool to maximize efficiency. In that case, this will mean that we are sending all restore traffic through a single hash bucket (and thus, one single valkey node).

We are going to need to solve this problem, and I'm having difficulty seeing how it could be solved within redis.

However, this is heaps better than the current strategy of "read every worker all the time".

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% I have a much more advanced change that handles locality but because it is much more ossifying I felt this was a decent intermediate step while the datamodel and DB choices were finalized.

err = s.rdb.SAdd(ctx, setKey, worker.GetWorkerPod()).Err()
if err != nil {
return fmt.Errorf("while returning worker to idle set: %w", err)
}
}

return nil
}

func (s *Persistence) DeleteWorker(ctx context.Context, namespace, pool, pod string) error {
dbKey := workerDBKey(namespace, pool, pod)
setKey := fmt.Sprintf("pool:%s:%s:idle_workers", namespace, pool)

err := s.rdb.Del(ctx, dbKey).Err()
if err != nil {
return fmt.Errorf("while deleting worker key %q: %w", dbKey, err)
}

err = s.rdb.SRem(ctx, setKey, pod).Err()
if err != nil {
return fmt.Errorf("while removing worker from idle set: %w", err)
}

return nil
}

Expand Down Expand Up @@ -452,3 +481,50 @@ func (s *Persistence) ReleaseLock(ctx context.Context, key string, value string)
}
return nil
}

func (s *Persistence) ClaimIdleWorker(ctx context.Context, namespace, pool string, actorID string, actorNamespace string, actorTemplate string) (*ateapipb.Worker, error) {
setKey := fmt.Sprintf("pool:%s:%s:idle_workers", namespace, pool)

for {
// Pop a random idle worker name.
podName, err := s.rdb.SPop(ctx, setKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, store.ErrNotFound
}
return nil, fmt.Errorf("while popping idle worker from set: %w", err)
}

worker, err := s.GetWorker(ctx, namespace, pool, podName)
if err != nil {
// If the worker was deleted, skip and pop the next one.
if errors.Is(err, store.ErrNotFound) {
continue
}
_ = s.rdb.SAdd(ctx, setKey, podName).Err()
return nil, fmt.Errorf("while loading popped worker metadata: %w", err)
}

if worker.GetActorId() != "" {
// Skip busy workers.
continue
}

worker.ActorId = actorID
worker.ActorNamespace = actorNamespace
worker.ActorTemplate = actorTemplate

err = s.UpdateWorker(ctx, worker, worker.Version)
if err != nil {
if errors.Is(err, store.ErrPersistenceRetry) {
// Return to the idle set and retry on locking conflict.
_ = s.rdb.SAdd(ctx, setKey, podName).Err()
continue
}
_ = s.rdb.SAdd(ctx, setKey, podName).Err()
return nil, fmt.Errorf("while claiming popped worker: %w", err)
}

return worker, nil
}
}
4 changes: 4 additions & 0 deletions cmd/servers/ateapi/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ type Interface interface {
// Lists all known workers. Returns nil if none found.
ListWorkers(ctx context.Context) ([]*ateapipb.Worker, error)

// ClaimIdleWorker atomically claims a random idle worker from the specified pool
// and associates it with the given Actor. Returns ErrNotFound if no idle workers are available.
ClaimIdleWorker(ctx context.Context, namespace, pool string, actorID string, actorNamespace string, actorTemplate string) (*ateapipb.Worker, error)

// AcquireLock attempts to acquire a distributed lock with a TTL.
// Returns true if the lock was successfully acquired.
// Returns false if the lock is already held by another client (conflict).
Expand Down
Loading