From 8c966c2100902d67ecc259c0b3640fc2a9636ced Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Tue, 31 Mar 2026 18:31:50 +0200 Subject: [PATCH] feat: add celeristest.WithHandlers for middleware chain testing (#139) --- celeristest/celeristest.go | 16 ++++++++++++++++ context.go | 9 +++++++++ internal/ctxkit/ctxkit.go | 1 + 3 files changed, 26 insertions(+) diff --git a/celeristest/celeristest.go b/celeristest/celeristest.go index be7e3b0..fe49083 100644 --- a/celeristest/celeristest.go +++ b/celeristest/celeristest.go @@ -83,6 +83,7 @@ type config struct { params [][2]string cookies [][2]string remoteAddr string + handlers []any } // WithBody sets the request body. @@ -128,6 +129,18 @@ func WithRemoteAddr(addr string) Option { return func(c *config) { c.remoteAddr = addr } } +// WithHandlers sets the handler chain on the test context. This enables +// middleware chain testing where mw1 calls c.Next() → mw2 runs → ... → final handler. +// Pass celeris.HandlerFunc values; they are stored as []any to avoid import cycles. +func WithHandlers(handlers ...celeris.HandlerFunc) Option { + return func(c *config) { + c.handlers = make([]any, len(handlers)) + for i, h := range handlers { + c.handlers[i] = h + } + } +} + // ReleaseContext returns a [celeris.Context] to the pool. The context must not // be used after this call. func ReleaseContext(ctx *celeris.Context) { @@ -191,5 +204,8 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons for _, p := range cfg.params { ctxkit.AddParam(ctx, p[0], p[1]) } + if len(cfg.handlers) > 0 { + ctxkit.SetHandlers(ctx, cfg.handlers) + } return ctx, rec } diff --git a/context.go b/context.go index 707eb47..80099df 100644 --- a/context.go +++ b/context.go @@ -23,6 +23,15 @@ func init() { ctx := c.(*Context) ctx.params = append(ctx.params, Param{Key: key, Value: value}) } + ctxkit.SetHandlers = func(c any, handlers []any) { + ctx := c.(*Context) + chain := make([]HandlerFunc, len(handlers)) + for i, h := range handlers { + chain[i] = h.(HandlerFunc) + } + ctx.handlers = chain + ctx.index = -1 + } } // Context is the request context passed to handlers. It is pooled via sync.Pool. diff --git a/internal/ctxkit/ctxkit.go b/internal/ctxkit/ctxkit.go index 453287d..e0af1a9 100644 --- a/internal/ctxkit/ctxkit.go +++ b/internal/ctxkit/ctxkit.go @@ -10,4 +10,5 @@ var ( NewContext func(s *stream.Stream) any ReleaseContext func(c any) AddParam func(c any, key, value string) + SetHandlers func(c any, handlers []any) )