From 68091f98931c33dcc54d4e6d18e4f15e7414dc57 Mon Sep 17 00:00:00 2001 From: "Masih H. Derkani" Date: Wed, 6 Aug 2025 12:10:49 +0100 Subject: [PATCH] Use context to safely break out of loops and cancel tickers Wherever possible, use context to ensure that the code path cannot enter a forever loop as a result of a developer error. This would have the added benefit in some cases of avoiding an additional iteration when context is done. Add defer statements to stop tickers as required. Fix a minor error type check. --- generator/weighted.go | 5 +++-- sender/dispatcher.go | 7 ++++--- sender/worker.go | 16 +++++++++++----- stats/block_collector.go | 3 ++- stats/logger.go | 7 +++++-- stats/user_latency_tracker.go | 5 +++-- 6 files changed, 28 insertions(+), 15 deletions(-) diff --git a/generator/weighted.go b/generator/weighted.go index d1c4452..c423826 100644 --- a/generator/weighted.go +++ b/generator/weighted.go @@ -2,9 +2,10 @@ package generator import ( "context" - "github.com/sei-protocol/sei-load/types" "math/rand" "sync" + + "github.com/sei-protocol/sei-load/types" ) // WeightedCfg is a configuration for a weighted scenarioGenerator. @@ -32,7 +33,7 @@ func (w *weightedGenerator) GenerateInfinite(ctx context.Context) <-chan *types. output := make(chan *types.LoadTx, 10000) go func() { defer close(output) - for { + for ctx.Err() == nil { select { case <-ctx.Done(): return diff --git a/sender/dispatcher.go b/sender/dispatcher.go index 814072e..3e0c5c7 100644 --- a/sender/dispatcher.go +++ b/sender/dispatcher.go @@ -61,7 +61,7 @@ func (d *Dispatcher) Prewarm(ctx context.Context) error { logInterval := 100 // Run prewarm generator until completion - for { + for ctx.Err() == nil { tx, ok := gen.Generate() if !ok { break // Prewarming is complete @@ -87,7 +87,7 @@ func (d *Dispatcher) Prewarm(ctx context.Context) error { // Start begins the dispatcher's transaction generation and sending loop func (d *Dispatcher) Run(ctx context.Context) error { - for { + for ctx.Err() == nil { // Generate a transaction from main generator tx, ok := d.generator.Generate() if !ok { @@ -103,6 +103,7 @@ func (d *Dispatcher) Run(ctx context.Context) error { d.totalSent++ d.mu.Unlock() } + return ctx.Err() } // StartBatch generates and sends a specific number of transactions then stops @@ -126,7 +127,7 @@ func (d *Dispatcher) RunBatch(ctx context.Context, count int) error { d.mu.Unlock() } } - return nil + return ctx.Err() } // GetStats returns dispatcher statistics diff --git a/sender/worker.go b/sender/worker.go index 0c3d6a2..59eba30 100644 --- a/sender/worker.go +++ b/sender/worker.go @@ -3,6 +3,7 @@ package sender import ( "bytes" "context" + "errors" "fmt" "io" "log" @@ -10,6 +11,7 @@ import ( "net/http" "time" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/ethclient" "golang.org/x/time/rate" @@ -111,7 +113,7 @@ func (w *Worker) watchTransactions(ctx context.Context) error { if err != nil { return fmt.Errorf("ethclient.Dial(%q): %w", w.endpoint, err) } - for { + for ctx.Err() == nil { tx, err := utils.Recv(ctx, w.sentTxs) if err != nil { return err @@ -122,17 +124,19 @@ func (w *Worker) watchTransactions(ctx context.Context) error { log.Printf("❌ %v", err) } } + return ctx.Err() } func (w *Worker) waitForReceipt(ctx context.Context, eth *ethclient.Client, tx *types.LoadTx) error { ticker := time.NewTicker(100 * time.Millisecond) - for { + defer ticker.Stop() + for ctx.Err() == nil { if _, err := utils.Recv(ctx, ticker.C); err != nil { return fmt.Errorf("timeout waiting for receipt for tx %s", tx.EthTx.Hash().Hex()) } - receipt, err := eth.TransactionReceipt(context.Background(), tx.EthTx.Hash()) + receipt, err := eth.TransactionReceipt(ctx, tx.EthTx.Hash()) if err != nil { - if err.Error() == "not found" { + if errors.Is(err, ethereum.NotFound) { continue } log.Printf("❌ error getting receipt for tx %s: %v", tx.EthTx.Hash().Hex(), err) @@ -147,11 +151,12 @@ func (w *Worker) waitForReceipt(ctx context.Context, eth *ethclient.Client, tx * } return nil } + return ctx.Err() } // processTransactions is the main worker loop that processes transactions func (w *Worker) processTransactions(ctx context.Context, client *http.Client) error { - for { + for ctx.Err() == nil { tx, err := utils.Recv(ctx, w.txChan) if err != nil { return err @@ -174,6 +179,7 @@ func (w *Worker) processTransactions(ctx context.Context, client *http.Client) e log.Printf("%v", err) } } + return ctx.Err() } // sendTransaction sends a single transaction to the endpoint diff --git a/stats/block_collector.go b/stats/block_collector.go index e65831d..7bf6a35 100644 --- a/stats/block_collector.go +++ b/stats/block_collector.go @@ -69,13 +69,14 @@ func (bc *BlockCollector) Run(ctx context.Context, firstEndpoint string) error { return subErr }) log.Printf("📡 Subscribed to new blocks on %s", wsEndpoint) - for { + for ctx.Err() == nil { header, err := utils.Recv(ctx, headers) if err != nil { return err } bc.processNewBlock(header) } + return ctx.Err() }) } diff --git a/stats/logger.go b/stats/logger.go index ff6cca9..427652e 100644 --- a/stats/logger.go +++ b/stats/logger.go @@ -2,9 +2,10 @@ package stats import ( "context" - "github.com/sei-protocol/sei-load/utils" "log" "time" + + "github.com/sei-protocol/sei-load/utils" ) // Logger handles periodic statistics logging and dry-run transaction printing @@ -26,12 +27,14 @@ func NewLogger(collector *Collector, interval time.Duration, debug bool) *Logger // Start begins periodic statistics logging func (l *Logger) Run(ctx context.Context) error { ticker := time.NewTicker(l.interval) - for { + defer ticker.Stop() + for ctx.Err() == nil { if _, err := utils.Recv(ctx, ticker.C); err != nil { return err } l.logCurrentStats() } + return ctx.Err() } // logCurrentStats logs the current statistics diff --git a/stats/user_latency_tracker.go b/stats/user_latency_tracker.go index d6525de..bd3966c 100644 --- a/stats/user_latency_tracker.go +++ b/stats/user_latency_tracker.go @@ -28,7 +28,7 @@ func NewUserLatencyTracker(interval time.Duration) *UserLatencyTracker { func (ult *UserLatencyTracker) Run(ctx context.Context, endpoint string) error { // Create ticker for the configured interval ticker := time.NewTicker(ult.interval) - + defer ticker.Stop() // Connect to the endpoint client, err := ethclient.Dial(endpoint) if err != nil { @@ -36,7 +36,7 @@ func (ult *UserLatencyTracker) Run(ctx context.Context, endpoint string) error { } defer client.Close() - for { + for ctx.Err() == nil { if _, err := utils.Recv(ctx, ticker.C); err != nil { return err } @@ -45,6 +45,7 @@ func (ult *UserLatencyTracker) Run(ctx context.Context, endpoint string) error { // Continue on error - don't stop the tracker } } + return ctx.Err() } // trackLatency fetches the latest block and calculates user latency statistics