From 1b8dc120a77a4b5da885fc3f49fa2fb413d2db0a Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 01:50:47 +0200 Subject: [PATCH 01/11] perf: inline respHdrBuf for response headers (#145) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Initialize respHeaders to use the inline respHdrBuf backing array in the context pool constructor and reset(). This eliminates the heap allocation triggered by the first SetHeader() call via append(). Fix the aliasing hazard in Blob() where respHeaders and respHdrBuf now share the same backing array — copy user headers to a stack temporary before overwriting the buffer with content-type and content-length. --- context.go | 8 ++++++-- context_response.go | 23 +++++++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/context.go b/context.go index 80099df..c96e3cd 100644 --- a/context.go +++ b/context.go @@ -86,7 +86,11 @@ type Context struct { respHdrBuf [8][2]string // reusable buffer for response headers (avoids heap escape) } -var contextPool = sync.Pool{New: func() any { return &Context{} }} +var contextPool = sync.Pool{New: func() any { + c := &Context{} + c.respHeaders = c.respHdrBuf[:0] + return c +}} const abortIndex int16 = math.MaxInt16 / 2 @@ -256,7 +260,7 @@ func (c *Context) reset() { c.rawQuery = "" c.fullPath = "" c.statusCode = 200 - c.respHeaders = c.respHeaders[:0] + c.respHeaders = c.respHdrBuf[:0] c.written = false c.aborted = false c.bytesWritten = 0 diff --git a/context_response.go b/context_response.go index 9358243..2e3875a 100644 --- a/context_response.go +++ b/context_response.go @@ -125,13 +125,24 @@ func (c *Context) Blob(code int, contentType string, data []byte) error { c.capturedStatus = code c.capturedType = contentType } - headers := c.respHdrBuf[:0:8] - if len(c.respHeaders)+2 > 8 { - headers = make([][2]string, 0, len(c.respHeaders)+2) + nUser := len(c.respHeaders) + total := nUser + 2 + var headers [][2]string + if total <= 8 { + // respHeaders shares backing array with respHdrBuf — copy user + // headers to a stack temporary before overwriting the buffer. + var tmp [6][2]string + copy(tmp[:nUser], c.respHeaders) + headers = c.respHdrBuf[:0:8] + headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) + headers = append(headers, [2]string{"content-length", itoa(len(data))}) + headers = append(headers, tmp[:nUser]...) + } else { + headers = make([][2]string, 0, total) + headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) + headers = append(headers, [2]string{"content-length", itoa(len(data))}) + headers = append(headers, c.respHeaders...) } - headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) - headers = append(headers, [2]string{"content-length", itoa(len(data))}) - headers = append(headers, c.respHeaders...) if c.stream.ResponseWriter != nil { return c.stream.ResponseWriter.WriteResponse(c.stream, code, headers, data) } From c998faf7e01adf4f1b84128fc3cae1892aea88b4 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 01:52:00 +0200 Subject: [PATCH 02/11] perf: pre-allocate Context.keys map (#141) Initialize keys with make(map[string]any, 4) in the pool constructor and use clear() in reset() instead of niling the map. This retains the hash table buckets across requests, eliminating the heap allocation on the first c.Set() call. --- context.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/context.go b/context.go index c96e3cd..3ceeb06 100644 --- a/context.go +++ b/context.go @@ -87,7 +87,7 @@ type Context struct { } var contextPool = sync.Pool{New: func() any { - c := &Context{} + c := &Context{keys: make(map[string]any, 4)} c.respHeaders = c.respHdrBuf[:0] return c }} @@ -217,17 +217,11 @@ func (c *Context) SetContext(ctx context.Context) { // Set stores a key-value pair for this request. func (c *Context) Set(key string, value any) { c.extended = true - if c.keys == nil { - c.keys = make(map[string]any) - } c.keys[key] = value } // Get returns the value for a key. func (c *Context) Get(key string) (any, bool) { - if c.keys == nil { - return nil, false - } v, ok := c.keys[key] return v, ok } @@ -235,7 +229,7 @@ func (c *Context) Get(key string) (any, bool) { // Keys returns a copy of all key-value pairs stored on this context. // Returns nil if no values have been set. func (c *Context) Keys() map[string]any { - if c.keys == nil { + if len(c.keys) == 0 { return nil } cp := make(map[string]any, len(c.keys)) @@ -266,7 +260,7 @@ func (c *Context) reset() { c.bytesWritten = 0 c.maxFormSize = 0 if c.extended { - c.keys = nil + clear(c.keys) c.queryCache = nil c.queryCached = false c.cookieCache = c.cookieCache[:0] From 85558355582c4c82416e2413b859abd32d8a6cf1 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 01:53:15 +0200 Subject: [PATCH 03/11] perf: lazy OutboundBuffer allocation in NewStream (#143) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove eager getBuf() for OutboundBuffer from NewStream(). The buffer is only needed when flow control prevents immediate sends. BufferOutbound() already has a nil guard and now uses getBuf() (pool) instead of new(bytes.Buffer). Saves 1 alloc per NewStream() call — all test contexts and H2 streams that don't need flow-control buffering. --- protocol/h2/stream/stream.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/protocol/h2/stream/stream.go b/protocol/h2/stream/stream.go index 2ccd4b0..a079b1a 100644 --- a/protocol/h2/stream/stream.go +++ b/protocol/h2/stream/stream.go @@ -104,7 +104,6 @@ func NewStream(id uint32) *Stream { s.ID = id s.state.Store(int32(StateIdle)) s.Data = getBuf() - s.OutboundBuffer = getBuf() s.windowSize.Store(65535) s.phase = PhaseInit if cap(s.Headers) < 8 { @@ -402,7 +401,7 @@ func (s *Stream) SetHandlerStarted() { func (s *Stream) BufferOutbound(data []byte, endStream bool) { s.mu.Lock() if s.OutboundBuffer == nil { - s.OutboundBuffer = new(bytes.Buffer) + s.OutboundBuffer = getBuf() } s.OutboundBuffer.Write(data) s.OutboundEndStream = endStream From b85f230057278790a60af81fa71b9c59987a4a69 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 01:54:56 +0200 Subject: [PATCH 04/11] perf: inline stream headers array (#146) Add hdrBuf [16][2]string to Stream struct and use it as the backing array for Headers in NewStream, NewH1Stream, Release, ResetH1Stream, and ResetH2StreamInline. This eliminates the make([][2]string, 0, N) heap allocation on first pool retrieval and re-anchors to the inline buffer on release (in case append grew beyond 16). --- protocol/h2/stream/stream.go | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/protocol/h2/stream/stream.go b/protocol/h2/stream/stream.go index a079b1a..d1152be 100644 --- a/protocol/h2/stream/stream.go +++ b/protocol/h2/stream/stream.go @@ -61,6 +61,7 @@ type Stream struct { phase Phase CachedCtx any // per-connection cached context (avoids pool Get/Put per request) OnDetach func() // called by Context.Detach to install write-thread safety + hdrBuf [16][2]string } // streamContext is a zero-alloc context.Context for streams. @@ -106,9 +107,7 @@ func NewStream(id uint32) *Stream { s.Data = getBuf() s.windowSize.Store(65535) s.phase = PhaseInit - if cap(s.Headers) < 8 { - s.Headers = make([][2]string, 0, 8) - } + s.Headers = s.hdrBuf[:0] return s } @@ -118,9 +117,7 @@ func NewH1Stream(id uint32) *Stream { s.ID = id s.state.Store(int32(StateIdle)) s.h1Mode = true - if cap(s.Headers) < 16 { - s.Headers = make([][2]string, 0, 16) - } + s.Headers = s.hdrBuf[:0] return s } @@ -180,7 +177,7 @@ func (s *Stream) Release() { s.ID = 0 s.state.Store(0) s.manager = nil - s.Headers = s.Headers[:0] + s.Headers = s.hdrBuf[:0] s.Trailers = s.Trailers[:0] s.OutboundEndStream = false s.headersSent.Store(false) @@ -217,7 +214,7 @@ func ResetH1Stream(s *Stream) { bufferPool.Put(s.Data) s.Data = nil } - s.Headers = s.Headers[:0] + s.Headers = s.hdrBuf[:0] s.headersSent.Store(false) s.EndStream = false s.ResponseWriter = nil @@ -240,7 +237,7 @@ func ResetH2StreamInline(s *Stream, id uint32) { s.OutboundBuffer = getBuf() } s.ID = id - s.Headers = s.Headers[:0] + s.Headers = s.hdrBuf[:0] s.Trailers = s.Trailers[:0] s.OutboundEndStream = false s.headersSent.Store(false) From 44e82860317e10d95e909a7fd2e43f1e43137a71 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 01:56:10 +0200 Subject: [PATCH 05/11] perf: BasicAuth stack-buffer decode (#142) Use base64.StdEncoding.Decode with a [128]byte stack buffer instead of DecodeString which heap-allocates. Convert the auth payload string to []byte via unsafe.Slice(unsafe.StringData(...)) without allocation (read-only, safe for Decode which only reads src). Eliminates the intermediate []byte heap allocation; the two returned strings (username, password) are individually smaller than the full decoded buffer. --- context_request.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/context_request.go b/context_request.go index 4cbe383..1a31c24 100644 --- a/context_request.go +++ b/context_request.go @@ -13,6 +13,7 @@ import ( "net/url" "strconv" "strings" + "unsafe" "github.com/goceleris/celeris/internal/negotiate" ) @@ -340,16 +341,18 @@ func (c *Context) BasicAuth() (username, password string, ok bool) { if len(auth) < len(prefix) || auth[:len(prefix)] != prefix { return } - decoded, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + payload := auth[len(prefix):] + var buf [128]byte + n, err := base64.StdEncoding.Decode(buf[:], + unsafe.Slice(unsafe.StringData(payload), len(payload))) if err != nil { return } - s := string(decoded) - colon := strings.IndexByte(s, ':') - if colon < 0 { + i := bytes.IndexByte(buf[:n], ':') + if i < 0 { return } - return s[:colon], s[colon+1:], true + return string(buf[:i]), string(buf[i+1 : n]), true } // FormValue returns the first value for the named form field. From ab297f2f2b1b2dc7e6465cc67c89929e2bac5dd8 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 02:02:12 +0200 Subject: [PATCH 06/11] perf: pool ResponseRecorder in celeristest (#144) Bundle ResponseRecorder and recorderWriter into a single recorderCombo struct managed by sync.Pool. NewContext gets the combo from the pool instead of allocating two heap objects per call. ReleaseContext extracts the combo via the new ctxkit.GetResponseWriter hook and returns it. WriteResponse reuses the Body slice via append(w.rec.Body[:0], body...) instead of allocating a new []byte each time. --- celeristest/celeristest.go | 41 +++++++++++++++++++++++++++++++++----- context.go | 7 +++++++ internal/ctxkit/ctxkit.go | 9 +++++---- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/celeristest/celeristest.go b/celeristest/celeristest.go index fe49083..072528e 100644 --- a/celeristest/celeristest.go +++ b/celeristest/celeristest.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "net/url" "strings" + "sync" "testing" "github.com/goceleris/celeris" @@ -48,18 +49,32 @@ func (r *ResponseRecorder) BodyString() string { return string(r.Body) } +// recorderCombo bundles a ResponseRecorder and its recorderWriter in a +// single allocation so they can be pooled together. +type recorderCombo struct { + rec ResponseRecorder + rw recorderWriter +} + +var recorderPool = sync.Pool{New: func() any { + combo := &recorderCombo{} + combo.rw.rec = &combo.rec + combo.rw.combo = combo + return combo +}} + // recorderWriter adapts a ResponseRecorder to the internal // stream.ResponseWriter interface without leaking internal types // in the public API. type recorderWriter struct { - rec *ResponseRecorder + rec *ResponseRecorder + combo *recorderCombo } func (w *recorderWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error { w.rec.StatusCode = status w.rec.Headers = headers - w.rec.Body = make([]byte, len(body)) - copy(w.rec.Body, body) + w.rec.Body = append(w.rec.Body[:0], body...) return nil } @@ -144,6 +159,18 @@ func WithHandlers(handlers ...celeris.HandlerFunc) Option { // ReleaseContext returns a [celeris.Context] to the pool. The context must not // be used after this call. func ReleaseContext(ctx *celeris.Context) { + // Extract recorder combo before releasing the context (which nils the stream). + if rw := ctxkit.GetResponseWriter(ctx); rw != nil { + if w, ok := rw.(*recorderWriter); ok && w.combo != nil { + combo := w.combo + ctxkit.ReleaseContext(ctx) + combo.rec.StatusCode = 0 + combo.rec.Headers = nil + combo.rec.Body = combo.rec.Body[:0] + recorderPool.Put(combo) + return + } + } ctxkit.ReleaseContext(ctx) } @@ -193,8 +220,12 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons s.Data.Write(cfg.body) } - rec := &ResponseRecorder{} - s.ResponseWriter = &recorderWriter{rec: rec} + combo := recorderPool.Get().(*recorderCombo) + combo.rec.StatusCode = 0 + combo.rec.Headers = nil + combo.rec.Body = combo.rec.Body[:0] + rec := &combo.rec + s.ResponseWriter = &combo.rw if cfg.remoteAddr != "" { s.RemoteAddr = cfg.remoteAddr diff --git a/context.go b/context.go index 3ceeb06..e3bddde 100644 --- a/context.go +++ b/context.go @@ -32,6 +32,13 @@ func init() { ctx.handlers = chain ctx.index = -1 } + ctxkit.GetResponseWriter = func(c any) any { + ctx := c.(*Context) + if ctx.stream != nil { + return ctx.stream.ResponseWriter + } + return nil + } } // Context is the request context passed to handlers. It is pooled via sync.Pool. diff --git a/internal/ctxkit/ctxkit.go b/internal/ctxkit/ctxkit.go index e0af1a9..89ce20c 100644 --- a/internal/ctxkit/ctxkit.go +++ b/internal/ctxkit/ctxkit.go @@ -7,8 +7,9 @@ import "github.com/goceleris/celeris/protocol/h2/stream" // Hooks registered by the celeris package at init time. var ( - NewContext func(s *stream.Stream) any - ReleaseContext func(c any) - AddParam func(c any, key, value string) - SetHandlers func(c any, handlers []any) + NewContext func(s *stream.Stream) any + ReleaseContext func(c any) + AddParam func(c any, key, value string) + SetHandlers func(c any, handlers []any) + GetResponseWriter func(c any) any ) From a91562dfee18ec125efe553edf8beab018d34f3b Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 03:45:51 +0200 Subject: [PATCH 07/11] perf: zero-alloc context lifecycle + stream recycling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Comprehensive allocation elimination across the request lifecycle: - Stream.Data lazy allocation: remove eager getBuf() from NewStream; Data is allocated on first write via GetBuf()/AddData() - Stream.ResetForPool: new function for test recycling without Cancel (avoids race with context.WithTimeout propagation goroutines) - Stream.HasDoneCh: detect derived contexts for safe recycling decisions - Context.handlerBuf [8]HandlerFunc: inline handler chain buffer; SetHandlers uses it for chains <=8 (avoids make([]HandlerFunc)) - ctxkit.GetStream hook: enables celeristest to extract and recycle the stream before context release - celeristest config pooling: sync.Pool with inline headersBuf[4] and handlersBuf[4]; WithHandlers uses inline buffer for <=4 handlers - celeristest headers: append to hdrBuf instead of slice literal - celeristest ReleaseContext: recycles stream to pool (ResetForPool) for streams without derived contexts; Cancel-only for streams with doneCh to avoid goroutine reference races - BasicAuth: single string allocation (decoded[:i], decoded[i+1:]) instead of two separate string() conversions - stdlib.go, tests: s.Data.Write → s.GetBuf().Write for lazy Data Result: 0 allocs/op on 6 of 9 middleware benchmarks (was 10-17). --- celeristest/celeristest.go | 108 ++++++++++++++++++++++++++--------- context.go | 39 +++++++++++-- context_request.go | 3 +- context_request_test.go | 36 ++++++------ internal/ctxkit/ctxkit.go | 1 + protocol/h2/stream/stream.go | 19 +++++- server_test.go | 2 +- stdlib.go | 2 +- 8 files changed, 156 insertions(+), 54 deletions(-) diff --git a/celeristest/celeristest.go b/celeristest/celeristest.go index 072528e..517b191 100644 --- a/celeristest/celeristest.go +++ b/celeristest/celeristest.go @@ -92,13 +92,38 @@ var _ stream.ResponseWriter = (*recorderWriter)(nil) type Option func(*config) type config struct { - body []byte - headers [][2]string - queries [][2]string - params [][2]string - cookies [][2]string - remoteAddr string - handlers []any + body []byte + headers [][2]string + queries [][2]string + params [][2]string + cookies [][2]string + remoteAddr string + handlers []any + headersBuf [4][2]string + handlersBuf [4]any +} + +var configPool = sync.Pool{New: func() any { + c := &config{} + c.headers = c.headersBuf[:0] + c.handlers = c.handlersBuf[:0] + return c +}} + +func (c *config) reset() { + c.body = nil + for i := range c.headers { + c.headers[i] = [2]string{} + } + c.headers = c.headersBuf[:0] + c.queries = nil + c.params = nil + c.cookies = nil + c.remoteAddr = "" + for i := range c.handlers { + c.handlers[i] = nil + } + c.handlers = c.handlersBuf[:0] } // WithBody sets the request body. @@ -149,9 +174,17 @@ func WithRemoteAddr(addr string) Option { // Pass celeris.HandlerFunc values; they are stored as []any to avoid import cycles. func WithHandlers(handlers ...celeris.HandlerFunc) Option { return func(c *config) { - c.handlers = make([]any, len(handlers)) - for i, h := range handlers { - c.handlers[i] = h + n := len(handlers) + if n <= len(c.handlersBuf) { + for i, h := range handlers { + c.handlersBuf[i] = h + } + c.handlers = c.handlersBuf[:n] + } else { + c.handlers = make([]any, n) + for i, h := range handlers { + c.handlers[i] = h + } } } } @@ -159,19 +192,38 @@ func WithHandlers(handlers ...celeris.HandlerFunc) Option { // ReleaseContext returns a [celeris.Context] to the pool. The context must not // be used after this call. func ReleaseContext(ctx *celeris.Context) { - // Extract recorder combo before releasing the context (which nils the stream). + // Extract stream and recorder before releasing context (reset nils them). + var s *stream.Stream + if raw := ctxkit.GetStream(ctx); raw != nil { + s = raw.(*stream.Stream) + } + var combo *recorderCombo if rw := ctxkit.GetResponseWriter(ctx); rw != nil { if w, ok := rw.(*recorderWriter); ok && w.combo != nil { - combo := w.combo - ctxkit.ReleaseContext(ctx) - combo.rec.StatusCode = 0 - combo.rec.Headers = nil - combo.rec.Body = combo.rec.Body[:0] - recorderPool.Put(combo) - return + combo = w.combo } } + ctxkit.ReleaseContext(ctx) + + // Return stream to pool so NewStream reuses it on the next call. + if s != nil { + if s.HasDoneCh() { + // A derived context (e.g. context.WithTimeout) spawned a + // goroutine that holds a reference to this stream. Cancel to + // terminate it, but don't pool the stream — the goroutine may + // still read Err()/Done() after the pool recycles the struct. + s.Cancel() + } else { + stream.ResetForPool(s) + } + } + if combo != nil { + combo.rec.StatusCode = 0 + combo.rec.Headers = nil + combo.rec.Body = combo.rec.Body[:0] + recorderPool.Put(combo) + } } // NewContextT is like [NewContext] but registers an automatic cleanup with @@ -187,7 +239,7 @@ func NewContextT(t *testing.T, method, path string, opts ...Option) (*celeris.Co // The returned context has the given method and path, plus any options applied. // Call [ReleaseContext] when done to clean up pooled resources. func NewContext(method, path string, opts ...Option) (*celeris.Context, *ResponseRecorder) { - cfg := &config{} + cfg := configPool.Get().(*config) for _, o := range opts { o(cfg) } @@ -202,12 +254,12 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons } s := stream.NewStream(1) - s.Headers = [][2]string{ - {":method", method}, - {":path", fullPath}, - {":scheme", "http"}, - {":authority", "localhost"}, - } + s.Headers = append(s.Headers, + [2]string{":method", method}, + [2]string{":path", fullPath}, + [2]string{":scheme", "http"}, + [2]string{":authority", "localhost"}, + ) s.Headers = append(s.Headers, cfg.headers...) if len(cfg.cookies) > 0 { parts := make([]string, 0, len(cfg.cookies)) @@ -217,7 +269,7 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons s.Headers = append(s.Headers, [2]string{"cookie", strings.Join(parts, "; ")}) } if len(cfg.body) > 0 { - s.Data.Write(cfg.body) + s.GetBuf().Write(cfg.body) } combo := recorderPool.Get().(*recorderCombo) @@ -238,5 +290,9 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons if len(cfg.handlers) > 0 { ctxkit.SetHandlers(ctx, cfg.handlers) } + + cfg.reset() + configPool.Put(cfg) + return ctx, rec } diff --git a/context.go b/context.go index e3bddde..b5ae8fe 100644 --- a/context.go +++ b/context.go @@ -25,9 +25,18 @@ func init() { } ctxkit.SetHandlers = func(c any, handlers []any) { ctx := c.(*Context) - chain := make([]HandlerFunc, len(handlers)) - for i, h := range handlers { - chain[i] = h.(HandlerFunc) + n := len(handlers) + var chain []HandlerFunc + if n <= len(ctx.handlerBuf) { + for i, h := range handlers { + ctx.handlerBuf[i] = h.(HandlerFunc) + } + chain = ctx.handlerBuf[:n] + } else { + chain = make([]HandlerFunc, n) + for i, h := range handlers { + chain[i] = h.(HandlerFunc) + } } ctx.handlers = chain ctx.index = -1 @@ -39,15 +48,23 @@ func init() { } return nil } + ctxkit.GetStream = func(c any) any { + ctx := c.(*Context) + if ctx.stream != nil { + return ctx.stream + } + return nil + } } // Context is the request context passed to handlers. It is pooled via sync.Pool. // A Context is obtained from the pool and must not be retained after the handler returns. type Context struct { stream *stream.Stream - index int16 - handlers []HandlerFunc - params Params + index int16 + handlers []HandlerFunc + handlerBuf [8]HandlerFunc + params Params keys map[string]any ctx context.Context @@ -249,6 +266,16 @@ func (c *Context) Keys() map[string]any { func (c *Context) reset() { c.stream = nil c.index = -1 + // Clear handler references so closures can be GCed, but only when + // the slice is owned by this context (backed by handlerBuf or a + // SetHandlers allocation). The production path assigns the router's + // pre-composed chain directly; nilling those would corrupt shared state. + if len(c.handlers) > 0 && cap(c.handlers) <= len(c.handlerBuf) && + &c.handlers[:cap(c.handlers)][0] == &c.handlerBuf[0] { + for i := range c.handlers { + c.handlers[i] = nil + } + } if cap(c.handlers) > 64 { c.handlers = nil } else { diff --git a/context_request.go b/context_request.go index 1a31c24..8206754 100644 --- a/context_request.go +++ b/context_request.go @@ -352,7 +352,8 @@ func (c *Context) BasicAuth() (username, password string, ok bool) { if i < 0 { return } - return string(buf[:i]), string(buf[i+1 : n]), true + decoded := string(buf[:n]) + return decoded[:i], decoded[i+1:], true } // FormValue returns the first value for the named form field. diff --git a/context_request_test.go b/context_request_test.go index 23aac6f..7386720 100644 --- a/context_request_test.go +++ b/context_request_test.go @@ -15,7 +15,7 @@ import ( func TestContextBind(t *testing.T) { s, _ := newTestStream("POST", "/bind") - s.Data.Write([]byte(`{"name":"test"}`)) + s.GetBuf().Write([]byte(`{"name":"test"}`)) defer s.Release() c := acquireContext(s) @@ -221,7 +221,7 @@ func TestContextQueryInt(t *testing.T) { func TestBindJSON(t *testing.T) { s, _ := newTestStream("POST", "/bind-json") - s.Data.Write([]byte(`{"name":"test"}`)) + s.GetBuf().Write([]byte(`{"name":"test"}`)) defer s.Release() c := acquireContext(s) @@ -238,7 +238,7 @@ func TestBindJSON(t *testing.T) { func TestBindXML(t *testing.T) { s, _ := newTestStream("POST", "/bind-xml") - s.Data.Write([]byte(`test`)) + s.GetBuf().Write([]byte(`test`)) defer s.Release() c := acquireContext(s) @@ -259,7 +259,7 @@ func TestBindXML(t *testing.T) { func TestBindContentTypeDetection(t *testing.T) { // JSON (default, no Content-Type). s, _ := newTestStream("POST", "/bind-auto") - s.Data.Write([]byte(`{"name":"json"}`)) + s.GetBuf().Write([]byte(`{"name":"json"}`)) defer s.Release() c := acquireContext(s) @@ -276,7 +276,7 @@ func TestBindContentTypeDetection(t *testing.T) { // XML with Content-Type header. s2, _ := newTestStream("POST", "/bind-auto-xml") s2.Headers = append(s2.Headers, [2]string{"content-type", "application/xml"}) - s2.Data.Write([]byte(`xml`)) + s2.GetBuf().Write([]byte(`xml`)) defer s2.Release() c2 := acquireContext(s2) @@ -440,7 +440,7 @@ func TestContextBasicAuthPasswordWithColons(t *testing.T) { func TestContextFormValueURLEncoded(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("name=alice&age=30")) + s.GetBuf().Write([]byte("name=alice&age=30")) defer s.Release() c := acquireContext(s) @@ -460,7 +460,7 @@ func TestContextFormValueURLEncoded(t *testing.T) { func TestContextFormValuesMultiple(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("color=red&color=blue&color=green")) + s.GetBuf().Write([]byte("color=red&color=blue&color=green")) defer s.Release() c := acquireContext(s) @@ -484,7 +484,7 @@ func TestContextMultipartFormValue(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()}) - s.Data.Write(buf.Bytes()) + s.GetBuf().Write(buf.Bytes()) defer s.Release() c := acquireContext(s) @@ -513,7 +513,7 @@ func TestContextFormFile(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()}) - s.Data.Write(buf.Bytes()) + s.GetBuf().Write(buf.Bytes()) defer s.Release() c := acquireContext(s) @@ -538,7 +538,7 @@ func TestContextFormFile(t *testing.T) { func TestContextFormFileNonMultipart(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -565,7 +565,7 @@ func TestContextFormEmptyBody(t *testing.T) { func TestContextFormCaching(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("a=1&b=2")) + s.GetBuf().Write([]byte("a=1&b=2")) defer s.Release() c := acquireContext(s) @@ -584,7 +584,7 @@ func TestContextFormCaching(t *testing.T) { func TestContextFormResetClearsForm(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -607,7 +607,7 @@ func TestContextMaxFormSizeUnlimited(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()}) - s.Data.Write(buf.Bytes()) + s.GetBuf().Write(buf.Bytes()) defer s.Release() c := acquireContext(s) @@ -624,7 +624,7 @@ func TestContextMaxFormSizeUnlimited(t *testing.T) { func TestContextBodyCopy(t *testing.T) { s, _ := newTestStream("POST", "/data") - s.Data.Write([]byte("original")) + s.GetBuf().Write([]byte("original")) defer s.Release() c := acquireContext(s) @@ -655,7 +655,7 @@ func TestContextBodyCopyEmpty(t *testing.T) { func TestContextBodyReader(t *testing.T) { s, _ := newTestStream("POST", "/data") - s.Data.Write([]byte("read me")) + s.GetBuf().Write([]byte("read me")) defer s.Release() c := acquireContext(s) @@ -1070,7 +1070,7 @@ func TestAcceptsLanguagesNoMatch(t *testing.T) { func TestFormFileNotMultipart(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -1092,7 +1092,7 @@ func TestFormFileNotMultipart(t *testing.T) { func TestMultipartFormNotMultipart(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -1170,7 +1170,7 @@ func TestContextFormValueOk(t *testing.T) { body := "name=alice&empty=" s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte(body)) + s.GetBuf().Write([]byte(body)) defer s.Release() c := acquireContext(s) diff --git a/internal/ctxkit/ctxkit.go b/internal/ctxkit/ctxkit.go index 89ce20c..d7858ec 100644 --- a/internal/ctxkit/ctxkit.go +++ b/internal/ctxkit/ctxkit.go @@ -12,4 +12,5 @@ var ( AddParam func(c any, key, value string) SetHandlers func(c any, handlers []any) GetResponseWriter func(c any) any + GetStream func(c any) any ) diff --git a/protocol/h2/stream/stream.go b/protocol/h2/stream/stream.go index d1152be..adcf45a 100644 --- a/protocol/h2/stream/stream.go +++ b/protocol/h2/stream/stream.go @@ -104,7 +104,6 @@ func NewStream(id uint32) *Stream { s := streamPool.Get().(*Stream) s.ID = id s.state.Store(int32(StateIdle)) - s.Data = getBuf() s.windowSize.Store(65535) s.phase = PhaseInit s.Headers = s.hdrBuf[:0] @@ -158,12 +157,30 @@ func (s *Stream) IsCancelled() bool { return s.flags.Load()&flagCancelled != 0 } +// HasDoneCh reports whether a Done channel was created, indicating +// a derived context (e.g. context.WithTimeout) is watching this stream. +func (s *Stream) HasDoneCh() bool { + return s.doneCh.Load() != nil +} + // Release returns pooled buffers, cancels the context, and returns the stream // to its pool. Safe to call multiple times; subsequent calls are no-ops. func (s *Stream) Release() { if !s.h1Mode { s.Cancel() } + s.resetAndPool() +} + +// ResetForPool returns pooled buffers and returns the stream to its pool +// WITHOUT cancelling the context. Use this in test harnesses where derived +// contexts (e.g. from context.WithTimeout) may have propagation goroutines +// that race with cancellation flag clearing. +func ResetForPool(s *Stream) { + s.resetAndPool() +} + +func (s *Stream) resetAndPool() { if s.Data != nil { s.Data.Reset() bufferPool.Put(s.Data) diff --git a/server_test.go b/server_test.go index dea10fb..8f9cbc4 100644 --- a/server_test.go +++ b/server_test.go @@ -65,7 +65,7 @@ func TestServerRouting(t *testing.T) { // Test POST /echo. st2, rw2 := newTestStream("POST", "/echo") - st2.Data.Write([]byte("payload")) + st2.GetBuf().Write([]byte("payload")) if err := adapter.HandleStream(context.Background(), st2); err != nil { t.Fatal(err) } diff --git a/stdlib.go b/stdlib.go index b743981..97f1723 100644 --- a/stdlib.go +++ b/stdlib.go @@ -57,7 +57,7 @@ func ToHandler(h HandlerFunc) http.Handler { return } if len(body) > 0 { - _, _ = s.Data.Write(body) + _, _ = s.GetBuf().Write(body) } } From 3a86446c118d27fe0e1c7865638eee8abda35944 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 13:15:33 +0200 Subject: [PATCH 08/11] perf: v1.2.2 comprehensive optimization, security, DX, and quality sweep MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Performance: - Embedded paramBuf [4]Param on Context for inline route parameter storage - Hoisted stripCRLF/stripCookieUnsafe replacers to package-level vars - Inline fast-path scans in AddHeader/SetHeader (skip non-inlineable calls) - Consolidated three 100MB constants into single maxBodySize - Scheme() fast-path for common "http"/"https" values - Pre-allocated notFound/methodNotAllowed handler chains on routerAdapter - Clear backing arrays (respHdrBuf, paramBuf, hdrBuf, Trailers) in reset for GC hygiene — prevents stale strings from being pinned - SetCookie Builder.Grow(128) pre-allocation Security: - H1 parser: reject duplicate Content-Length with conflicting values (CL-CL request smuggling prevention, RFC 7230 §3.3.3) - FileFromDir: resolve symlinks via filepath.EvalSymlinks + recheck prefix to prevent symlink-based directory traversal - FileFromDir: reject directory paths with IsDir() check DX: - Header(key) auto-lowercases uppercase keys (net/http compat) - FormValueOK canonical name (FormValueOk deprecated alias kept) - QueryBool(key, default) convenience method - QueryInt64(key, default) convenience method - ParamDefault(key, default) convenience method Code quality: - Removed dead internal/timer/ package - Mock ResponseWriter copies headers (prevents clear() aliasing) - Blob header assembly uses len(c.respHdrBuf) instead of magic 8 Tests: - TestQueryBool (16 cases), TestQueryInt64 (8 cases), TestParamDefault - TestWithHandlers (4 tests: chain order, many handlers, error, abort) - TestFormValueOkDeprecated - TestParseRequest_DuplicateContentLength (3 subtests × 2 modes) --- bridge.go | 2 +- celeristest/celeristest_test.go | 108 ++++++++++++++++++++++ context.go | 6 +- context_request.go | 69 +++++++++++++- context_request_test.go | 154 +++++++++++++++++++++++++++++++- context_response.go | 69 +++++++++++--- context_test.go | 3 +- doc.go | 4 +- handler.go | 20 +++-- internal/timer/wheel.go | 145 ------------------------------ internal/timer/wheel_test.go | 133 --------------------------- protocol/h1/parser.go | 15 +++- protocol/h1/parser_test.go | 75 ++++++++++++++++ protocol/h2/stream/stream.go | 3 + server.go | 9 +- stdlib.go | 2 +- types.go | 6 +- 17 files changed, 508 insertions(+), 315 deletions(-) delete mode 100644 internal/timer/wheel.go delete mode 100644 internal/timer/wheel_test.go diff --git a/bridge.go b/bridge.go index 6dce906..910cc36 100644 --- a/bridge.go +++ b/bridge.go @@ -89,7 +89,7 @@ func buildHTTPRequest(c *Context) (*http.Request, error) { return req, nil } -const maxBridgeResponseBytes = 100 << 20 // 100 MB +const maxBridgeResponseBytes = maxBodySize var errBridgeResponseTooLarge = errors.New("bridge: response body exceeds 100MB limit") diff --git a/celeristest/celeristest_test.go b/celeristest/celeristest_test.go index e5190f0..33e21f2 100644 --- a/celeristest/celeristest_test.go +++ b/celeristest/celeristest_test.go @@ -205,3 +205,111 @@ func TestWithRemoteAddr(t *testing.T) { t.Fatalf("expected 192.168.1.1:54321, got %s", ctx.RemoteAddr()) } } + +func TestWithHandlers(t *testing.T) { + var order []string + mw := func(c *celeris.Context) error { + order = append(order, "mw") + return c.Next() + } + handler := func(c *celeris.Context) error { + order = append(order, "handler") + return c.String(200, "ok") + } + ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler)) + defer ReleaseContext(ctx) + + err := ctx.Next() + if err != nil { + t.Fatal(err) + } + if rec.StatusCode != 200 { + t.Fatalf("expected 200, got %d", rec.StatusCode) + } + if rec.BodyString() != "ok" { + t.Fatalf("expected body 'ok', got %q", rec.BodyString()) + } + if len(order) != 2 || order[0] != "mw" || order[1] != "handler" { + t.Fatalf("unexpected execution order: %v", order) + } +} + +func TestWithHandlersErrorPropagation(t *testing.T) { + mw := func(c *celeris.Context) error { + return c.Next() + } + handler := func(c *celeris.Context) error { + return celeris.NewHTTPError(403, "forbidden") + } + ctx, _ := NewContext("GET", "/test", WithHandlers(mw, handler)) + defer ReleaseContext(ctx) + + err := ctx.Next() + if err == nil { + t.Fatal("expected error from handler chain") + } + he, ok := err.(*celeris.HTTPError) + if !ok { + t.Fatalf("expected *HTTPError, got %T", err) + } + if he.Code != 403 { + t.Fatalf("expected 403, got %d", he.Code) + } +} + +func TestWithHandlersManyHandlers(t *testing.T) { + var order []string + makeMW := func(name string) celeris.HandlerFunc { + return func(c *celeris.Context) error { + order = append(order, name) + return c.Next() + } + } + handler := func(c *celeris.Context) error { + order = append(order, "handler") + return c.String(200, "ok") + } + // 5 middleware + 1 handler = 6 total, exceeds the 4-element handlersBuf. + ctx, rec := NewContext("GET", "/test", WithHandlers( + makeMW("mw1"), makeMW("mw2"), makeMW("mw3"), makeMW("mw4"), makeMW("mw5"), handler, + )) + defer ReleaseContext(ctx) + + err := ctx.Next() + if err != nil { + t.Fatal(err) + } + if rec.StatusCode != 200 { + t.Fatalf("expected 200, got %d", rec.StatusCode) + } + expected := []string{"mw1", "mw2", "mw3", "mw4", "mw5", "handler"} + if len(order) != len(expected) { + t.Fatalf("expected %d entries, got %d: %v", len(expected), len(order), order) + } + for i, e := range expected { + if order[i] != e { + t.Fatalf("order[%d] = %q, want %q", i, order[i], e) + } + } +} + +func TestWithHandlersAbort(t *testing.T) { + var reached bool + mw := func(c *celeris.Context) error { + return c.AbortWithStatus(401) + } + handler := func(c *celeris.Context) error { + reached = true + return c.String(200, "ok") + } + ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler)) + defer ReleaseContext(ctx) + + _ = ctx.Next() + if reached { + t.Fatal("handler should not have been reached after abort") + } + if rec.StatusCode != 401 { + t.Fatalf("expected 401, got %d", rec.StatusCode) + } +} diff --git a/context.go b/context.go index b5ae8fe..52c7098 100644 --- a/context.go +++ b/context.go @@ -65,6 +65,7 @@ type Context struct { handlers []HandlerFunc handlerBuf [8]HandlerFunc params Params + paramBuf [4]Param keys map[string]any ctx context.Context @@ -112,6 +113,7 @@ type Context struct { var contextPool = sync.Pool{New: func() any { c := &Context{keys: make(map[string]any, 4)} + c.params = c.paramBuf[:0] c.respHeaders = c.respHdrBuf[:0] return c }} @@ -281,13 +283,15 @@ func (c *Context) reset() { } else { c.handlers = c.handlers[:0] } - c.params = c.params[:0] + clear(c.paramBuf[:]) + c.params = c.paramBuf[:0] c.ctx = nil c.method = "" c.path = "" c.rawQuery = "" c.fullPath = "" c.statusCode = 200 + clear(c.respHdrBuf[:]) c.respHeaders = c.respHdrBuf[:0] c.written = false c.aborted = false diff --git a/context_request.go b/context_request.go index 8206754..b47423a 100644 --- a/context_request.go +++ b/context_request.go @@ -61,9 +61,21 @@ func (c *Context) ContentLength() int64 { return n } -// Header returns the value of the named request header. Keys must be lowercase -// (HTTP/2 mandates lowercase; the H1 parser normalizes to lowercase). +// Header returns the value of the named request header. Keys are normalized +// to lowercase automatically (HTTP/2 mandates lowercase; the H1 parser +// normalizes to lowercase). func (c *Context) Header(key string) string { + // Fast path: most programmatic keys are already lowercase. + needsLower := false + for i := range len(key) { + if key[i] >= 'A' && key[i] <= 'Z' { + needsLower = true + break + } + } + if needsLower { + key = strings.ToLower(key) + } for _, h := range c.stream.Headers { if h[0] == key { return h[1] @@ -98,6 +110,15 @@ func (c *Context) ParamInt64(key string) (int64, error) { return strconv.ParseInt(v, 10, 64) } +// ParamDefault returns the value of a URL parameter, or the default if absent or empty. +func (c *Context) ParamDefault(key, defaultValue string) string { + v := c.Param(key) + if v == "" { + return defaultValue + } + return v +} + // Query returns the value of a query parameter by name. Results are cached // so repeated calls for different keys do not re-parse the query string. func (c *Context) Query(key string) string { @@ -136,6 +157,38 @@ func (c *Context) QueryInt(key string, defaultValue int) int { return n } +// QueryInt64 returns a query parameter parsed as an int64. +// Returns the provided default value if the key is absent or not a valid integer. +func (c *Context) QueryInt64(key string, defaultValue int64) int64 { + v := c.Query(key) + if v == "" { + return defaultValue + } + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return defaultValue + } + return n +} + +// QueryBool returns a query parameter parsed as a bool. +// Returns the provided default value if the key is absent or not a valid bool. +// Recognizes "true", "1", "yes" as true and "false", "0", "no" as false. +func (c *Context) QueryBool(key string, defaultValue bool) bool { + v := c.Query(key) + if v == "" { + return defaultValue + } + switch strings.ToLower(v) { + case "true", "1", "yes": + return true + case "false", "0", "no": + return false + default: + return defaultValue + } +} + // QueryValues returns all values for the given query parameter key. // Returns nil if the key is not present. func (c *Context) QueryValues(key string) []string { @@ -283,6 +336,9 @@ func (c *Context) Scheme() string { return c.schemeOverride } if proto := c.Header("x-forwarded-proto"); proto != "" { + if proto == "https" || proto == "http" { + return proto + } return strings.ToLower(strings.TrimSpace(proto)) } if scheme := c.Header(":scheme"); scheme != "" { @@ -365,10 +421,10 @@ func (c *Context) FormValue(name string) string { return c.formValues.Get(name) } -// FormValueOk returns the first value for the named form field plus a boolean +// FormValueOK returns the first value for the named form field plus a boolean // indicating whether the field was present. Unlike FormValue, callers can // distinguish a missing field from an empty value. -func (c *Context) FormValueOk(name string) (string, bool) { +func (c *Context) FormValueOK(name string) (string, bool) { if err := c.parseForm(); err != nil { return "", false } @@ -379,6 +435,11 @@ func (c *Context) FormValueOk(name string) (string, bool) { return vs[0], true } +// Deprecated: Use [Context.FormValueOK] instead. +func (c *Context) FormValueOk(name string) (string, bool) { + return c.FormValueOK(name) +} + // FormValues returns all values for the named form field. func (c *Context) FormValues(name string) []string { if err := c.parseForm(); err != nil { diff --git a/context_request_test.go b/context_request_test.go index 7386720..98f04b7 100644 --- a/context_request_test.go +++ b/context_request_test.go @@ -1166,7 +1166,7 @@ func TestRequestHeaders(t *testing.T) { } } -func TestContextFormValueOk(t *testing.T) { +func TestContextFormValueOK(t *testing.T) { body := "name=alice&empty=" s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) @@ -1176,12 +1176,12 @@ func TestContextFormValueOk(t *testing.T) { c := acquireContext(s) defer releaseContext(c) - v, ok := c.FormValueOk("name") + v, ok := c.FormValueOK("name") if !ok || v != "alice" { t.Fatalf("expected (alice, true), got (%s, %v)", v, ok) } - v, ok = c.FormValueOk("empty") + v, ok = c.FormValueOK("empty") if !ok { t.Fatal("expected field 'empty' to be present") } @@ -1189,7 +1189,7 @@ func TestContextFormValueOk(t *testing.T) { t.Fatalf("expected empty string, got %q", v) } - _, ok = c.FormValueOk("missing") + _, ok = c.FormValueOK("missing") if ok { t.Fatal("expected missing field to return false") } @@ -1373,3 +1373,149 @@ func TestContextOverridesResetOnReuse(t *testing.T) { t.Fatal("expected hostOverride to be cleared") } } + +func TestQueryBool(t *testing.T) { + tests := []struct { + query string + key string + def bool + expected bool + }{ + {"debug=true", "debug", false, true}, + {"debug=1", "debug", false, true}, + {"debug=yes", "debug", false, true}, + {"debug=false", "debug", true, false}, + {"debug=0", "debug", true, false}, + {"debug=no", "debug", true, false}, + {"debug=TRUE", "debug", false, true}, + {"debug=FALSE", "debug", true, false}, + {"debug=Yes", "debug", false, true}, + {"debug=No", "debug", true, false}, + {"debug=invalid", "debug", true, true}, + {"debug=invalid", "debug", false, false}, + {"", "debug", true, true}, + {"", "debug", false, false}, + {"other=1", "debug", false, false}, + {"other=1", "debug", true, true}, + } + for _, tt := range tests { + name := tt.query + "_" + tt.key + if tt.def { + name += "_def=true" + } else { + name += "_def=false" + } + t.Run(name, func(t *testing.T) { + path := "/test" + if tt.query != "" { + path += "?" + tt.query + } + s, _ := newTestStream("GET", path) + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + got := c.QueryBool(tt.key, tt.def) + if got != tt.expected { + t.Fatalf("QueryBool(%q, %v) = %v, want %v", tt.key, tt.def, got, tt.expected) + } + }) + } +} + +func TestQueryInt64(t *testing.T) { + tests := []struct { + query string + key string + def int64 + expected int64 + }{ + {"page=42", "page", 0, 42}, + {"page=9999999999", "page", 0, 9999999999}, + {"page=-100", "page", 0, -100}, + {"page=0", "page", 99, 0}, + {"page=abc", "page", 10, 10}, + {"", "page", 5, 5}, + {"other=1", "page", 7, 7}, + {"page=9223372036854775807", "page", 0, 9223372036854775807}, + } + for _, tt := range tests { + name := tt.query + "_" + tt.key + t.Run(name, func(t *testing.T) { + path := "/test" + if tt.query != "" { + path += "?" + tt.query + } + s, _ := newTestStream("GET", path) + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + got := c.QueryInt64(tt.key, tt.def) + if got != tt.expected { + t.Fatalf("QueryInt64(%q, %d) = %d, want %d", tt.key, tt.def, got, tt.expected) + } + }) + } +} + +func TestParamDefault(t *testing.T) { + s, _ := newTestStream("GET", "/users/alice") + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + c.params = Params{ + {Key: "name", Value: "alice"}, + {Key: "empty", Value: ""}, + } + + if got := c.ParamDefault("name", "fallback"); got != "alice" { + t.Fatalf("expected alice, got %s", got) + } + if got := c.ParamDefault("empty", "fallback"); got != "fallback" { + t.Fatalf("expected fallback for empty param, got %s", got) + } + if got := c.ParamDefault("missing", "fallback"); got != "fallback" { + t.Fatalf("expected fallback for missing param, got %s", got) + } +} + +func TestFormValueOkDeprecated(t *testing.T) { + body := "name=alice&empty=" + s, _ := newTestStream("POST", "/form") + s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) + s.GetBuf().Write([]byte(body)) + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + // FormValueOk (deprecated) should return the same results as FormValueOK. + v1, ok1 := c.FormValueOK("name") + v2, ok2 := c.FormValueOk("name") + if v1 != v2 || ok1 != ok2 { + t.Fatalf("FormValueOk != FormValueOK: (%q,%v) vs (%q,%v)", v1, ok1, v2, ok2) + } + if !ok1 || v1 != "alice" { + t.Fatalf("expected (alice, true), got (%s, %v)", v1, ok1) + } + + v1, ok1 = c.FormValueOK("empty") + v2, ok2 = c.FormValueOk("empty") + if v1 != v2 || ok1 != ok2 { + t.Fatalf("FormValueOk != FormValueOK for empty: (%q,%v) vs (%q,%v)", v1, ok1, v2, ok2) + } + + _, ok1 = c.FormValueOK("missing") + _, ok2 = c.FormValueOk("missing") + if ok1 != ok2 { + t.Fatalf("FormValueOk != FormValueOK for missing: %v vs %v", ok1, ok2) + } + if ok1 { + t.Fatal("expected missing field to return false") + } +} diff --git a/context_response.go b/context_response.go index 2e3875a..c9542c4 100644 --- a/context_response.go +++ b/context_response.go @@ -128,12 +128,13 @@ func (c *Context) Blob(code int, contentType string, data []byte) error { nUser := len(c.respHeaders) total := nUser + 2 var headers [][2]string - if total <= 8 { + if total <= len(c.respHdrBuf) { // respHeaders shares backing array with respHdrBuf — copy user // headers to a stack temporary before overwriting the buffer. + // Max user headers in fast path: len(respHdrBuf) - 2 = 6. var tmp [6][2]string copy(tmp[:nUser], c.respHeaders) - headers = c.respHdrBuf[:0:8] + headers = c.respHdrBuf[:0:len(c.respHdrBuf)] headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) headers = append(headers, [2]string{"content-length", itoa(len(data))}) headers = append(headers, tmp[:nUser]...) @@ -186,7 +187,14 @@ func (c *Context) SetHeader(key, value string) { break } } - v := stripCRLF(value) + v := value + for i := range len(value) { + b := value[i] + if b == '\r' || b == '\n' || b == 0 { + v = stripCRLF(value) + break + } + } for i, h := range c.respHeaders { if h[0] == k { c.respHeaders[i][1] = v @@ -200,10 +208,23 @@ func (c *Context) SetHeader(key, value string) { // replace existing values — use this for headers that allow multiple values // (e.g. set-cookie). func (c *Context) AddHeader(key, value string) { - c.respHeaders = append(c.respHeaders, [2]string{ - sanitizeHeaderKey(key), - stripCRLF(value), - }) + k := key + for i := range len(key) { + b := key[i] + if b >= 'A' && b <= 'Z' || b == '\r' || b == '\n' || b == 0 { + k = sanitizeHeaderKey(key) + break + } + } + v := value + for i := range len(value) { + b := value[i] + if b == '\r' || b == '\n' || b == 0 { + v = stripCRLF(value) + break + } + } + c.respHeaders = append(c.respHeaders, [2]string{k, v}) } // sanitizeHeaderKey lowercases and strips CRLF/null bytes. Fast path avoids @@ -218,11 +239,13 @@ func sanitizeHeaderKey(s string) string { return s } +var crlfReplacer = strings.NewReplacer("\r", "", "\n", "", "\x00", "") + // stripCRLF removes \r, \n, and \x00 to prevent HTTP response splitting // (CWE-113) and null-byte header smuggling. func stripCRLF(s string) string { if strings.ContainsAny(s, "\r\n\x00") { - return strings.NewReplacer("\r", "", "\n", "", "\x00", "").Replace(s) + return crlfReplacer.Replace(s) } return s } @@ -256,11 +279,13 @@ func (c *Context) Redirect(code int, url string) error { return c.NoContent(code) } +var cookieUnsafeReplacer = strings.NewReplacer(";", "", "\r", "", "\n", "") + // stripCookieUnsafe strips characters that could inject cookie attributes // (semicolons) or cause header injection (CRLF) from cookie field values. func stripCookieUnsafe(s string) string { if strings.ContainsAny(s, ";\r\n") { - return strings.NewReplacer(";", "", "\r", "", "\n", "").Replace(s) + return cookieUnsafeReplacer.Replace(s) } return s } @@ -272,6 +297,7 @@ func stripCookieUnsafe(s string) string { // cookie attribute injection. func (c *Context) SetCookie(cookie *Cookie) { var b strings.Builder + b.Grow(128) b.WriteString(stripCookieUnsafe(cookie.Name)) b.WriteByte('=') b.WriteString(stripCookieUnsafe(cookie.Value)) @@ -368,15 +394,34 @@ func (c *Context) File(filePath string) error { // FileFromDir safely serves a file from within baseDir. The userPath is // cleaned and joined with baseDir; if the result escapes baseDir, a 400 -// error is returned. This prevents directory traversal when serving -// user-supplied paths. +// error is returned. Symlinks are resolved and rechecked to prevent a +// symlink under baseDir from escaping the directory boundary. func (c *Context) FileFromDir(baseDir, userPath string) error { abs := filepath.Clean(filepath.Join(baseDir, filepath.FromSlash(userPath))) base := filepath.Clean(baseDir) if abs != base && !strings.HasPrefix(abs, base+string(filepath.Separator)) { return NewHTTPError(400, "invalid file path") } - return c.File(abs) + // Resolve symlinks and recheck prefix to prevent symlink escape. + resolved, err := filepath.EvalSymlinks(abs) + if err != nil { + return err + } + resolvedBase, err := filepath.EvalSymlinks(base) + if err != nil { + return err + } + if resolved != resolvedBase && !strings.HasPrefix(resolved, resolvedBase+string(filepath.Separator)) { + return NewHTTPError(400, "invalid file path") + } + info, err := os.Stat(resolved) + if err != nil { + return err + } + if info.IsDir() { + return NewHTTPError(400, "invalid file path") + } + return c.File(resolved) } func parseRange(header string, size int64) (start, end int64, ok bool) { diff --git a/context_test.go b/context_test.go index aa762f7..cf60ebf 100644 --- a/context_test.go +++ b/context_test.go @@ -19,7 +19,8 @@ type mockResponseWriter struct { func (m *mockResponseWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error { m.status = status - m.headers = headers + m.headers = make([][2]string, len(headers)) + copy(m.headers, headers) m.body = make([]byte, len(body)) copy(m.body, body) return nil diff --git a/doc.go b/doc.go index 32aae88..88d06f4 100644 --- a/doc.go +++ b/doc.go @@ -431,9 +431,9 @@ // // # Form Presence // -// FormValueOk distinguishes a missing field from an empty value: +// FormValueOK distinguishes a missing field from an empty value: // -// val, ok := c.FormValueOk("name") +// val, ok := c.FormValueOK("name") // if !ok { // // field was not submitted // } diff --git a/handler.go b/handler.go index e277e4a..4ad38f2 100644 --- a/handler.go +++ b/handler.go @@ -12,7 +12,9 @@ import ( ) type routerAdapter struct { - server *Server + server *Server + notFoundChain []HandlerFunc + methodNotAllowedChain []HandlerFunc } func (a *routerAdapter) HandleStream(_ context.Context, s *stream.Stream) error { @@ -111,9 +113,13 @@ func (a *routerAdapter) handleUnmatched(c *Context, s *stream.Stream) { if len(allowed) > 0 { c.statusCode = 405 allowVal := strings.Join(allowed, ", ") - if a.server.methodNotAllowedHandler != nil { + chain := a.methodNotAllowedChain + if chain == nil && a.server.methodNotAllowedHandler != nil { + chain = []HandlerFunc{a.server.methodNotAllowedHandler} + } + if chain != nil { c.SetHeader("allow", allowVal) - c.handlers = []HandlerFunc{a.server.methodNotAllowedHandler} + c.handlers = chain a.handleError(c, s, c.Next()) } if !c.written && s.ResponseWriter != nil { @@ -125,8 +131,12 @@ func (a *routerAdapter) handleUnmatched(c *Context, s *stream.Stream) { } } else { c.statusCode = 404 - if a.server.notFoundHandler != nil { - c.handlers = []HandlerFunc{a.server.notFoundHandler} + chain := a.notFoundChain + if chain == nil && a.server.notFoundHandler != nil { + chain = []HandlerFunc{a.server.notFoundHandler} + } + if chain != nil { + c.handlers = chain a.handleError(c, s, c.Next()) } if !c.written && s.ResponseWriter != nil { diff --git a/internal/timer/wheel.go b/internal/timer/wheel.go deleted file mode 100644 index 040a491..0000000 --- a/internal/timer/wheel.go +++ /dev/null @@ -1,145 +0,0 @@ -// Package timer implements a hierarchical timer wheel for efficient -// timeout management in event-driven I/O engines. -package timer - -import ( - "sync" - "time" -) - -const ( - // WheelSize is the number of slots in the timer wheel (power of 2 for fast modulo). - WheelSize = 8192 - // TickInterval is the granularity of the timer wheel. - TickInterval = 100 * time.Millisecond - - wheelMask = WheelSize - 1 -) - -// TimeoutKind identifies the type of timeout. -type TimeoutKind uint8 - -// Timeout kinds for the timer wheel. -const ( - ReadTimeout TimeoutKind = iota - WriteTimeout - IdleTimeout -) - -// Entry represents a scheduled timeout. -type Entry struct { - FD int - Deadline int64 // unix nano - Kind TimeoutKind - next *Entry -} - -var entryPool = sync.Pool{New: func() any { return &Entry{} }} - -func getEntry() *Entry { - return entryPool.Get().(*Entry) -} - -func putEntry(e *Entry) { - e.FD = 0 - e.Deadline = 0 - e.Kind = 0 - e.next = nil - entryPool.Put(e) -} - -// Wheel is a hashed timer wheel with O(1) schedule and cancel. -type Wheel struct { - slots [WheelSize]*Entry - current uint64 - startNs int64 - onExpire func(fd int, kind TimeoutKind) - // fdSlots tracks the slot index for each FD to enable O(1) cancel. - fdSlots map[int]int64 // fd → deadline -} - -// New creates a timer wheel that calls onExpire for each expired entry. -func New(onExpire func(fd int, kind TimeoutKind)) *Wheel { - return &Wheel{ - startNs: time.Now().UnixNano(), - onExpire: onExpire, - fdSlots: make(map[int]int64), - } -} - -func (w *Wheel) slotIndex(deadline int64) uint64 { - tick := uint64(deadline-w.startNs) / uint64(TickInterval) - return tick & wheelMask -} - -// Schedule registers a timeout for the given FD. Any previous timeout for -// the same FD is logically cancelled (the old entry is ignored on expiry). -func (w *Wheel) Schedule(fd int, timeout time.Duration, kind TimeoutKind) { - w.ScheduleAt(fd, timeout, kind, time.Now().UnixNano()) -} - -// ScheduleAt is like Schedule but accepts a pre-computed now timestamp (UnixNano) -// to avoid redundant time.Now() calls when scheduling multiple timeouts in the -// same request processing cycle. -func (w *Wheel) ScheduleAt(fd int, timeout time.Duration, kind TimeoutKind, nowNs int64) { - deadline := nowNs + int64(timeout) - slot := w.slotIndex(deadline) - - e := getEntry() - e.FD = fd - e.Deadline = deadline - e.Kind = kind - e.next = w.slots[slot] - w.slots[slot] = e - - w.fdSlots[fd] = deadline -} - -// Cancel removes any pending timeout for the given FD. -func (w *Wheel) Cancel(fd int) { - delete(w.fdSlots, fd) -} - -// Tick advances the wheel and fires expired entries. Returns the number -// of entries that expired. Call this from the event loop after processing -// I/O events. -func (w *Wheel) Tick() int { - now := time.Now().UnixNano() - currentTick := uint64(now-w.startNs) / uint64(TickInterval) - count := 0 - - for w.current <= currentTick { - slot := w.current & wheelMask - var prev *Entry - e := w.slots[slot] - for e != nil { - next := e.next - if e.Deadline <= now { - // Check if this entry is still the active one for this FD. - if active, ok := w.fdSlots[e.FD]; ok && active == e.Deadline { - delete(w.fdSlots, e.FD) - w.onExpire(e.FD, e.Kind) - count++ - } - // Remove from list. - if prev == nil { - w.slots[slot] = next - } else { - prev.next = next - } - putEntry(e) - } else { - prev = e - } - e = next - } - w.current++ - } - - return count -} - -// Len returns the number of tracked FDs (approximate). -func (w *Wheel) Len() int { - return len(w.fdSlots) -} diff --git a/internal/timer/wheel_test.go b/internal/timer/wheel_test.go deleted file mode 100644 index 3e3d3eb..0000000 --- a/internal/timer/wheel_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package timer - -import ( - "sync/atomic" - "testing" - "time" -) - -func TestWheelScheduleAndExpire(t *testing.T) { - var expired atomic.Int32 - var expiredFD int - var expiredKind TimeoutKind - - w := New(func(fd int, kind TimeoutKind) { - expiredFD = fd - expiredKind = kind - expired.Add(1) - }) - - w.Schedule(42, 150*time.Millisecond, IdleTimeout) - - // Should not expire immediately. - n := w.Tick() - if n != 0 { - t.Fatalf("expected 0 expired, got %d", n) - } - - time.Sleep(200 * time.Millisecond) - n = w.Tick() - if n != 1 { - t.Fatalf("expected 1 expired, got %d", n) - } - if expiredFD != 42 { - t.Fatalf("expected FD 42, got %d", expiredFD) - } - if expiredKind != IdleTimeout { - t.Fatalf("expected IdleTimeout, got %d", expiredKind) - } -} - -func TestWheelCancel(t *testing.T) { - var expired atomic.Int32 - w := New(func(_ int, _ TimeoutKind) { - expired.Add(1) - }) - - w.Schedule(10, 150*time.Millisecond, ReadTimeout) - w.Cancel(10) - - time.Sleep(200 * time.Millisecond) - w.Tick() - if expired.Load() != 0 { - t.Fatalf("expected 0 expired after cancel, got %d", expired.Load()) - } -} - -func TestWheelMultipleEntries(t *testing.T) { - expired := make(map[int]bool) - w := New(func(fd int, _ TimeoutKind) { - expired[fd] = true - }) - - w.Schedule(1, 150*time.Millisecond, IdleTimeout) - w.Schedule(2, 150*time.Millisecond, IdleTimeout) - w.Schedule(3, 150*time.Millisecond, IdleTimeout) - - time.Sleep(200 * time.Millisecond) - n := w.Tick() - if n != 3 { - t.Fatalf("expected 3 expired, got %d", n) - } - for _, fd := range []int{1, 2, 3} { - if !expired[fd] { - t.Fatalf("expected FD %d to expire", fd) - } - } -} - -func TestWheelReschedule(t *testing.T) { - var lastKind TimeoutKind - count := 0 - w := New(func(_ int, kind TimeoutKind) { - lastKind = kind - count++ - }) - - // Schedule then reschedule with different kind. - w.Schedule(5, 150*time.Millisecond, ReadTimeout) - w.Schedule(5, 150*time.Millisecond, WriteTimeout) - - time.Sleep(200 * time.Millisecond) - w.Tick() - // Only the latest schedule should fire. - if count != 1 { - t.Fatalf("expected 1 expiry, got %d", count) - } - if lastKind != WriteTimeout { - t.Fatalf("expected WriteTimeout, got %d", lastKind) - } -} - -func TestWheelLen(t *testing.T) { - w := New(func(_ int, _ TimeoutKind) {}) - w.Schedule(1, time.Second, IdleTimeout) - w.Schedule(2, time.Second, IdleTimeout) - if w.Len() != 2 { - t.Fatalf("expected Len=2, got %d", w.Len()) - } - w.Cancel(1) - if w.Len() != 1 { - t.Fatalf("expected Len=1, got %d", w.Len()) - } -} - -func BenchmarkWheelSchedule(b *testing.B) { - w := New(func(_ int, _ TimeoutKind) {}) - b.ResetTimer() - for i := range b.N { - w.Schedule(i%10000, 5*time.Second, IdleTimeout) - } -} - -func BenchmarkWheelTick(b *testing.B) { - w := New(func(_ int, _ TimeoutKind) {}) - // Pre-fill with entries that won't expire. - for i := range 1000 { - w.Schedule(i, time.Hour, IdleTimeout) - } - b.ResetTimer() - for range b.N { - w.Tick() - } -} diff --git a/protocol/h1/parser.go b/protocol/h1/parser.go index bf18908..0c35083 100644 --- a/protocol/h1/parser.go +++ b/protocol/h1/parser.go @@ -13,7 +13,8 @@ var ( ErrMissingHost = errors.New("missing Host header") ErrUnsupportedVersion = errors.New("unsupported HTTP version") ErrHeadersTooLarge = errors.New("headers too large") - ErrInvalidContentLength = errors.New("invalid content-length") + ErrInvalidContentLength = errors.New("invalid content-length") + ErrDuplicateContentLength = errors.New("duplicate content-length with conflicting values") ) // Parser is a zero-allocation HTTP/1.x request parser. @@ -171,6 +172,9 @@ func (p *Parser) appendHeader(req *Request, rawName, rawValue []byte) error { if !ok { return ErrInvalidContentLength } + if req.ContentLength >= 0 && req.ContentLength != cl { + return ErrDuplicateContentLength + } req.ContentLength = cl case "transfer-encoding": if asciiContainsFoldString(value, "chunked") { @@ -206,11 +210,14 @@ func (p *Parser) appendHeader(req *Request, rawName, rawValue []byte) error { if req.ChunkedEncoding { return nil } - if cl, ok := parseInt64Bytes(rawValue); ok { - req.ContentLength = cl - } else { + cl, ok := parseInt64Bytes(rawValue) + if !ok { return ErrInvalidContentLength } + if req.ContentLength >= 0 && req.ContentLength != cl { + return ErrDuplicateContentLength + } + req.ContentLength = cl return nil } if asciiEqualFold(rawName, "Connection") { diff --git a/protocol/h1/parser_test.go b/protocol/h1/parser_test.go index 7586d6b..f0b6aa2 100644 --- a/protocol/h1/parser_test.go +++ b/protocol/h1/parser_test.go @@ -530,6 +530,81 @@ func TestFindHeaderEnd_MultipleCR(t *testing.T) { } } +func TestParseRequest_DuplicateContentLength_Conflicting(t *testing.T) { + for _, zeroCopy := range []bool{false, true} { + name := "standard" + if zeroCopy { + name = "zerocopy" + } + t.Run(name, func(t *testing.T) { + raw := "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n" + p := NewParser() + if zeroCopy { + p.noStringHeaders = true + } + p.Reset([]byte(raw)) + var req Request + _, err := p.ParseRequest(&req) + if !errors.Is(err, ErrDuplicateContentLength) { + t.Fatalf("got error %v, want %v", err, ErrDuplicateContentLength) + } + }) + } +} + +func TestParseRequest_DuplicateContentLength_Identical(t *testing.T) { + for _, zeroCopy := range []bool{false, true} { + name := "standard" + if zeroCopy { + name = "zerocopy" + } + t.Run(name, func(t *testing.T) { + raw := "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 42\r\nContent-Length: 42\r\n\r\n" + p := NewParser() + if zeroCopy { + p.noStringHeaders = true + } + p.Reset([]byte(raw)) + var req Request + _, err := p.ParseRequest(&req) + if err != nil { + t.Fatalf("identical CL should be accepted, got: %v", err) + } + if req.ContentLength != 42 { + t.Fatalf("content-length = %d, want 42", req.ContentLength) + } + }) + } +} + +func TestParseRequest_DuplicateContentLength_ChunkedIgnored(t *testing.T) { + for _, zeroCopy := range []bool{false, true} { + name := "standard" + if zeroCopy { + name = "zerocopy" + } + t.Run(name, func(t *testing.T) { + raw := "POST / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n" + p := NewParser() + if zeroCopy { + p.noStringHeaders = true + } + p.Reset([]byte(raw)) + var req Request + _, err := p.ParseRequest(&req) + if err != nil { + t.Fatalf("conflicting CL with chunked TE should be ignored, got: %v", err) + } + if !req.ChunkedEncoding { + t.Fatal("chunked encoding not detected") + } + if req.ContentLength != -1 { + t.Fatalf("content-length = %d, want -1 for chunked", req.ContentLength) + } + }) + } +} + // TestFindHeaderEnd_PartialSequence ensures partial \r\n sequences // (without the full \r\n\r\n) are not mistakenly matched. func TestFindHeaderEnd_PartialSequence(t *testing.T) { diff --git a/protocol/h2/stream/stream.go b/protocol/h2/stream/stream.go index adcf45a..9ba93fd 100644 --- a/protocol/h2/stream/stream.go +++ b/protocol/h2/stream/stream.go @@ -194,7 +194,10 @@ func (s *Stream) resetAndPool() { s.ID = 0 s.state.Store(0) s.manager = nil + clear(s.hdrBuf[:]) s.Headers = s.hdrBuf[:0] + s.Trailers = s.Trailers[:cap(s.Trailers)] + clear(s.Trailers) s.Trailers = s.Trailers[:0] s.OutboundEndStream = false s.headersSent.Store(false) diff --git a/server.go b/server.go index 7b32daa..c2da332 100644 --- a/server.go +++ b/server.go @@ -375,7 +375,14 @@ func (s *Server) doPrepare(configureFn func(cfg *resource.Config)) (engine.Engin s.collector = observe.NewCollector() } - var handler stream.Handler = &routerAdapter{server: s} + ra := &routerAdapter{server: s} + if s.notFoundHandler != nil { + ra.notFoundChain = []HandlerFunc{s.notFoundHandler} + } + if s.methodNotAllowedHandler != nil { + ra.methodNotAllowedChain = []HandlerFunc{s.methodNotAllowedHandler} + } + var handler stream.Handler = ra var err error eng, err = createEngine(cfg, handler) if err != nil { diff --git a/stdlib.go b/stdlib.go index 97f1723..f82bf19 100644 --- a/stdlib.go +++ b/stdlib.go @@ -11,7 +11,7 @@ import ( "golang.org/x/net/http2" ) -const maxToHandlerBodySize = 100 << 20 // 100 MB +const maxToHandlerBodySize = maxBodySize // ToHandler wraps a celeris HandlerFunc as an http.Handler for use with // net/http routers, middleware, or test infrastructure. The returned handler diff --git a/types.go b/types.go index 31a29d7..859035a 100644 --- a/types.go +++ b/types.go @@ -11,7 +11,11 @@ type HandlerFunc func(*Context) error // parsing (32 MB), matching net/http. const DefaultMaxFormSize int64 = 32 << 20 -const maxStreamBodySize = 100 << 20 // 100MB +// maxBodySize is the maximum request/response body size (100 MB), shared by +// stream responses (File, Stream), bridge adapter output, and ToHandler input. +const maxBodySize = 100 << 20 + +const maxStreamBodySize = maxBodySize // SameSite controls the SameSite attribute of a cookie. type SameSite int From c94fa5e06dc77656f693d6cc0b814907ad2b0b85 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 13:38:41 +0200 Subject: [PATCH 09/11] fix: guard BasicAuth against oversized credentials base64.StdEncoding.Decode panics (not returns error) when the decoded output exceeds the destination buffer. Add a DecodedLen pre-check before decoding to gracefully return ok=false for credentials exceeding the 128-byte stack buffer, preventing a panic on crafted Authorization headers with long payloads. --- context_request.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/context_request.go b/context_request.go index b47423a..0a6dbde 100644 --- a/context_request.go +++ b/context_request.go @@ -399,6 +399,9 @@ func (c *Context) BasicAuth() (username, password string, ok bool) { } payload := auth[len(prefix):] var buf [128]byte + if base64.StdEncoding.DecodedLen(len(payload)) > len(buf) { + return + } n, err := base64.StdEncoding.Decode(buf[:], unsafe.Slice(unsafe.StringData(payload), len(payload))) if err != nil { From 542370b6e5921171cc0cf4a4dd92753459857df2 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 13:51:57 +0200 Subject: [PATCH 10/11] =?UTF-8?q?fix:=20lint=20=E2=80=94=20gofmt,=20unused?= =?UTF-8?q?=20param,=20exported=20func=20comment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- celeristest/celeristest_test.go | 2 +- context.go | 6 +++--- context_request.go | 2 ++ handler.go | 6 +++--- internal/ctxkit/ctxkit.go | 12 ++++++------ 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/celeristest/celeristest_test.go b/celeristest/celeristest_test.go index 33e21f2..ccb86c1 100644 --- a/celeristest/celeristest_test.go +++ b/celeristest/celeristest_test.go @@ -238,7 +238,7 @@ func TestWithHandlersErrorPropagation(t *testing.T) { mw := func(c *celeris.Context) error { return c.Next() } - handler := func(c *celeris.Context) error { + handler := func(_ *celeris.Context) error { return celeris.NewHTTPError(403, "forbidden") } ctx, _ := NewContext("GET", "/test", WithHandlers(mw, handler)) diff --git a/context.go b/context.go index 52c7098..8eb926c 100644 --- a/context.go +++ b/context.go @@ -60,14 +60,14 @@ func init() { // Context is the request context passed to handlers. It is pooled via sync.Pool. // A Context is obtained from the pool and must not be retained after the handler returns. type Context struct { - stream *stream.Stream + stream *stream.Stream index int16 handlers []HandlerFunc handlerBuf [8]HandlerFunc params Params paramBuf [4]Param - keys map[string]any - ctx context.Context + keys map[string]any + ctx context.Context method string path string diff --git a/context_request.go b/context_request.go index 0a6dbde..b19ee11 100644 --- a/context_request.go +++ b/context_request.go @@ -438,6 +438,8 @@ func (c *Context) FormValueOK(name string) (string, bool) { return vs[0], true } +// FormValueOk is a deprecated alias for [Context.FormValueOK]. +// // Deprecated: Use [Context.FormValueOK] instead. func (c *Context) FormValueOk(name string) (string, bool) { return c.FormValueOK(name) diff --git a/handler.go b/handler.go index 4ad38f2..81bcb63 100644 --- a/handler.go +++ b/handler.go @@ -12,9 +12,9 @@ import ( ) type routerAdapter struct { - server *Server - notFoundChain []HandlerFunc - methodNotAllowedChain []HandlerFunc + server *Server + notFoundChain []HandlerFunc + methodNotAllowedChain []HandlerFunc } func (a *routerAdapter) HandleStream(_ context.Context, s *stream.Stream) error { diff --git a/internal/ctxkit/ctxkit.go b/internal/ctxkit/ctxkit.go index d7858ec..2a8f06f 100644 --- a/internal/ctxkit/ctxkit.go +++ b/internal/ctxkit/ctxkit.go @@ -7,10 +7,10 @@ import "github.com/goceleris/celeris/protocol/h2/stream" // Hooks registered by the celeris package at init time. var ( - NewContext func(s *stream.Stream) any - ReleaseContext func(c any) - AddParam func(c any, key, value string) - SetHandlers func(c any, handlers []any) - GetResponseWriter func(c any) any - GetStream func(c any) any + NewContext func(s *stream.Stream) any + ReleaseContext func(c any) + AddParam func(c any, key, value string) + SetHandlers func(c any, handlers []any) + GetResponseWriter func(c any) any + GetStream func(c any) any ) From 2118c7a4966663dcb6f66674e5ebfd0bb0ab5d41 Mon Sep 17 00:00:00 2001 From: Albert Bausili Date: Wed, 1 Apr 2026 13:55:11 +0200 Subject: [PATCH 11/11] fix: gofmt parser.go --- protocol/h1/parser.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/protocol/h1/parser.go b/protocol/h1/parser.go index 0c35083..c9b4b82 100644 --- a/protocol/h1/parser.go +++ b/protocol/h1/parser.go @@ -7,14 +7,14 @@ import ( // H1 parser sentinel errors. var ( - ErrBufferExhausted = errors.New("buffer exhausted") - ErrInvalidRequestLine = errors.New("invalid request line") - ErrInvalidHeader = errors.New("invalid header line") - ErrMissingHost = errors.New("missing Host header") - ErrUnsupportedVersion = errors.New("unsupported HTTP version") - ErrHeadersTooLarge = errors.New("headers too large") - ErrInvalidContentLength = errors.New("invalid content-length") - ErrDuplicateContentLength = errors.New("duplicate content-length with conflicting values") + ErrBufferExhausted = errors.New("buffer exhausted") + ErrInvalidRequestLine = errors.New("invalid request line") + ErrInvalidHeader = errors.New("invalid header line") + ErrMissingHost = errors.New("missing Host header") + ErrUnsupportedVersion = errors.New("unsupported HTTP version") + ErrHeadersTooLarge = errors.New("headers too large") + ErrInvalidContentLength = errors.New("invalid content-length") + ErrDuplicateContentLength = errors.New("duplicate content-length with conflicting values") ) // Parser is a zero-allocation HTTP/1.x request parser.