diff --git a/bridge.go b/bridge.go
index 6dce906..910cc36 100644
--- a/bridge.go
+++ b/bridge.go
@@ -89,7 +89,7 @@ func buildHTTPRequest(c *Context) (*http.Request, error) {
return req, nil
}
-const maxBridgeResponseBytes = 100 << 20 // 100 MB
+const maxBridgeResponseBytes = maxBodySize
var errBridgeResponseTooLarge = errors.New("bridge: response body exceeds 100MB limit")
diff --git a/celeristest/celeristest.go b/celeristest/celeristest.go
index fe49083..517b191 100644
--- a/celeristest/celeristest.go
+++ b/celeristest/celeristest.go
@@ -13,6 +13,7 @@ import (
"encoding/base64"
"net/url"
"strings"
+ "sync"
"testing"
"github.com/goceleris/celeris"
@@ -48,18 +49,32 @@ func (r *ResponseRecorder) BodyString() string {
return string(r.Body)
}
+// recorderCombo bundles a ResponseRecorder and its recorderWriter in a
+// single allocation so they can be pooled together.
+type recorderCombo struct {
+ rec ResponseRecorder
+ rw recorderWriter
+}
+
+var recorderPool = sync.Pool{New: func() any {
+ combo := &recorderCombo{}
+ combo.rw.rec = &combo.rec
+ combo.rw.combo = combo
+ return combo
+}}
+
// recorderWriter adapts a ResponseRecorder to the internal
// stream.ResponseWriter interface without leaking internal types
// in the public API.
type recorderWriter struct {
- rec *ResponseRecorder
+ rec *ResponseRecorder
+ combo *recorderCombo
}
func (w *recorderWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error {
w.rec.StatusCode = status
w.rec.Headers = headers
- w.rec.Body = make([]byte, len(body))
- copy(w.rec.Body, body)
+ w.rec.Body = append(w.rec.Body[:0], body...)
return nil
}
@@ -77,13 +92,38 @@ var _ stream.ResponseWriter = (*recorderWriter)(nil)
type Option func(*config)
type config struct {
- body []byte
- headers [][2]string
- queries [][2]string
- params [][2]string
- cookies [][2]string
- remoteAddr string
- handlers []any
+ body []byte
+ headers [][2]string
+ queries [][2]string
+ params [][2]string
+ cookies [][2]string
+ remoteAddr string
+ handlers []any
+ headersBuf [4][2]string
+ handlersBuf [4]any
+}
+
+var configPool = sync.Pool{New: func() any {
+ c := &config{}
+ c.headers = c.headersBuf[:0]
+ c.handlers = c.handlersBuf[:0]
+ return c
+}}
+
+func (c *config) reset() {
+ c.body = nil
+ for i := range c.headers {
+ c.headers[i] = [2]string{}
+ }
+ c.headers = c.headersBuf[:0]
+ c.queries = nil
+ c.params = nil
+ c.cookies = nil
+ c.remoteAddr = ""
+ for i := range c.handlers {
+ c.handlers[i] = nil
+ }
+ c.handlers = c.handlersBuf[:0]
}
// WithBody sets the request body.
@@ -134,9 +174,17 @@ func WithRemoteAddr(addr string) Option {
// Pass celeris.HandlerFunc values; they are stored as []any to avoid import cycles.
func WithHandlers(handlers ...celeris.HandlerFunc) Option {
return func(c *config) {
- c.handlers = make([]any, len(handlers))
- for i, h := range handlers {
- c.handlers[i] = h
+ n := len(handlers)
+ if n <= len(c.handlersBuf) {
+ for i, h := range handlers {
+ c.handlersBuf[i] = h
+ }
+ c.handlers = c.handlersBuf[:n]
+ } else {
+ c.handlers = make([]any, n)
+ for i, h := range handlers {
+ c.handlers[i] = h
+ }
}
}
}
@@ -144,7 +192,38 @@ func WithHandlers(handlers ...celeris.HandlerFunc) Option {
// ReleaseContext returns a [celeris.Context] to the pool. The context must not
// be used after this call.
func ReleaseContext(ctx *celeris.Context) {
+ // Extract stream and recorder before releasing context (reset nils them).
+ var s *stream.Stream
+ if raw := ctxkit.GetStream(ctx); raw != nil {
+ s = raw.(*stream.Stream)
+ }
+ var combo *recorderCombo
+ if rw := ctxkit.GetResponseWriter(ctx); rw != nil {
+ if w, ok := rw.(*recorderWriter); ok && w.combo != nil {
+ combo = w.combo
+ }
+ }
+
ctxkit.ReleaseContext(ctx)
+
+ // Return stream to pool so NewStream reuses it on the next call.
+ if s != nil {
+ if s.HasDoneCh() {
+ // A derived context (e.g. context.WithTimeout) spawned a
+ // goroutine that holds a reference to this stream. Cancel to
+ // terminate it, but don't pool the stream — the goroutine may
+ // still read Err()/Done() after the pool recycles the struct.
+ s.Cancel()
+ } else {
+ stream.ResetForPool(s)
+ }
+ }
+ if combo != nil {
+ combo.rec.StatusCode = 0
+ combo.rec.Headers = nil
+ combo.rec.Body = combo.rec.Body[:0]
+ recorderPool.Put(combo)
+ }
}
// NewContextT is like [NewContext] but registers an automatic cleanup with
@@ -160,7 +239,7 @@ func NewContextT(t *testing.T, method, path string, opts ...Option) (*celeris.Co
// The returned context has the given method and path, plus any options applied.
// Call [ReleaseContext] when done to clean up pooled resources.
func NewContext(method, path string, opts ...Option) (*celeris.Context, *ResponseRecorder) {
- cfg := &config{}
+ cfg := configPool.Get().(*config)
for _, o := range opts {
o(cfg)
}
@@ -175,12 +254,12 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons
}
s := stream.NewStream(1)
- s.Headers = [][2]string{
- {":method", method},
- {":path", fullPath},
- {":scheme", "http"},
- {":authority", "localhost"},
- }
+ s.Headers = append(s.Headers,
+ [2]string{":method", method},
+ [2]string{":path", fullPath},
+ [2]string{":scheme", "http"},
+ [2]string{":authority", "localhost"},
+ )
s.Headers = append(s.Headers, cfg.headers...)
if len(cfg.cookies) > 0 {
parts := make([]string, 0, len(cfg.cookies))
@@ -190,11 +269,15 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons
s.Headers = append(s.Headers, [2]string{"cookie", strings.Join(parts, "; ")})
}
if len(cfg.body) > 0 {
- s.Data.Write(cfg.body)
+ s.GetBuf().Write(cfg.body)
}
- rec := &ResponseRecorder{}
- s.ResponseWriter = &recorderWriter{rec: rec}
+ combo := recorderPool.Get().(*recorderCombo)
+ combo.rec.StatusCode = 0
+ combo.rec.Headers = nil
+ combo.rec.Body = combo.rec.Body[:0]
+ rec := &combo.rec
+ s.ResponseWriter = &combo.rw
if cfg.remoteAddr != "" {
s.RemoteAddr = cfg.remoteAddr
@@ -207,5 +290,9 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons
if len(cfg.handlers) > 0 {
ctxkit.SetHandlers(ctx, cfg.handlers)
}
+
+ cfg.reset()
+ configPool.Put(cfg)
+
return ctx, rec
}
diff --git a/celeristest/celeristest_test.go b/celeristest/celeristest_test.go
index e5190f0..ccb86c1 100644
--- a/celeristest/celeristest_test.go
+++ b/celeristest/celeristest_test.go
@@ -205,3 +205,111 @@ func TestWithRemoteAddr(t *testing.T) {
t.Fatalf("expected 192.168.1.1:54321, got %s", ctx.RemoteAddr())
}
}
+
+func TestWithHandlers(t *testing.T) {
+ var order []string
+ mw := func(c *celeris.Context) error {
+ order = append(order, "mw")
+ return c.Next()
+ }
+ handler := func(c *celeris.Context) error {
+ order = append(order, "handler")
+ return c.String(200, "ok")
+ }
+ ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler))
+ defer ReleaseContext(ctx)
+
+ err := ctx.Next()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if rec.StatusCode != 200 {
+ t.Fatalf("expected 200, got %d", rec.StatusCode)
+ }
+ if rec.BodyString() != "ok" {
+ t.Fatalf("expected body 'ok', got %q", rec.BodyString())
+ }
+ if len(order) != 2 || order[0] != "mw" || order[1] != "handler" {
+ t.Fatalf("unexpected execution order: %v", order)
+ }
+}
+
+func TestWithHandlersErrorPropagation(t *testing.T) {
+ mw := func(c *celeris.Context) error {
+ return c.Next()
+ }
+ handler := func(_ *celeris.Context) error {
+ return celeris.NewHTTPError(403, "forbidden")
+ }
+ ctx, _ := NewContext("GET", "/test", WithHandlers(mw, handler))
+ defer ReleaseContext(ctx)
+
+ err := ctx.Next()
+ if err == nil {
+ t.Fatal("expected error from handler chain")
+ }
+ he, ok := err.(*celeris.HTTPError)
+ if !ok {
+ t.Fatalf("expected *HTTPError, got %T", err)
+ }
+ if he.Code != 403 {
+ t.Fatalf("expected 403, got %d", he.Code)
+ }
+}
+
+func TestWithHandlersManyHandlers(t *testing.T) {
+ var order []string
+ makeMW := func(name string) celeris.HandlerFunc {
+ return func(c *celeris.Context) error {
+ order = append(order, name)
+ return c.Next()
+ }
+ }
+ handler := func(c *celeris.Context) error {
+ order = append(order, "handler")
+ return c.String(200, "ok")
+ }
+ // 5 middleware + 1 handler = 6 total, exceeds the 4-element handlersBuf.
+ ctx, rec := NewContext("GET", "/test", WithHandlers(
+ makeMW("mw1"), makeMW("mw2"), makeMW("mw3"), makeMW("mw4"), makeMW("mw5"), handler,
+ ))
+ defer ReleaseContext(ctx)
+
+ err := ctx.Next()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if rec.StatusCode != 200 {
+ t.Fatalf("expected 200, got %d", rec.StatusCode)
+ }
+ expected := []string{"mw1", "mw2", "mw3", "mw4", "mw5", "handler"}
+ if len(order) != len(expected) {
+ t.Fatalf("expected %d entries, got %d: %v", len(expected), len(order), order)
+ }
+ for i, e := range expected {
+ if order[i] != e {
+ t.Fatalf("order[%d] = %q, want %q", i, order[i], e)
+ }
+ }
+}
+
+func TestWithHandlersAbort(t *testing.T) {
+ var reached bool
+ mw := func(c *celeris.Context) error {
+ return c.AbortWithStatus(401)
+ }
+ handler := func(c *celeris.Context) error {
+ reached = true
+ return c.String(200, "ok")
+ }
+ ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler))
+ defer ReleaseContext(ctx)
+
+ _ = ctx.Next()
+ if reached {
+ t.Fatal("handler should not have been reached after abort")
+ }
+ if rec.StatusCode != 401 {
+ t.Fatalf("expected 401, got %d", rec.StatusCode)
+ }
+}
diff --git a/context.go b/context.go
index 80099df..8eb926c 100644
--- a/context.go
+++ b/context.go
@@ -25,24 +25,49 @@ func init() {
}
ctxkit.SetHandlers = func(c any, handlers []any) {
ctx := c.(*Context)
- chain := make([]HandlerFunc, len(handlers))
- for i, h := range handlers {
- chain[i] = h.(HandlerFunc)
+ n := len(handlers)
+ var chain []HandlerFunc
+ if n <= len(ctx.handlerBuf) {
+ for i, h := range handlers {
+ ctx.handlerBuf[i] = h.(HandlerFunc)
+ }
+ chain = ctx.handlerBuf[:n]
+ } else {
+ chain = make([]HandlerFunc, n)
+ for i, h := range handlers {
+ chain[i] = h.(HandlerFunc)
+ }
}
ctx.handlers = chain
ctx.index = -1
}
+ ctxkit.GetResponseWriter = func(c any) any {
+ ctx := c.(*Context)
+ if ctx.stream != nil {
+ return ctx.stream.ResponseWriter
+ }
+ return nil
+ }
+ ctxkit.GetStream = func(c any) any {
+ ctx := c.(*Context)
+ if ctx.stream != nil {
+ return ctx.stream
+ }
+ return nil
+ }
}
// Context is the request context passed to handlers. It is pooled via sync.Pool.
// A Context is obtained from the pool and must not be retained after the handler returns.
type Context struct {
- stream *stream.Stream
- index int16
- handlers []HandlerFunc
- params Params
- keys map[string]any
- ctx context.Context
+ stream *stream.Stream
+ index int16
+ handlers []HandlerFunc
+ handlerBuf [8]HandlerFunc
+ params Params
+ paramBuf [4]Param
+ keys map[string]any
+ ctx context.Context
method string
path string
@@ -86,7 +111,12 @@ type Context struct {
respHdrBuf [8][2]string // reusable buffer for response headers (avoids heap escape)
}
-var contextPool = sync.Pool{New: func() any { return &Context{} }}
+var contextPool = sync.Pool{New: func() any {
+ c := &Context{keys: make(map[string]any, 4)}
+ c.params = c.paramBuf[:0]
+ c.respHeaders = c.respHdrBuf[:0]
+ return c
+}}
const abortIndex int16 = math.MaxInt16 / 2
@@ -213,17 +243,11 @@ func (c *Context) SetContext(ctx context.Context) {
// Set stores a key-value pair for this request.
func (c *Context) Set(key string, value any) {
c.extended = true
- if c.keys == nil {
- c.keys = make(map[string]any)
- }
c.keys[key] = value
}
// Get returns the value for a key.
func (c *Context) Get(key string) (any, bool) {
- if c.keys == nil {
- return nil, false
- }
v, ok := c.keys[key]
return v, ok
}
@@ -231,7 +255,7 @@ func (c *Context) Get(key string) (any, bool) {
// Keys returns a copy of all key-value pairs stored on this context.
// Returns nil if no values have been set.
func (c *Context) Keys() map[string]any {
- if c.keys == nil {
+ if len(c.keys) == 0 {
return nil
}
cp := make(map[string]any, len(c.keys))
@@ -244,25 +268,37 @@ func (c *Context) Keys() map[string]any {
func (c *Context) reset() {
c.stream = nil
c.index = -1
+ // Clear handler references so closures can be GCed, but only when
+ // the slice is owned by this context (backed by handlerBuf or a
+ // SetHandlers allocation). The production path assigns the router's
+ // pre-composed chain directly; nilling those would corrupt shared state.
+ if len(c.handlers) > 0 && cap(c.handlers) <= len(c.handlerBuf) &&
+ &c.handlers[:cap(c.handlers)][0] == &c.handlerBuf[0] {
+ for i := range c.handlers {
+ c.handlers[i] = nil
+ }
+ }
if cap(c.handlers) > 64 {
c.handlers = nil
} else {
c.handlers = c.handlers[:0]
}
- c.params = c.params[:0]
+ clear(c.paramBuf[:])
+ c.params = c.paramBuf[:0]
c.ctx = nil
c.method = ""
c.path = ""
c.rawQuery = ""
c.fullPath = ""
c.statusCode = 200
- c.respHeaders = c.respHeaders[:0]
+ clear(c.respHdrBuf[:])
+ c.respHeaders = c.respHdrBuf[:0]
c.written = false
c.aborted = false
c.bytesWritten = 0
c.maxFormSize = 0
if c.extended {
- c.keys = nil
+ clear(c.keys)
c.queryCache = nil
c.queryCached = false
c.cookieCache = c.cookieCache[:0]
diff --git a/context_request.go b/context_request.go
index 4cbe383..b19ee11 100644
--- a/context_request.go
+++ b/context_request.go
@@ -13,6 +13,7 @@ import (
"net/url"
"strconv"
"strings"
+ "unsafe"
"github.com/goceleris/celeris/internal/negotiate"
)
@@ -60,9 +61,21 @@ func (c *Context) ContentLength() int64 {
return n
}
-// Header returns the value of the named request header. Keys must be lowercase
-// (HTTP/2 mandates lowercase; the H1 parser normalizes to lowercase).
+// Header returns the value of the named request header. Keys are normalized
+// to lowercase automatically (HTTP/2 mandates lowercase; the H1 parser
+// normalizes to lowercase).
func (c *Context) Header(key string) string {
+ // Fast path: most programmatic keys are already lowercase.
+ needsLower := false
+ for i := range len(key) {
+ if key[i] >= 'A' && key[i] <= 'Z' {
+ needsLower = true
+ break
+ }
+ }
+ if needsLower {
+ key = strings.ToLower(key)
+ }
for _, h := range c.stream.Headers {
if h[0] == key {
return h[1]
@@ -97,6 +110,15 @@ func (c *Context) ParamInt64(key string) (int64, error) {
return strconv.ParseInt(v, 10, 64)
}
+// ParamDefault returns the value of a URL parameter, or the default if absent or empty.
+func (c *Context) ParamDefault(key, defaultValue string) string {
+ v := c.Param(key)
+ if v == "" {
+ return defaultValue
+ }
+ return v
+}
+
// Query returns the value of a query parameter by name. Results are cached
// so repeated calls for different keys do not re-parse the query string.
func (c *Context) Query(key string) string {
@@ -135,6 +157,38 @@ func (c *Context) QueryInt(key string, defaultValue int) int {
return n
}
+// QueryInt64 returns a query parameter parsed as an int64.
+// Returns the provided default value if the key is absent or not a valid integer.
+func (c *Context) QueryInt64(key string, defaultValue int64) int64 {
+ v := c.Query(key)
+ if v == "" {
+ return defaultValue
+ }
+ n, err := strconv.ParseInt(v, 10, 64)
+ if err != nil {
+ return defaultValue
+ }
+ return n
+}
+
+// QueryBool returns a query parameter parsed as a bool.
+// Returns the provided default value if the key is absent or not a valid bool.
+// Recognizes "true", "1", "yes" as true and "false", "0", "no" as false.
+func (c *Context) QueryBool(key string, defaultValue bool) bool {
+ v := c.Query(key)
+ if v == "" {
+ return defaultValue
+ }
+ switch strings.ToLower(v) {
+ case "true", "1", "yes":
+ return true
+ case "false", "0", "no":
+ return false
+ default:
+ return defaultValue
+ }
+}
+
// QueryValues returns all values for the given query parameter key.
// Returns nil if the key is not present.
func (c *Context) QueryValues(key string) []string {
@@ -282,6 +336,9 @@ func (c *Context) Scheme() string {
return c.schemeOverride
}
if proto := c.Header("x-forwarded-proto"); proto != "" {
+ if proto == "https" || proto == "http" {
+ return proto
+ }
return strings.ToLower(strings.TrimSpace(proto))
}
if scheme := c.Header(":scheme"); scheme != "" {
@@ -340,16 +397,22 @@ func (c *Context) BasicAuth() (username, password string, ok bool) {
if len(auth) < len(prefix) || auth[:len(prefix)] != prefix {
return
}
- decoded, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
+ payload := auth[len(prefix):]
+ var buf [128]byte
+ if base64.StdEncoding.DecodedLen(len(payload)) > len(buf) {
+ return
+ }
+ n, err := base64.StdEncoding.Decode(buf[:],
+ unsafe.Slice(unsafe.StringData(payload), len(payload)))
if err != nil {
return
}
- s := string(decoded)
- colon := strings.IndexByte(s, ':')
- if colon < 0 {
+ i := bytes.IndexByte(buf[:n], ':')
+ if i < 0 {
return
}
- return s[:colon], s[colon+1:], true
+ decoded := string(buf[:n])
+ return decoded[:i], decoded[i+1:], true
}
// FormValue returns the first value for the named form field.
@@ -361,10 +424,10 @@ func (c *Context) FormValue(name string) string {
return c.formValues.Get(name)
}
-// FormValueOk returns the first value for the named form field plus a boolean
+// FormValueOK returns the first value for the named form field plus a boolean
// indicating whether the field was present. Unlike FormValue, callers can
// distinguish a missing field from an empty value.
-func (c *Context) FormValueOk(name string) (string, bool) {
+func (c *Context) FormValueOK(name string) (string, bool) {
if err := c.parseForm(); err != nil {
return "", false
}
@@ -375,6 +438,13 @@ func (c *Context) FormValueOk(name string) (string, bool) {
return vs[0], true
}
+// FormValueOk is a deprecated alias for [Context.FormValueOK].
+//
+// Deprecated: Use [Context.FormValueOK] instead.
+func (c *Context) FormValueOk(name string) (string, bool) {
+ return c.FormValueOK(name)
+}
+
// FormValues returns all values for the named form field.
func (c *Context) FormValues(name string) []string {
if err := c.parseForm(); err != nil {
diff --git a/context_request_test.go b/context_request_test.go
index 23aac6f..98f04b7 100644
--- a/context_request_test.go
+++ b/context_request_test.go
@@ -15,7 +15,7 @@ import (
func TestContextBind(t *testing.T) {
s, _ := newTestStream("POST", "/bind")
- s.Data.Write([]byte(`{"name":"test"}`))
+ s.GetBuf().Write([]byte(`{"name":"test"}`))
defer s.Release()
c := acquireContext(s)
@@ -221,7 +221,7 @@ func TestContextQueryInt(t *testing.T) {
func TestBindJSON(t *testing.T) {
s, _ := newTestStream("POST", "/bind-json")
- s.Data.Write([]byte(`{"name":"test"}`))
+ s.GetBuf().Write([]byte(`{"name":"test"}`))
defer s.Release()
c := acquireContext(s)
@@ -238,7 +238,7 @@ func TestBindJSON(t *testing.T) {
func TestBindXML(t *testing.T) {
s, _ := newTestStream("POST", "/bind-xml")
- s.Data.Write([]byte(`- test
`))
+ s.GetBuf().Write([]byte(`- test
`))
defer s.Release()
c := acquireContext(s)
@@ -259,7 +259,7 @@ func TestBindXML(t *testing.T) {
func TestBindContentTypeDetection(t *testing.T) {
// JSON (default, no Content-Type).
s, _ := newTestStream("POST", "/bind-auto")
- s.Data.Write([]byte(`{"name":"json"}`))
+ s.GetBuf().Write([]byte(`{"name":"json"}`))
defer s.Release()
c := acquireContext(s)
@@ -276,7 +276,7 @@ func TestBindContentTypeDetection(t *testing.T) {
// XML with Content-Type header.
s2, _ := newTestStream("POST", "/bind-auto-xml")
s2.Headers = append(s2.Headers, [2]string{"content-type", "application/xml"})
- s2.Data.Write([]byte(`- xml
`))
+ s2.GetBuf().Write([]byte(`- xml
`))
defer s2.Release()
c2 := acquireContext(s2)
@@ -440,7 +440,7 @@ func TestContextBasicAuthPasswordWithColons(t *testing.T) {
func TestContextFormValueURLEncoded(t *testing.T) {
s, _ := newTestStream("POST", "/form")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("name=alice&age=30"))
+ s.GetBuf().Write([]byte("name=alice&age=30"))
defer s.Release()
c := acquireContext(s)
@@ -460,7 +460,7 @@ func TestContextFormValueURLEncoded(t *testing.T) {
func TestContextFormValuesMultiple(t *testing.T) {
s, _ := newTestStream("POST", "/form")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("color=red&color=blue&color=green"))
+ s.GetBuf().Write([]byte("color=red&color=blue&color=green"))
defer s.Release()
c := acquireContext(s)
@@ -484,7 +484,7 @@ func TestContextMultipartFormValue(t *testing.T) {
s, _ := newTestStream("POST", "/upload")
s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()})
- s.Data.Write(buf.Bytes())
+ s.GetBuf().Write(buf.Bytes())
defer s.Release()
c := acquireContext(s)
@@ -513,7 +513,7 @@ func TestContextFormFile(t *testing.T) {
s, _ := newTestStream("POST", "/upload")
s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()})
- s.Data.Write(buf.Bytes())
+ s.GetBuf().Write(buf.Bytes())
defer s.Release()
c := acquireContext(s)
@@ -538,7 +538,7 @@ func TestContextFormFile(t *testing.T) {
func TestContextFormFileNonMultipart(t *testing.T) {
s, _ := newTestStream("POST", "/upload")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("key=value"))
+ s.GetBuf().Write([]byte("key=value"))
defer s.Release()
c := acquireContext(s)
@@ -565,7 +565,7 @@ func TestContextFormEmptyBody(t *testing.T) {
func TestContextFormCaching(t *testing.T) {
s, _ := newTestStream("POST", "/form")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("a=1&b=2"))
+ s.GetBuf().Write([]byte("a=1&b=2"))
defer s.Release()
c := acquireContext(s)
@@ -584,7 +584,7 @@ func TestContextFormCaching(t *testing.T) {
func TestContextFormResetClearsForm(t *testing.T) {
s, _ := newTestStream("POST", "/form")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("key=value"))
+ s.GetBuf().Write([]byte("key=value"))
defer s.Release()
c := acquireContext(s)
@@ -607,7 +607,7 @@ func TestContextMaxFormSizeUnlimited(t *testing.T) {
s, _ := newTestStream("POST", "/form")
s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()})
- s.Data.Write(buf.Bytes())
+ s.GetBuf().Write(buf.Bytes())
defer s.Release()
c := acquireContext(s)
@@ -624,7 +624,7 @@ func TestContextMaxFormSizeUnlimited(t *testing.T) {
func TestContextBodyCopy(t *testing.T) {
s, _ := newTestStream("POST", "/data")
- s.Data.Write([]byte("original"))
+ s.GetBuf().Write([]byte("original"))
defer s.Release()
c := acquireContext(s)
@@ -655,7 +655,7 @@ func TestContextBodyCopyEmpty(t *testing.T) {
func TestContextBodyReader(t *testing.T) {
s, _ := newTestStream("POST", "/data")
- s.Data.Write([]byte("read me"))
+ s.GetBuf().Write([]byte("read me"))
defer s.Release()
c := acquireContext(s)
@@ -1070,7 +1070,7 @@ func TestAcceptsLanguagesNoMatch(t *testing.T) {
func TestFormFileNotMultipart(t *testing.T) {
s, _ := newTestStream("POST", "/upload")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("key=value"))
+ s.GetBuf().Write([]byte("key=value"))
defer s.Release()
c := acquireContext(s)
@@ -1092,7 +1092,7 @@ func TestFormFileNotMultipart(t *testing.T) {
func TestMultipartFormNotMultipart(t *testing.T) {
s, _ := newTestStream("POST", "/upload")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte("key=value"))
+ s.GetBuf().Write([]byte("key=value"))
defer s.Release()
c := acquireContext(s)
@@ -1166,22 +1166,22 @@ func TestRequestHeaders(t *testing.T) {
}
}
-func TestContextFormValueOk(t *testing.T) {
+func TestContextFormValueOK(t *testing.T) {
body := "name=alice&empty="
s, _ := newTestStream("POST", "/form")
s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
- s.Data.Write([]byte(body))
+ s.GetBuf().Write([]byte(body))
defer s.Release()
c := acquireContext(s)
defer releaseContext(c)
- v, ok := c.FormValueOk("name")
+ v, ok := c.FormValueOK("name")
if !ok || v != "alice" {
t.Fatalf("expected (alice, true), got (%s, %v)", v, ok)
}
- v, ok = c.FormValueOk("empty")
+ v, ok = c.FormValueOK("empty")
if !ok {
t.Fatal("expected field 'empty' to be present")
}
@@ -1189,7 +1189,7 @@ func TestContextFormValueOk(t *testing.T) {
t.Fatalf("expected empty string, got %q", v)
}
- _, ok = c.FormValueOk("missing")
+ _, ok = c.FormValueOK("missing")
if ok {
t.Fatal("expected missing field to return false")
}
@@ -1373,3 +1373,149 @@ func TestContextOverridesResetOnReuse(t *testing.T) {
t.Fatal("expected hostOverride to be cleared")
}
}
+
+func TestQueryBool(t *testing.T) {
+ tests := []struct {
+ query string
+ key string
+ def bool
+ expected bool
+ }{
+ {"debug=true", "debug", false, true},
+ {"debug=1", "debug", false, true},
+ {"debug=yes", "debug", false, true},
+ {"debug=false", "debug", true, false},
+ {"debug=0", "debug", true, false},
+ {"debug=no", "debug", true, false},
+ {"debug=TRUE", "debug", false, true},
+ {"debug=FALSE", "debug", true, false},
+ {"debug=Yes", "debug", false, true},
+ {"debug=No", "debug", true, false},
+ {"debug=invalid", "debug", true, true},
+ {"debug=invalid", "debug", false, false},
+ {"", "debug", true, true},
+ {"", "debug", false, false},
+ {"other=1", "debug", false, false},
+ {"other=1", "debug", true, true},
+ }
+ for _, tt := range tests {
+ name := tt.query + "_" + tt.key
+ if tt.def {
+ name += "_def=true"
+ } else {
+ name += "_def=false"
+ }
+ t.Run(name, func(t *testing.T) {
+ path := "/test"
+ if tt.query != "" {
+ path += "?" + tt.query
+ }
+ s, _ := newTestStream("GET", path)
+ defer s.Release()
+
+ c := acquireContext(s)
+ defer releaseContext(c)
+
+ got := c.QueryBool(tt.key, tt.def)
+ if got != tt.expected {
+ t.Fatalf("QueryBool(%q, %v) = %v, want %v", tt.key, tt.def, got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestQueryInt64(t *testing.T) {
+ tests := []struct {
+ query string
+ key string
+ def int64
+ expected int64
+ }{
+ {"page=42", "page", 0, 42},
+ {"page=9999999999", "page", 0, 9999999999},
+ {"page=-100", "page", 0, -100},
+ {"page=0", "page", 99, 0},
+ {"page=abc", "page", 10, 10},
+ {"", "page", 5, 5},
+ {"other=1", "page", 7, 7},
+ {"page=9223372036854775807", "page", 0, 9223372036854775807},
+ }
+ for _, tt := range tests {
+ name := tt.query + "_" + tt.key
+ t.Run(name, func(t *testing.T) {
+ path := "/test"
+ if tt.query != "" {
+ path += "?" + tt.query
+ }
+ s, _ := newTestStream("GET", path)
+ defer s.Release()
+
+ c := acquireContext(s)
+ defer releaseContext(c)
+
+ got := c.QueryInt64(tt.key, tt.def)
+ if got != tt.expected {
+ t.Fatalf("QueryInt64(%q, %d) = %d, want %d", tt.key, tt.def, got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestParamDefault(t *testing.T) {
+ s, _ := newTestStream("GET", "/users/alice")
+ defer s.Release()
+
+ c := acquireContext(s)
+ defer releaseContext(c)
+
+ c.params = Params{
+ {Key: "name", Value: "alice"},
+ {Key: "empty", Value: ""},
+ }
+
+ if got := c.ParamDefault("name", "fallback"); got != "alice" {
+ t.Fatalf("expected alice, got %s", got)
+ }
+ if got := c.ParamDefault("empty", "fallback"); got != "fallback" {
+ t.Fatalf("expected fallback for empty param, got %s", got)
+ }
+ if got := c.ParamDefault("missing", "fallback"); got != "fallback" {
+ t.Fatalf("expected fallback for missing param, got %s", got)
+ }
+}
+
+func TestFormValueOkDeprecated(t *testing.T) {
+ body := "name=alice&empty="
+ s, _ := newTestStream("POST", "/form")
+ s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"})
+ s.GetBuf().Write([]byte(body))
+ defer s.Release()
+
+ c := acquireContext(s)
+ defer releaseContext(c)
+
+ // FormValueOk (deprecated) should return the same results as FormValueOK.
+ v1, ok1 := c.FormValueOK("name")
+ v2, ok2 := c.FormValueOk("name")
+ if v1 != v2 || ok1 != ok2 {
+ t.Fatalf("FormValueOk != FormValueOK: (%q,%v) vs (%q,%v)", v1, ok1, v2, ok2)
+ }
+ if !ok1 || v1 != "alice" {
+ t.Fatalf("expected (alice, true), got (%s, %v)", v1, ok1)
+ }
+
+ v1, ok1 = c.FormValueOK("empty")
+ v2, ok2 = c.FormValueOk("empty")
+ if v1 != v2 || ok1 != ok2 {
+ t.Fatalf("FormValueOk != FormValueOK for empty: (%q,%v) vs (%q,%v)", v1, ok1, v2, ok2)
+ }
+
+ _, ok1 = c.FormValueOK("missing")
+ _, ok2 = c.FormValueOk("missing")
+ if ok1 != ok2 {
+ t.Fatalf("FormValueOk != FormValueOK for missing: %v vs %v", ok1, ok2)
+ }
+ if ok1 {
+ t.Fatal("expected missing field to return false")
+ }
+}
diff --git a/context_response.go b/context_response.go
index 9358243..c9542c4 100644
--- a/context_response.go
+++ b/context_response.go
@@ -125,13 +125,25 @@ func (c *Context) Blob(code int, contentType string, data []byte) error {
c.capturedStatus = code
c.capturedType = contentType
}
- headers := c.respHdrBuf[:0:8]
- if len(c.respHeaders)+2 > 8 {
- headers = make([][2]string, 0, len(c.respHeaders)+2)
+ nUser := len(c.respHeaders)
+ total := nUser + 2
+ var headers [][2]string
+ if total <= len(c.respHdrBuf) {
+ // respHeaders shares backing array with respHdrBuf — copy user
+ // headers to a stack temporary before overwriting the buffer.
+ // Max user headers in fast path: len(respHdrBuf) - 2 = 6.
+ var tmp [6][2]string
+ copy(tmp[:nUser], c.respHeaders)
+ headers = c.respHdrBuf[:0:len(c.respHdrBuf)]
+ headers = append(headers, [2]string{"content-type", stripCRLF(contentType)})
+ headers = append(headers, [2]string{"content-length", itoa(len(data))})
+ headers = append(headers, tmp[:nUser]...)
+ } else {
+ headers = make([][2]string, 0, total)
+ headers = append(headers, [2]string{"content-type", stripCRLF(contentType)})
+ headers = append(headers, [2]string{"content-length", itoa(len(data))})
+ headers = append(headers, c.respHeaders...)
}
- headers = append(headers, [2]string{"content-type", stripCRLF(contentType)})
- headers = append(headers, [2]string{"content-length", itoa(len(data))})
- headers = append(headers, c.respHeaders...)
if c.stream.ResponseWriter != nil {
return c.stream.ResponseWriter.WriteResponse(c.stream, code, headers, data)
}
@@ -175,7 +187,14 @@ func (c *Context) SetHeader(key, value string) {
break
}
}
- v := stripCRLF(value)
+ v := value
+ for i := range len(value) {
+ b := value[i]
+ if b == '\r' || b == '\n' || b == 0 {
+ v = stripCRLF(value)
+ break
+ }
+ }
for i, h := range c.respHeaders {
if h[0] == k {
c.respHeaders[i][1] = v
@@ -189,10 +208,23 @@ func (c *Context) SetHeader(key, value string) {
// replace existing values — use this for headers that allow multiple values
// (e.g. set-cookie).
func (c *Context) AddHeader(key, value string) {
- c.respHeaders = append(c.respHeaders, [2]string{
- sanitizeHeaderKey(key),
- stripCRLF(value),
- })
+ k := key
+ for i := range len(key) {
+ b := key[i]
+ if b >= 'A' && b <= 'Z' || b == '\r' || b == '\n' || b == 0 {
+ k = sanitizeHeaderKey(key)
+ break
+ }
+ }
+ v := value
+ for i := range len(value) {
+ b := value[i]
+ if b == '\r' || b == '\n' || b == 0 {
+ v = stripCRLF(value)
+ break
+ }
+ }
+ c.respHeaders = append(c.respHeaders, [2]string{k, v})
}
// sanitizeHeaderKey lowercases and strips CRLF/null bytes. Fast path avoids
@@ -207,11 +239,13 @@ func sanitizeHeaderKey(s string) string {
return s
}
+var crlfReplacer = strings.NewReplacer("\r", "", "\n", "", "\x00", "")
+
// stripCRLF removes \r, \n, and \x00 to prevent HTTP response splitting
// (CWE-113) and null-byte header smuggling.
func stripCRLF(s string) string {
if strings.ContainsAny(s, "\r\n\x00") {
- return strings.NewReplacer("\r", "", "\n", "", "\x00", "").Replace(s)
+ return crlfReplacer.Replace(s)
}
return s
}
@@ -245,11 +279,13 @@ func (c *Context) Redirect(code int, url string) error {
return c.NoContent(code)
}
+var cookieUnsafeReplacer = strings.NewReplacer(";", "", "\r", "", "\n", "")
+
// stripCookieUnsafe strips characters that could inject cookie attributes
// (semicolons) or cause header injection (CRLF) from cookie field values.
func stripCookieUnsafe(s string) string {
if strings.ContainsAny(s, ";\r\n") {
- return strings.NewReplacer(";", "", "\r", "", "\n", "").Replace(s)
+ return cookieUnsafeReplacer.Replace(s)
}
return s
}
@@ -261,6 +297,7 @@ func stripCookieUnsafe(s string) string {
// cookie attribute injection.
func (c *Context) SetCookie(cookie *Cookie) {
var b strings.Builder
+ b.Grow(128)
b.WriteString(stripCookieUnsafe(cookie.Name))
b.WriteByte('=')
b.WriteString(stripCookieUnsafe(cookie.Value))
@@ -357,15 +394,34 @@ func (c *Context) File(filePath string) error {
// FileFromDir safely serves a file from within baseDir. The userPath is
// cleaned and joined with baseDir; if the result escapes baseDir, a 400
-// error is returned. This prevents directory traversal when serving
-// user-supplied paths.
+// error is returned. Symlinks are resolved and rechecked to prevent a
+// symlink under baseDir from escaping the directory boundary.
func (c *Context) FileFromDir(baseDir, userPath string) error {
abs := filepath.Clean(filepath.Join(baseDir, filepath.FromSlash(userPath)))
base := filepath.Clean(baseDir)
if abs != base && !strings.HasPrefix(abs, base+string(filepath.Separator)) {
return NewHTTPError(400, "invalid file path")
}
- return c.File(abs)
+ // Resolve symlinks and recheck prefix to prevent symlink escape.
+ resolved, err := filepath.EvalSymlinks(abs)
+ if err != nil {
+ return err
+ }
+ resolvedBase, err := filepath.EvalSymlinks(base)
+ if err != nil {
+ return err
+ }
+ if resolved != resolvedBase && !strings.HasPrefix(resolved, resolvedBase+string(filepath.Separator)) {
+ return NewHTTPError(400, "invalid file path")
+ }
+ info, err := os.Stat(resolved)
+ if err != nil {
+ return err
+ }
+ if info.IsDir() {
+ return NewHTTPError(400, "invalid file path")
+ }
+ return c.File(resolved)
}
func parseRange(header string, size int64) (start, end int64, ok bool) {
diff --git a/context_test.go b/context_test.go
index aa762f7..cf60ebf 100644
--- a/context_test.go
+++ b/context_test.go
@@ -19,7 +19,8 @@ type mockResponseWriter struct {
func (m *mockResponseWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error {
m.status = status
- m.headers = headers
+ m.headers = make([][2]string, len(headers))
+ copy(m.headers, headers)
m.body = make([]byte, len(body))
copy(m.body, body)
return nil
diff --git a/doc.go b/doc.go
index 32aae88..88d06f4 100644
--- a/doc.go
+++ b/doc.go
@@ -431,9 +431,9 @@
//
// # Form Presence
//
-// FormValueOk distinguishes a missing field from an empty value:
+// FormValueOK distinguishes a missing field from an empty value:
//
-// val, ok := c.FormValueOk("name")
+// val, ok := c.FormValueOK("name")
// if !ok {
// // field was not submitted
// }
diff --git a/handler.go b/handler.go
index e277e4a..81bcb63 100644
--- a/handler.go
+++ b/handler.go
@@ -12,7 +12,9 @@ import (
)
type routerAdapter struct {
- server *Server
+ server *Server
+ notFoundChain []HandlerFunc
+ methodNotAllowedChain []HandlerFunc
}
func (a *routerAdapter) HandleStream(_ context.Context, s *stream.Stream) error {
@@ -111,9 +113,13 @@ func (a *routerAdapter) handleUnmatched(c *Context, s *stream.Stream) {
if len(allowed) > 0 {
c.statusCode = 405
allowVal := strings.Join(allowed, ", ")
- if a.server.methodNotAllowedHandler != nil {
+ chain := a.methodNotAllowedChain
+ if chain == nil && a.server.methodNotAllowedHandler != nil {
+ chain = []HandlerFunc{a.server.methodNotAllowedHandler}
+ }
+ if chain != nil {
c.SetHeader("allow", allowVal)
- c.handlers = []HandlerFunc{a.server.methodNotAllowedHandler}
+ c.handlers = chain
a.handleError(c, s, c.Next())
}
if !c.written && s.ResponseWriter != nil {
@@ -125,8 +131,12 @@ func (a *routerAdapter) handleUnmatched(c *Context, s *stream.Stream) {
}
} else {
c.statusCode = 404
- if a.server.notFoundHandler != nil {
- c.handlers = []HandlerFunc{a.server.notFoundHandler}
+ chain := a.notFoundChain
+ if chain == nil && a.server.notFoundHandler != nil {
+ chain = []HandlerFunc{a.server.notFoundHandler}
+ }
+ if chain != nil {
+ c.handlers = chain
a.handleError(c, s, c.Next())
}
if !c.written && s.ResponseWriter != nil {
diff --git a/internal/ctxkit/ctxkit.go b/internal/ctxkit/ctxkit.go
index e0af1a9..2a8f06f 100644
--- a/internal/ctxkit/ctxkit.go
+++ b/internal/ctxkit/ctxkit.go
@@ -7,8 +7,10 @@ import "github.com/goceleris/celeris/protocol/h2/stream"
// Hooks registered by the celeris package at init time.
var (
- NewContext func(s *stream.Stream) any
- ReleaseContext func(c any)
- AddParam func(c any, key, value string)
- SetHandlers func(c any, handlers []any)
+ NewContext func(s *stream.Stream) any
+ ReleaseContext func(c any)
+ AddParam func(c any, key, value string)
+ SetHandlers func(c any, handlers []any)
+ GetResponseWriter func(c any) any
+ GetStream func(c any) any
)
diff --git a/internal/timer/wheel.go b/internal/timer/wheel.go
deleted file mode 100644
index 040a491..0000000
--- a/internal/timer/wheel.go
+++ /dev/null
@@ -1,145 +0,0 @@
-// Package timer implements a hierarchical timer wheel for efficient
-// timeout management in event-driven I/O engines.
-package timer
-
-import (
- "sync"
- "time"
-)
-
-const (
- // WheelSize is the number of slots in the timer wheel (power of 2 for fast modulo).
- WheelSize = 8192
- // TickInterval is the granularity of the timer wheel.
- TickInterval = 100 * time.Millisecond
-
- wheelMask = WheelSize - 1
-)
-
-// TimeoutKind identifies the type of timeout.
-type TimeoutKind uint8
-
-// Timeout kinds for the timer wheel.
-const (
- ReadTimeout TimeoutKind = iota
- WriteTimeout
- IdleTimeout
-)
-
-// Entry represents a scheduled timeout.
-type Entry struct {
- FD int
- Deadline int64 // unix nano
- Kind TimeoutKind
- next *Entry
-}
-
-var entryPool = sync.Pool{New: func() any { return &Entry{} }}
-
-func getEntry() *Entry {
- return entryPool.Get().(*Entry)
-}
-
-func putEntry(e *Entry) {
- e.FD = 0
- e.Deadline = 0
- e.Kind = 0
- e.next = nil
- entryPool.Put(e)
-}
-
-// Wheel is a hashed timer wheel with O(1) schedule and cancel.
-type Wheel struct {
- slots [WheelSize]*Entry
- current uint64
- startNs int64
- onExpire func(fd int, kind TimeoutKind)
- // fdSlots tracks the slot index for each FD to enable O(1) cancel.
- fdSlots map[int]int64 // fd → deadline
-}
-
-// New creates a timer wheel that calls onExpire for each expired entry.
-func New(onExpire func(fd int, kind TimeoutKind)) *Wheel {
- return &Wheel{
- startNs: time.Now().UnixNano(),
- onExpire: onExpire,
- fdSlots: make(map[int]int64),
- }
-}
-
-func (w *Wheel) slotIndex(deadline int64) uint64 {
- tick := uint64(deadline-w.startNs) / uint64(TickInterval)
- return tick & wheelMask
-}
-
-// Schedule registers a timeout for the given FD. Any previous timeout for
-// the same FD is logically cancelled (the old entry is ignored on expiry).
-func (w *Wheel) Schedule(fd int, timeout time.Duration, kind TimeoutKind) {
- w.ScheduleAt(fd, timeout, kind, time.Now().UnixNano())
-}
-
-// ScheduleAt is like Schedule but accepts a pre-computed now timestamp (UnixNano)
-// to avoid redundant time.Now() calls when scheduling multiple timeouts in the
-// same request processing cycle.
-func (w *Wheel) ScheduleAt(fd int, timeout time.Duration, kind TimeoutKind, nowNs int64) {
- deadline := nowNs + int64(timeout)
- slot := w.slotIndex(deadline)
-
- e := getEntry()
- e.FD = fd
- e.Deadline = deadline
- e.Kind = kind
- e.next = w.slots[slot]
- w.slots[slot] = e
-
- w.fdSlots[fd] = deadline
-}
-
-// Cancel removes any pending timeout for the given FD.
-func (w *Wheel) Cancel(fd int) {
- delete(w.fdSlots, fd)
-}
-
-// Tick advances the wheel and fires expired entries. Returns the number
-// of entries that expired. Call this from the event loop after processing
-// I/O events.
-func (w *Wheel) Tick() int {
- now := time.Now().UnixNano()
- currentTick := uint64(now-w.startNs) / uint64(TickInterval)
- count := 0
-
- for w.current <= currentTick {
- slot := w.current & wheelMask
- var prev *Entry
- e := w.slots[slot]
- for e != nil {
- next := e.next
- if e.Deadline <= now {
- // Check if this entry is still the active one for this FD.
- if active, ok := w.fdSlots[e.FD]; ok && active == e.Deadline {
- delete(w.fdSlots, e.FD)
- w.onExpire(e.FD, e.Kind)
- count++
- }
- // Remove from list.
- if prev == nil {
- w.slots[slot] = next
- } else {
- prev.next = next
- }
- putEntry(e)
- } else {
- prev = e
- }
- e = next
- }
- w.current++
- }
-
- return count
-}
-
-// Len returns the number of tracked FDs (approximate).
-func (w *Wheel) Len() int {
- return len(w.fdSlots)
-}
diff --git a/internal/timer/wheel_test.go b/internal/timer/wheel_test.go
deleted file mode 100644
index 3e3d3eb..0000000
--- a/internal/timer/wheel_test.go
+++ /dev/null
@@ -1,133 +0,0 @@
-package timer
-
-import (
- "sync/atomic"
- "testing"
- "time"
-)
-
-func TestWheelScheduleAndExpire(t *testing.T) {
- var expired atomic.Int32
- var expiredFD int
- var expiredKind TimeoutKind
-
- w := New(func(fd int, kind TimeoutKind) {
- expiredFD = fd
- expiredKind = kind
- expired.Add(1)
- })
-
- w.Schedule(42, 150*time.Millisecond, IdleTimeout)
-
- // Should not expire immediately.
- n := w.Tick()
- if n != 0 {
- t.Fatalf("expected 0 expired, got %d", n)
- }
-
- time.Sleep(200 * time.Millisecond)
- n = w.Tick()
- if n != 1 {
- t.Fatalf("expected 1 expired, got %d", n)
- }
- if expiredFD != 42 {
- t.Fatalf("expected FD 42, got %d", expiredFD)
- }
- if expiredKind != IdleTimeout {
- t.Fatalf("expected IdleTimeout, got %d", expiredKind)
- }
-}
-
-func TestWheelCancel(t *testing.T) {
- var expired atomic.Int32
- w := New(func(_ int, _ TimeoutKind) {
- expired.Add(1)
- })
-
- w.Schedule(10, 150*time.Millisecond, ReadTimeout)
- w.Cancel(10)
-
- time.Sleep(200 * time.Millisecond)
- w.Tick()
- if expired.Load() != 0 {
- t.Fatalf("expected 0 expired after cancel, got %d", expired.Load())
- }
-}
-
-func TestWheelMultipleEntries(t *testing.T) {
- expired := make(map[int]bool)
- w := New(func(fd int, _ TimeoutKind) {
- expired[fd] = true
- })
-
- w.Schedule(1, 150*time.Millisecond, IdleTimeout)
- w.Schedule(2, 150*time.Millisecond, IdleTimeout)
- w.Schedule(3, 150*time.Millisecond, IdleTimeout)
-
- time.Sleep(200 * time.Millisecond)
- n := w.Tick()
- if n != 3 {
- t.Fatalf("expected 3 expired, got %d", n)
- }
- for _, fd := range []int{1, 2, 3} {
- if !expired[fd] {
- t.Fatalf("expected FD %d to expire", fd)
- }
- }
-}
-
-func TestWheelReschedule(t *testing.T) {
- var lastKind TimeoutKind
- count := 0
- w := New(func(_ int, kind TimeoutKind) {
- lastKind = kind
- count++
- })
-
- // Schedule then reschedule with different kind.
- w.Schedule(5, 150*time.Millisecond, ReadTimeout)
- w.Schedule(5, 150*time.Millisecond, WriteTimeout)
-
- time.Sleep(200 * time.Millisecond)
- w.Tick()
- // Only the latest schedule should fire.
- if count != 1 {
- t.Fatalf("expected 1 expiry, got %d", count)
- }
- if lastKind != WriteTimeout {
- t.Fatalf("expected WriteTimeout, got %d", lastKind)
- }
-}
-
-func TestWheelLen(t *testing.T) {
- w := New(func(_ int, _ TimeoutKind) {})
- w.Schedule(1, time.Second, IdleTimeout)
- w.Schedule(2, time.Second, IdleTimeout)
- if w.Len() != 2 {
- t.Fatalf("expected Len=2, got %d", w.Len())
- }
- w.Cancel(1)
- if w.Len() != 1 {
- t.Fatalf("expected Len=1, got %d", w.Len())
- }
-}
-
-func BenchmarkWheelSchedule(b *testing.B) {
- w := New(func(_ int, _ TimeoutKind) {})
- b.ResetTimer()
- for i := range b.N {
- w.Schedule(i%10000, 5*time.Second, IdleTimeout)
- }
-}
-
-func BenchmarkWheelTick(b *testing.B) {
- w := New(func(_ int, _ TimeoutKind) {})
- // Pre-fill with entries that won't expire.
- for i := range 1000 {
- w.Schedule(i, time.Hour, IdleTimeout)
- }
- b.ResetTimer()
- for range b.N {
- w.Tick()
- }
-}
diff --git a/protocol/h1/parser.go b/protocol/h1/parser.go
index bf18908..c9b4b82 100644
--- a/protocol/h1/parser.go
+++ b/protocol/h1/parser.go
@@ -7,13 +7,14 @@ import (
// H1 parser sentinel errors.
var (
- ErrBufferExhausted = errors.New("buffer exhausted")
- ErrInvalidRequestLine = errors.New("invalid request line")
- ErrInvalidHeader = errors.New("invalid header line")
- ErrMissingHost = errors.New("missing Host header")
- ErrUnsupportedVersion = errors.New("unsupported HTTP version")
- ErrHeadersTooLarge = errors.New("headers too large")
- ErrInvalidContentLength = errors.New("invalid content-length")
+ ErrBufferExhausted = errors.New("buffer exhausted")
+ ErrInvalidRequestLine = errors.New("invalid request line")
+ ErrInvalidHeader = errors.New("invalid header line")
+ ErrMissingHost = errors.New("missing Host header")
+ ErrUnsupportedVersion = errors.New("unsupported HTTP version")
+ ErrHeadersTooLarge = errors.New("headers too large")
+ ErrInvalidContentLength = errors.New("invalid content-length")
+ ErrDuplicateContentLength = errors.New("duplicate content-length with conflicting values")
)
// Parser is a zero-allocation HTTP/1.x request parser.
@@ -171,6 +172,9 @@ func (p *Parser) appendHeader(req *Request, rawName, rawValue []byte) error {
if !ok {
return ErrInvalidContentLength
}
+ if req.ContentLength >= 0 && req.ContentLength != cl {
+ return ErrDuplicateContentLength
+ }
req.ContentLength = cl
case "transfer-encoding":
if asciiContainsFoldString(value, "chunked") {
@@ -206,11 +210,14 @@ func (p *Parser) appendHeader(req *Request, rawName, rawValue []byte) error {
if req.ChunkedEncoding {
return nil
}
- if cl, ok := parseInt64Bytes(rawValue); ok {
- req.ContentLength = cl
- } else {
+ cl, ok := parseInt64Bytes(rawValue)
+ if !ok {
return ErrInvalidContentLength
}
+ if req.ContentLength >= 0 && req.ContentLength != cl {
+ return ErrDuplicateContentLength
+ }
+ req.ContentLength = cl
return nil
}
if asciiEqualFold(rawName, "Connection") {
diff --git a/protocol/h1/parser_test.go b/protocol/h1/parser_test.go
index 7586d6b..f0b6aa2 100644
--- a/protocol/h1/parser_test.go
+++ b/protocol/h1/parser_test.go
@@ -530,6 +530,81 @@ func TestFindHeaderEnd_MultipleCR(t *testing.T) {
}
}
+func TestParseRequest_DuplicateContentLength_Conflicting(t *testing.T) {
+ for _, zeroCopy := range []bool{false, true} {
+ name := "standard"
+ if zeroCopy {
+ name = "zerocopy"
+ }
+ t.Run(name, func(t *testing.T) {
+ raw := "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n"
+ p := NewParser()
+ if zeroCopy {
+ p.noStringHeaders = true
+ }
+ p.Reset([]byte(raw))
+ var req Request
+ _, err := p.ParseRequest(&req)
+ if !errors.Is(err, ErrDuplicateContentLength) {
+ t.Fatalf("got error %v, want %v", err, ErrDuplicateContentLength)
+ }
+ })
+ }
+}
+
+func TestParseRequest_DuplicateContentLength_Identical(t *testing.T) {
+ for _, zeroCopy := range []bool{false, true} {
+ name := "standard"
+ if zeroCopy {
+ name = "zerocopy"
+ }
+ t.Run(name, func(t *testing.T) {
+ raw := "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 42\r\nContent-Length: 42\r\n\r\n"
+ p := NewParser()
+ if zeroCopy {
+ p.noStringHeaders = true
+ }
+ p.Reset([]byte(raw))
+ var req Request
+ _, err := p.ParseRequest(&req)
+ if err != nil {
+ t.Fatalf("identical CL should be accepted, got: %v", err)
+ }
+ if req.ContentLength != 42 {
+ t.Fatalf("content-length = %d, want 42", req.ContentLength)
+ }
+ })
+ }
+}
+
+func TestParseRequest_DuplicateContentLength_ChunkedIgnored(t *testing.T) {
+ for _, zeroCopy := range []bool{false, true} {
+ name := "standard"
+ if zeroCopy {
+ name = "zerocopy"
+ }
+ t.Run(name, func(t *testing.T) {
+ raw := "POST / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n"
+ p := NewParser()
+ if zeroCopy {
+ p.noStringHeaders = true
+ }
+ p.Reset([]byte(raw))
+ var req Request
+ _, err := p.ParseRequest(&req)
+ if err != nil {
+ t.Fatalf("conflicting CL with chunked TE should be ignored, got: %v", err)
+ }
+ if !req.ChunkedEncoding {
+ t.Fatal("chunked encoding not detected")
+ }
+ if req.ContentLength != -1 {
+ t.Fatalf("content-length = %d, want -1 for chunked", req.ContentLength)
+ }
+ })
+ }
+}
+
// TestFindHeaderEnd_PartialSequence ensures partial \r\n sequences
// (without the full \r\n\r\n) are not mistakenly matched.
func TestFindHeaderEnd_PartialSequence(t *testing.T) {
diff --git a/protocol/h2/stream/stream.go b/protocol/h2/stream/stream.go
index 2ccd4b0..9ba93fd 100644
--- a/protocol/h2/stream/stream.go
+++ b/protocol/h2/stream/stream.go
@@ -61,6 +61,7 @@ type Stream struct {
phase Phase
CachedCtx any // per-connection cached context (avoids pool Get/Put per request)
OnDetach func() // called by Context.Detach to install write-thread safety
+ hdrBuf [16][2]string
}
// streamContext is a zero-alloc context.Context for streams.
@@ -103,13 +104,9 @@ func NewStream(id uint32) *Stream {
s := streamPool.Get().(*Stream)
s.ID = id
s.state.Store(int32(StateIdle))
- s.Data = getBuf()
- s.OutboundBuffer = getBuf()
s.windowSize.Store(65535)
s.phase = PhaseInit
- if cap(s.Headers) < 8 {
- s.Headers = make([][2]string, 0, 8)
- }
+ s.Headers = s.hdrBuf[:0]
return s
}
@@ -119,9 +116,7 @@ func NewH1Stream(id uint32) *Stream {
s.ID = id
s.state.Store(int32(StateIdle))
s.h1Mode = true
- if cap(s.Headers) < 16 {
- s.Headers = make([][2]string, 0, 16)
- }
+ s.Headers = s.hdrBuf[:0]
return s
}
@@ -162,12 +157,30 @@ func (s *Stream) IsCancelled() bool {
return s.flags.Load()&flagCancelled != 0
}
+// HasDoneCh reports whether a Done channel was created, indicating
+// a derived context (e.g. context.WithTimeout) is watching this stream.
+func (s *Stream) HasDoneCh() bool {
+ return s.doneCh.Load() != nil
+}
+
// Release returns pooled buffers, cancels the context, and returns the stream
// to its pool. Safe to call multiple times; subsequent calls are no-ops.
func (s *Stream) Release() {
if !s.h1Mode {
s.Cancel()
}
+ s.resetAndPool()
+}
+
+// ResetForPool returns pooled buffers and returns the stream to its pool
+// WITHOUT cancelling the context. Use this in test harnesses where derived
+// contexts (e.g. from context.WithTimeout) may have propagation goroutines
+// that race with cancellation flag clearing.
+func ResetForPool(s *Stream) {
+ s.resetAndPool()
+}
+
+func (s *Stream) resetAndPool() {
if s.Data != nil {
s.Data.Reset()
bufferPool.Put(s.Data)
@@ -181,7 +194,10 @@ func (s *Stream) Release() {
s.ID = 0
s.state.Store(0)
s.manager = nil
- s.Headers = s.Headers[:0]
+ clear(s.hdrBuf[:])
+ s.Headers = s.hdrBuf[:0]
+ s.Trailers = s.Trailers[:cap(s.Trailers)]
+ clear(s.Trailers)
s.Trailers = s.Trailers[:0]
s.OutboundEndStream = false
s.headersSent.Store(false)
@@ -218,7 +234,7 @@ func ResetH1Stream(s *Stream) {
bufferPool.Put(s.Data)
s.Data = nil
}
- s.Headers = s.Headers[:0]
+ s.Headers = s.hdrBuf[:0]
s.headersSent.Store(false)
s.EndStream = false
s.ResponseWriter = nil
@@ -241,7 +257,7 @@ func ResetH2StreamInline(s *Stream, id uint32) {
s.OutboundBuffer = getBuf()
}
s.ID = id
- s.Headers = s.Headers[:0]
+ s.Headers = s.hdrBuf[:0]
s.Trailers = s.Trailers[:0]
s.OutboundEndStream = false
s.headersSent.Store(false)
@@ -402,7 +418,7 @@ func (s *Stream) SetHandlerStarted() {
func (s *Stream) BufferOutbound(data []byte, endStream bool) {
s.mu.Lock()
if s.OutboundBuffer == nil {
- s.OutboundBuffer = new(bytes.Buffer)
+ s.OutboundBuffer = getBuf()
}
s.OutboundBuffer.Write(data)
s.OutboundEndStream = endStream
diff --git a/server.go b/server.go
index 7b32daa..c2da332 100644
--- a/server.go
+++ b/server.go
@@ -375,7 +375,14 @@ func (s *Server) doPrepare(configureFn func(cfg *resource.Config)) (engine.Engin
s.collector = observe.NewCollector()
}
- var handler stream.Handler = &routerAdapter{server: s}
+ ra := &routerAdapter{server: s}
+ if s.notFoundHandler != nil {
+ ra.notFoundChain = []HandlerFunc{s.notFoundHandler}
+ }
+ if s.methodNotAllowedHandler != nil {
+ ra.methodNotAllowedChain = []HandlerFunc{s.methodNotAllowedHandler}
+ }
+ var handler stream.Handler = ra
var err error
eng, err = createEngine(cfg, handler)
if err != nil {
diff --git a/server_test.go b/server_test.go
index dea10fb..8f9cbc4 100644
--- a/server_test.go
+++ b/server_test.go
@@ -65,7 +65,7 @@ func TestServerRouting(t *testing.T) {
// Test POST /echo.
st2, rw2 := newTestStream("POST", "/echo")
- st2.Data.Write([]byte("payload"))
+ st2.GetBuf().Write([]byte("payload"))
if err := adapter.HandleStream(context.Background(), st2); err != nil {
t.Fatal(err)
}
diff --git a/stdlib.go b/stdlib.go
index b743981..f82bf19 100644
--- a/stdlib.go
+++ b/stdlib.go
@@ -11,7 +11,7 @@ import (
"golang.org/x/net/http2"
)
-const maxToHandlerBodySize = 100 << 20 // 100 MB
+const maxToHandlerBodySize = maxBodySize
// ToHandler wraps a celeris HandlerFunc as an http.Handler for use with
// net/http routers, middleware, or test infrastructure. The returned handler
@@ -57,7 +57,7 @@ func ToHandler(h HandlerFunc) http.Handler {
return
}
if len(body) > 0 {
- _, _ = s.Data.Write(body)
+ _, _ = s.GetBuf().Write(body)
}
}
diff --git a/types.go b/types.go
index 31a29d7..859035a 100644
--- a/types.go
+++ b/types.go
@@ -11,7 +11,11 @@ type HandlerFunc func(*Context) error
// parsing (32 MB), matching net/http.
const DefaultMaxFormSize int64 = 32 << 20
-const maxStreamBodySize = 100 << 20 // 100MB
+// maxBodySize is the maximum request/response body size (100 MB), shared by
+// stream responses (File, Stream), bridge adapter output, and ToHandler input.
+const maxBodySize = 100 << 20
+
+const maxStreamBodySize = maxBodySize
// SameSite controls the SameSite attribute of a cookie.
type SameSite int