Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func buildHTTPRequest(c *Context) (*http.Request, error) {
return req, nil
}

const maxBridgeResponseBytes = 100 << 20 // 100 MB
const maxBridgeResponseBytes = maxBodySize

var errBridgeResponseTooLarge = errors.New("bridge: response body exceeds 100MB limit")

Expand Down
133 changes: 110 additions & 23 deletions celeristest/celeristest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"encoding/base64"
"net/url"
"strings"
"sync"
"testing"

"github.com/goceleris/celeris"
Expand Down Expand Up @@ -48,18 +49,32 @@ func (r *ResponseRecorder) BodyString() string {
return string(r.Body)
}

// recorderCombo bundles a ResponseRecorder and its recorderWriter in a
// single allocation so they can be pooled together.
type recorderCombo struct {
rec ResponseRecorder
rw recorderWriter
}

var recorderPool = sync.Pool{New: func() any {
combo := &recorderCombo{}
combo.rw.rec = &combo.rec
combo.rw.combo = combo
return combo
}}

// recorderWriter adapts a ResponseRecorder to the internal
// stream.ResponseWriter interface without leaking internal types
// in the public API.
type recorderWriter struct {
rec *ResponseRecorder
rec *ResponseRecorder
combo *recorderCombo
}

func (w *recorderWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error {
w.rec.StatusCode = status
w.rec.Headers = headers
w.rec.Body = make([]byte, len(body))
copy(w.rec.Body, body)
w.rec.Body = append(w.rec.Body[:0], body...)
return nil
}

Expand All @@ -77,13 +92,38 @@ var _ stream.ResponseWriter = (*recorderWriter)(nil)
type Option func(*config)

type config struct {
body []byte
headers [][2]string
queries [][2]string
params [][2]string
cookies [][2]string
remoteAddr string
handlers []any
body []byte
headers [][2]string
queries [][2]string
params [][2]string
cookies [][2]string
remoteAddr string
handlers []any
headersBuf [4][2]string
handlersBuf [4]any
}

var configPool = sync.Pool{New: func() any {
c := &config{}
c.headers = c.headersBuf[:0]
c.handlers = c.handlersBuf[:0]
return c
}}

func (c *config) reset() {
c.body = nil
for i := range c.headers {
c.headers[i] = [2]string{}
}
c.headers = c.headersBuf[:0]
c.queries = nil
c.params = nil
c.cookies = nil
c.remoteAddr = ""
for i := range c.handlers {
c.handlers[i] = nil
}
c.handlers = c.handlersBuf[:0]
}

// WithBody sets the request body.
Expand Down Expand Up @@ -134,17 +174,56 @@ func WithRemoteAddr(addr string) Option {
// Pass celeris.HandlerFunc values; they are stored as []any to avoid import cycles.
func WithHandlers(handlers ...celeris.HandlerFunc) Option {
return func(c *config) {
c.handlers = make([]any, len(handlers))
for i, h := range handlers {
c.handlers[i] = h
n := len(handlers)
if n <= len(c.handlersBuf) {
for i, h := range handlers {
c.handlersBuf[i] = h
}
c.handlers = c.handlersBuf[:n]
} else {
c.handlers = make([]any, n)
for i, h := range handlers {
c.handlers[i] = h
}
}
}
}

// ReleaseContext returns a [celeris.Context] to the pool. The context must not
// be used after this call.
func ReleaseContext(ctx *celeris.Context) {
// Extract stream and recorder before releasing context (reset nils them).
var s *stream.Stream
if raw := ctxkit.GetStream(ctx); raw != nil {
s = raw.(*stream.Stream)
}
var combo *recorderCombo
if rw := ctxkit.GetResponseWriter(ctx); rw != nil {
if w, ok := rw.(*recorderWriter); ok && w.combo != nil {
combo = w.combo
}
}

ctxkit.ReleaseContext(ctx)

// Return stream to pool so NewStream reuses it on the next call.
if s != nil {
if s.HasDoneCh() {
// A derived context (e.g. context.WithTimeout) spawned a
// goroutine that holds a reference to this stream. Cancel to
// terminate it, but don't pool the stream — the goroutine may
// still read Err()/Done() after the pool recycles the struct.
s.Cancel()
} else {
stream.ResetForPool(s)
}
}
if combo != nil {
combo.rec.StatusCode = 0
combo.rec.Headers = nil
combo.rec.Body = combo.rec.Body[:0]
recorderPool.Put(combo)
}
}

// NewContextT is like [NewContext] but registers an automatic cleanup with
Expand All @@ -160,7 +239,7 @@ func NewContextT(t *testing.T, method, path string, opts ...Option) (*celeris.Co
// The returned context has the given method and path, plus any options applied.
// Call [ReleaseContext] when done to clean up pooled resources.
func NewContext(method, path string, opts ...Option) (*celeris.Context, *ResponseRecorder) {
cfg := &config{}
cfg := configPool.Get().(*config)
for _, o := range opts {
o(cfg)
}
Expand All @@ -175,12 +254,12 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons
}

s := stream.NewStream(1)
s.Headers = [][2]string{
{":method", method},
{":path", fullPath},
{":scheme", "http"},
{":authority", "localhost"},
}
s.Headers = append(s.Headers,
[2]string{":method", method},
[2]string{":path", fullPath},
[2]string{":scheme", "http"},
[2]string{":authority", "localhost"},
)
s.Headers = append(s.Headers, cfg.headers...)
if len(cfg.cookies) > 0 {
parts := make([]string, 0, len(cfg.cookies))
Expand All @@ -190,11 +269,15 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons
s.Headers = append(s.Headers, [2]string{"cookie", strings.Join(parts, "; ")})
}
if len(cfg.body) > 0 {
s.Data.Write(cfg.body)
s.GetBuf().Write(cfg.body)
}

rec := &ResponseRecorder{}
s.ResponseWriter = &recorderWriter{rec: rec}
combo := recorderPool.Get().(*recorderCombo)
combo.rec.StatusCode = 0
combo.rec.Headers = nil
combo.rec.Body = combo.rec.Body[:0]
rec := &combo.rec
s.ResponseWriter = &combo.rw

if cfg.remoteAddr != "" {
s.RemoteAddr = cfg.remoteAddr
Expand All @@ -207,5 +290,9 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons
if len(cfg.handlers) > 0 {
ctxkit.SetHandlers(ctx, cfg.handlers)
}

cfg.reset()
configPool.Put(cfg)

return ctx, rec
}
108 changes: 108 additions & 0 deletions celeristest/celeristest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,111 @@ func TestWithRemoteAddr(t *testing.T) {
t.Fatalf("expected 192.168.1.1:54321, got %s", ctx.RemoteAddr())
}
}

func TestWithHandlers(t *testing.T) {
var order []string
mw := func(c *celeris.Context) error {
order = append(order, "mw")
return c.Next()
}
handler := func(c *celeris.Context) error {
order = append(order, "handler")
return c.String(200, "ok")
}
ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler))
defer ReleaseContext(ctx)

err := ctx.Next()
if err != nil {
t.Fatal(err)
}
if rec.StatusCode != 200 {
t.Fatalf("expected 200, got %d", rec.StatusCode)
}
if rec.BodyString() != "ok" {
t.Fatalf("expected body 'ok', got %q", rec.BodyString())
}
if len(order) != 2 || order[0] != "mw" || order[1] != "handler" {
t.Fatalf("unexpected execution order: %v", order)
}
}

func TestWithHandlersErrorPropagation(t *testing.T) {
mw := func(c *celeris.Context) error {
return c.Next()
}
handler := func(_ *celeris.Context) error {
return celeris.NewHTTPError(403, "forbidden")
}
ctx, _ := NewContext("GET", "/test", WithHandlers(mw, handler))
defer ReleaseContext(ctx)

err := ctx.Next()
if err == nil {
t.Fatal("expected error from handler chain")
}
he, ok := err.(*celeris.HTTPError)
if !ok {
t.Fatalf("expected *HTTPError, got %T", err)
}
if he.Code != 403 {
t.Fatalf("expected 403, got %d", he.Code)
}
}

func TestWithHandlersManyHandlers(t *testing.T) {
var order []string
makeMW := func(name string) celeris.HandlerFunc {
return func(c *celeris.Context) error {
order = append(order, name)
return c.Next()
}
}
handler := func(c *celeris.Context) error {
order = append(order, "handler")
return c.String(200, "ok")
}
// 5 middleware + 1 handler = 6 total, exceeds the 4-element handlersBuf.
ctx, rec := NewContext("GET", "/test", WithHandlers(
makeMW("mw1"), makeMW("mw2"), makeMW("mw3"), makeMW("mw4"), makeMW("mw5"), handler,
))
defer ReleaseContext(ctx)

err := ctx.Next()
if err != nil {
t.Fatal(err)
}
if rec.StatusCode != 200 {
t.Fatalf("expected 200, got %d", rec.StatusCode)
}
expected := []string{"mw1", "mw2", "mw3", "mw4", "mw5", "handler"}
if len(order) != len(expected) {
t.Fatalf("expected %d entries, got %d: %v", len(expected), len(order), order)
}
for i, e := range expected {
if order[i] != e {
t.Fatalf("order[%d] = %q, want %q", i, order[i], e)
}
}
}

func TestWithHandlersAbort(t *testing.T) {
var reached bool
mw := func(c *celeris.Context) error {
return c.AbortWithStatus(401)
}
handler := func(c *celeris.Context) error {
reached = true
return c.String(200, "ok")
}
ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler))
defer ReleaseContext(ctx)

_ = ctx.Next()
if reached {
t.Fatal("handler should not have been reached after abort")
}
if rec.StatusCode != 401 {
t.Fatalf("expected 401, got %d", rec.StatusCode)
}
}
Loading