diff --git a/bridge.go b/bridge.go index 6dce906..910cc36 100644 --- a/bridge.go +++ b/bridge.go @@ -89,7 +89,7 @@ func buildHTTPRequest(c *Context) (*http.Request, error) { return req, nil } -const maxBridgeResponseBytes = 100 << 20 // 100 MB +const maxBridgeResponseBytes = maxBodySize var errBridgeResponseTooLarge = errors.New("bridge: response body exceeds 100MB limit") diff --git a/celeristest/celeristest.go b/celeristest/celeristest.go index fe49083..517b191 100644 --- a/celeristest/celeristest.go +++ b/celeristest/celeristest.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "net/url" "strings" + "sync" "testing" "github.com/goceleris/celeris" @@ -48,18 +49,32 @@ func (r *ResponseRecorder) BodyString() string { return string(r.Body) } +// recorderCombo bundles a ResponseRecorder and its recorderWriter in a +// single allocation so they can be pooled together. +type recorderCombo struct { + rec ResponseRecorder + rw recorderWriter +} + +var recorderPool = sync.Pool{New: func() any { + combo := &recorderCombo{} + combo.rw.rec = &combo.rec + combo.rw.combo = combo + return combo +}} + // recorderWriter adapts a ResponseRecorder to the internal // stream.ResponseWriter interface without leaking internal types // in the public API. type recorderWriter struct { - rec *ResponseRecorder + rec *ResponseRecorder + combo *recorderCombo } func (w *recorderWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error { w.rec.StatusCode = status w.rec.Headers = headers - w.rec.Body = make([]byte, len(body)) - copy(w.rec.Body, body) + w.rec.Body = append(w.rec.Body[:0], body...) return nil } @@ -77,13 +92,38 @@ var _ stream.ResponseWriter = (*recorderWriter)(nil) type Option func(*config) type config struct { - body []byte - headers [][2]string - queries [][2]string - params [][2]string - cookies [][2]string - remoteAddr string - handlers []any + body []byte + headers [][2]string + queries [][2]string + params [][2]string + cookies [][2]string + remoteAddr string + handlers []any + headersBuf [4][2]string + handlersBuf [4]any +} + +var configPool = sync.Pool{New: func() any { + c := &config{} + c.headers = c.headersBuf[:0] + c.handlers = c.handlersBuf[:0] + return c +}} + +func (c *config) reset() { + c.body = nil + for i := range c.headers { + c.headers[i] = [2]string{} + } + c.headers = c.headersBuf[:0] + c.queries = nil + c.params = nil + c.cookies = nil + c.remoteAddr = "" + for i := range c.handlers { + c.handlers[i] = nil + } + c.handlers = c.handlersBuf[:0] } // WithBody sets the request body. @@ -134,9 +174,17 @@ func WithRemoteAddr(addr string) Option { // Pass celeris.HandlerFunc values; they are stored as []any to avoid import cycles. func WithHandlers(handlers ...celeris.HandlerFunc) Option { return func(c *config) { - c.handlers = make([]any, len(handlers)) - for i, h := range handlers { - c.handlers[i] = h + n := len(handlers) + if n <= len(c.handlersBuf) { + for i, h := range handlers { + c.handlersBuf[i] = h + } + c.handlers = c.handlersBuf[:n] + } else { + c.handlers = make([]any, n) + for i, h := range handlers { + c.handlers[i] = h + } } } } @@ -144,7 +192,38 @@ func WithHandlers(handlers ...celeris.HandlerFunc) Option { // ReleaseContext returns a [celeris.Context] to the pool. The context must not // be used after this call. func ReleaseContext(ctx *celeris.Context) { + // Extract stream and recorder before releasing context (reset nils them). + var s *stream.Stream + if raw := ctxkit.GetStream(ctx); raw != nil { + s = raw.(*stream.Stream) + } + var combo *recorderCombo + if rw := ctxkit.GetResponseWriter(ctx); rw != nil { + if w, ok := rw.(*recorderWriter); ok && w.combo != nil { + combo = w.combo + } + } + ctxkit.ReleaseContext(ctx) + + // Return stream to pool so NewStream reuses it on the next call. + if s != nil { + if s.HasDoneCh() { + // A derived context (e.g. context.WithTimeout) spawned a + // goroutine that holds a reference to this stream. Cancel to + // terminate it, but don't pool the stream — the goroutine may + // still read Err()/Done() after the pool recycles the struct. + s.Cancel() + } else { + stream.ResetForPool(s) + } + } + if combo != nil { + combo.rec.StatusCode = 0 + combo.rec.Headers = nil + combo.rec.Body = combo.rec.Body[:0] + recorderPool.Put(combo) + } } // NewContextT is like [NewContext] but registers an automatic cleanup with @@ -160,7 +239,7 @@ func NewContextT(t *testing.T, method, path string, opts ...Option) (*celeris.Co // The returned context has the given method and path, plus any options applied. // Call [ReleaseContext] when done to clean up pooled resources. func NewContext(method, path string, opts ...Option) (*celeris.Context, *ResponseRecorder) { - cfg := &config{} + cfg := configPool.Get().(*config) for _, o := range opts { o(cfg) } @@ -175,12 +254,12 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons } s := stream.NewStream(1) - s.Headers = [][2]string{ - {":method", method}, - {":path", fullPath}, - {":scheme", "http"}, - {":authority", "localhost"}, - } + s.Headers = append(s.Headers, + [2]string{":method", method}, + [2]string{":path", fullPath}, + [2]string{":scheme", "http"}, + [2]string{":authority", "localhost"}, + ) s.Headers = append(s.Headers, cfg.headers...) if len(cfg.cookies) > 0 { parts := make([]string, 0, len(cfg.cookies)) @@ -190,11 +269,15 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons s.Headers = append(s.Headers, [2]string{"cookie", strings.Join(parts, "; ")}) } if len(cfg.body) > 0 { - s.Data.Write(cfg.body) + s.GetBuf().Write(cfg.body) } - rec := &ResponseRecorder{} - s.ResponseWriter = &recorderWriter{rec: rec} + combo := recorderPool.Get().(*recorderCombo) + combo.rec.StatusCode = 0 + combo.rec.Headers = nil + combo.rec.Body = combo.rec.Body[:0] + rec := &combo.rec + s.ResponseWriter = &combo.rw if cfg.remoteAddr != "" { s.RemoteAddr = cfg.remoteAddr @@ -207,5 +290,9 @@ func NewContext(method, path string, opts ...Option) (*celeris.Context, *Respons if len(cfg.handlers) > 0 { ctxkit.SetHandlers(ctx, cfg.handlers) } + + cfg.reset() + configPool.Put(cfg) + return ctx, rec } diff --git a/celeristest/celeristest_test.go b/celeristest/celeristest_test.go index e5190f0..ccb86c1 100644 --- a/celeristest/celeristest_test.go +++ b/celeristest/celeristest_test.go @@ -205,3 +205,111 @@ func TestWithRemoteAddr(t *testing.T) { t.Fatalf("expected 192.168.1.1:54321, got %s", ctx.RemoteAddr()) } } + +func TestWithHandlers(t *testing.T) { + var order []string + mw := func(c *celeris.Context) error { + order = append(order, "mw") + return c.Next() + } + handler := func(c *celeris.Context) error { + order = append(order, "handler") + return c.String(200, "ok") + } + ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler)) + defer ReleaseContext(ctx) + + err := ctx.Next() + if err != nil { + t.Fatal(err) + } + if rec.StatusCode != 200 { + t.Fatalf("expected 200, got %d", rec.StatusCode) + } + if rec.BodyString() != "ok" { + t.Fatalf("expected body 'ok', got %q", rec.BodyString()) + } + if len(order) != 2 || order[0] != "mw" || order[1] != "handler" { + t.Fatalf("unexpected execution order: %v", order) + } +} + +func TestWithHandlersErrorPropagation(t *testing.T) { + mw := func(c *celeris.Context) error { + return c.Next() + } + handler := func(_ *celeris.Context) error { + return celeris.NewHTTPError(403, "forbidden") + } + ctx, _ := NewContext("GET", "/test", WithHandlers(mw, handler)) + defer ReleaseContext(ctx) + + err := ctx.Next() + if err == nil { + t.Fatal("expected error from handler chain") + } + he, ok := err.(*celeris.HTTPError) + if !ok { + t.Fatalf("expected *HTTPError, got %T", err) + } + if he.Code != 403 { + t.Fatalf("expected 403, got %d", he.Code) + } +} + +func TestWithHandlersManyHandlers(t *testing.T) { + var order []string + makeMW := func(name string) celeris.HandlerFunc { + return func(c *celeris.Context) error { + order = append(order, name) + return c.Next() + } + } + handler := func(c *celeris.Context) error { + order = append(order, "handler") + return c.String(200, "ok") + } + // 5 middleware + 1 handler = 6 total, exceeds the 4-element handlersBuf. + ctx, rec := NewContext("GET", "/test", WithHandlers( + makeMW("mw1"), makeMW("mw2"), makeMW("mw3"), makeMW("mw4"), makeMW("mw5"), handler, + )) + defer ReleaseContext(ctx) + + err := ctx.Next() + if err != nil { + t.Fatal(err) + } + if rec.StatusCode != 200 { + t.Fatalf("expected 200, got %d", rec.StatusCode) + } + expected := []string{"mw1", "mw2", "mw3", "mw4", "mw5", "handler"} + if len(order) != len(expected) { + t.Fatalf("expected %d entries, got %d: %v", len(expected), len(order), order) + } + for i, e := range expected { + if order[i] != e { + t.Fatalf("order[%d] = %q, want %q", i, order[i], e) + } + } +} + +func TestWithHandlersAbort(t *testing.T) { + var reached bool + mw := func(c *celeris.Context) error { + return c.AbortWithStatus(401) + } + handler := func(c *celeris.Context) error { + reached = true + return c.String(200, "ok") + } + ctx, rec := NewContext("GET", "/test", WithHandlers(mw, handler)) + defer ReleaseContext(ctx) + + _ = ctx.Next() + if reached { + t.Fatal("handler should not have been reached after abort") + } + if rec.StatusCode != 401 { + t.Fatalf("expected 401, got %d", rec.StatusCode) + } +} diff --git a/context.go b/context.go index 80099df..8eb926c 100644 --- a/context.go +++ b/context.go @@ -25,24 +25,49 @@ func init() { } ctxkit.SetHandlers = func(c any, handlers []any) { ctx := c.(*Context) - chain := make([]HandlerFunc, len(handlers)) - for i, h := range handlers { - chain[i] = h.(HandlerFunc) + n := len(handlers) + var chain []HandlerFunc + if n <= len(ctx.handlerBuf) { + for i, h := range handlers { + ctx.handlerBuf[i] = h.(HandlerFunc) + } + chain = ctx.handlerBuf[:n] + } else { + chain = make([]HandlerFunc, n) + for i, h := range handlers { + chain[i] = h.(HandlerFunc) + } } ctx.handlers = chain ctx.index = -1 } + ctxkit.GetResponseWriter = func(c any) any { + ctx := c.(*Context) + if ctx.stream != nil { + return ctx.stream.ResponseWriter + } + return nil + } + ctxkit.GetStream = func(c any) any { + ctx := c.(*Context) + if ctx.stream != nil { + return ctx.stream + } + return nil + } } // Context is the request context passed to handlers. It is pooled via sync.Pool. // A Context is obtained from the pool and must not be retained after the handler returns. type Context struct { - stream *stream.Stream - index int16 - handlers []HandlerFunc - params Params - keys map[string]any - ctx context.Context + stream *stream.Stream + index int16 + handlers []HandlerFunc + handlerBuf [8]HandlerFunc + params Params + paramBuf [4]Param + keys map[string]any + ctx context.Context method string path string @@ -86,7 +111,12 @@ type Context struct { respHdrBuf [8][2]string // reusable buffer for response headers (avoids heap escape) } -var contextPool = sync.Pool{New: func() any { return &Context{} }} +var contextPool = sync.Pool{New: func() any { + c := &Context{keys: make(map[string]any, 4)} + c.params = c.paramBuf[:0] + c.respHeaders = c.respHdrBuf[:0] + return c +}} const abortIndex int16 = math.MaxInt16 / 2 @@ -213,17 +243,11 @@ func (c *Context) SetContext(ctx context.Context) { // Set stores a key-value pair for this request. func (c *Context) Set(key string, value any) { c.extended = true - if c.keys == nil { - c.keys = make(map[string]any) - } c.keys[key] = value } // Get returns the value for a key. func (c *Context) Get(key string) (any, bool) { - if c.keys == nil { - return nil, false - } v, ok := c.keys[key] return v, ok } @@ -231,7 +255,7 @@ func (c *Context) Get(key string) (any, bool) { // Keys returns a copy of all key-value pairs stored on this context. // Returns nil if no values have been set. func (c *Context) Keys() map[string]any { - if c.keys == nil { + if len(c.keys) == 0 { return nil } cp := make(map[string]any, len(c.keys)) @@ -244,25 +268,37 @@ func (c *Context) Keys() map[string]any { func (c *Context) reset() { c.stream = nil c.index = -1 + // Clear handler references so closures can be GCed, but only when + // the slice is owned by this context (backed by handlerBuf or a + // SetHandlers allocation). The production path assigns the router's + // pre-composed chain directly; nilling those would corrupt shared state. + if len(c.handlers) > 0 && cap(c.handlers) <= len(c.handlerBuf) && + &c.handlers[:cap(c.handlers)][0] == &c.handlerBuf[0] { + for i := range c.handlers { + c.handlers[i] = nil + } + } if cap(c.handlers) > 64 { c.handlers = nil } else { c.handlers = c.handlers[:0] } - c.params = c.params[:0] + clear(c.paramBuf[:]) + c.params = c.paramBuf[:0] c.ctx = nil c.method = "" c.path = "" c.rawQuery = "" c.fullPath = "" c.statusCode = 200 - c.respHeaders = c.respHeaders[:0] + clear(c.respHdrBuf[:]) + c.respHeaders = c.respHdrBuf[:0] c.written = false c.aborted = false c.bytesWritten = 0 c.maxFormSize = 0 if c.extended { - c.keys = nil + clear(c.keys) c.queryCache = nil c.queryCached = false c.cookieCache = c.cookieCache[:0] diff --git a/context_request.go b/context_request.go index 4cbe383..b19ee11 100644 --- a/context_request.go +++ b/context_request.go @@ -13,6 +13,7 @@ import ( "net/url" "strconv" "strings" + "unsafe" "github.com/goceleris/celeris/internal/negotiate" ) @@ -60,9 +61,21 @@ func (c *Context) ContentLength() int64 { return n } -// Header returns the value of the named request header. Keys must be lowercase -// (HTTP/2 mandates lowercase; the H1 parser normalizes to lowercase). +// Header returns the value of the named request header. Keys are normalized +// to lowercase automatically (HTTP/2 mandates lowercase; the H1 parser +// normalizes to lowercase). func (c *Context) Header(key string) string { + // Fast path: most programmatic keys are already lowercase. + needsLower := false + for i := range len(key) { + if key[i] >= 'A' && key[i] <= 'Z' { + needsLower = true + break + } + } + if needsLower { + key = strings.ToLower(key) + } for _, h := range c.stream.Headers { if h[0] == key { return h[1] @@ -97,6 +110,15 @@ func (c *Context) ParamInt64(key string) (int64, error) { return strconv.ParseInt(v, 10, 64) } +// ParamDefault returns the value of a URL parameter, or the default if absent or empty. +func (c *Context) ParamDefault(key, defaultValue string) string { + v := c.Param(key) + if v == "" { + return defaultValue + } + return v +} + // Query returns the value of a query parameter by name. Results are cached // so repeated calls for different keys do not re-parse the query string. func (c *Context) Query(key string) string { @@ -135,6 +157,38 @@ func (c *Context) QueryInt(key string, defaultValue int) int { return n } +// QueryInt64 returns a query parameter parsed as an int64. +// Returns the provided default value if the key is absent or not a valid integer. +func (c *Context) QueryInt64(key string, defaultValue int64) int64 { + v := c.Query(key) + if v == "" { + return defaultValue + } + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return defaultValue + } + return n +} + +// QueryBool returns a query parameter parsed as a bool. +// Returns the provided default value if the key is absent or not a valid bool. +// Recognizes "true", "1", "yes" as true and "false", "0", "no" as false. +func (c *Context) QueryBool(key string, defaultValue bool) bool { + v := c.Query(key) + if v == "" { + return defaultValue + } + switch strings.ToLower(v) { + case "true", "1", "yes": + return true + case "false", "0", "no": + return false + default: + return defaultValue + } +} + // QueryValues returns all values for the given query parameter key. // Returns nil if the key is not present. func (c *Context) QueryValues(key string) []string { @@ -282,6 +336,9 @@ func (c *Context) Scheme() string { return c.schemeOverride } if proto := c.Header("x-forwarded-proto"); proto != "" { + if proto == "https" || proto == "http" { + return proto + } return strings.ToLower(strings.TrimSpace(proto)) } if scheme := c.Header(":scheme"); scheme != "" { @@ -340,16 +397,22 @@ func (c *Context) BasicAuth() (username, password string, ok bool) { if len(auth) < len(prefix) || auth[:len(prefix)] != prefix { return } - decoded, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + payload := auth[len(prefix):] + var buf [128]byte + if base64.StdEncoding.DecodedLen(len(payload)) > len(buf) { + return + } + n, err := base64.StdEncoding.Decode(buf[:], + unsafe.Slice(unsafe.StringData(payload), len(payload))) if err != nil { return } - s := string(decoded) - colon := strings.IndexByte(s, ':') - if colon < 0 { + i := bytes.IndexByte(buf[:n], ':') + if i < 0 { return } - return s[:colon], s[colon+1:], true + decoded := string(buf[:n]) + return decoded[:i], decoded[i+1:], true } // FormValue returns the first value for the named form field. @@ -361,10 +424,10 @@ func (c *Context) FormValue(name string) string { return c.formValues.Get(name) } -// FormValueOk returns the first value for the named form field plus a boolean +// FormValueOK returns the first value for the named form field plus a boolean // indicating whether the field was present. Unlike FormValue, callers can // distinguish a missing field from an empty value. -func (c *Context) FormValueOk(name string) (string, bool) { +func (c *Context) FormValueOK(name string) (string, bool) { if err := c.parseForm(); err != nil { return "", false } @@ -375,6 +438,13 @@ func (c *Context) FormValueOk(name string) (string, bool) { return vs[0], true } +// FormValueOk is a deprecated alias for [Context.FormValueOK]. +// +// Deprecated: Use [Context.FormValueOK] instead. +func (c *Context) FormValueOk(name string) (string, bool) { + return c.FormValueOK(name) +} + // FormValues returns all values for the named form field. func (c *Context) FormValues(name string) []string { if err := c.parseForm(); err != nil { diff --git a/context_request_test.go b/context_request_test.go index 23aac6f..98f04b7 100644 --- a/context_request_test.go +++ b/context_request_test.go @@ -15,7 +15,7 @@ import ( func TestContextBind(t *testing.T) { s, _ := newTestStream("POST", "/bind") - s.Data.Write([]byte(`{"name":"test"}`)) + s.GetBuf().Write([]byte(`{"name":"test"}`)) defer s.Release() c := acquireContext(s) @@ -221,7 +221,7 @@ func TestContextQueryInt(t *testing.T) { func TestBindJSON(t *testing.T) { s, _ := newTestStream("POST", "/bind-json") - s.Data.Write([]byte(`{"name":"test"}`)) + s.GetBuf().Write([]byte(`{"name":"test"}`)) defer s.Release() c := acquireContext(s) @@ -238,7 +238,7 @@ func TestBindJSON(t *testing.T) { func TestBindXML(t *testing.T) { s, _ := newTestStream("POST", "/bind-xml") - s.Data.Write([]byte(`test`)) + s.GetBuf().Write([]byte(`test`)) defer s.Release() c := acquireContext(s) @@ -259,7 +259,7 @@ func TestBindXML(t *testing.T) { func TestBindContentTypeDetection(t *testing.T) { // JSON (default, no Content-Type). s, _ := newTestStream("POST", "/bind-auto") - s.Data.Write([]byte(`{"name":"json"}`)) + s.GetBuf().Write([]byte(`{"name":"json"}`)) defer s.Release() c := acquireContext(s) @@ -276,7 +276,7 @@ func TestBindContentTypeDetection(t *testing.T) { // XML with Content-Type header. s2, _ := newTestStream("POST", "/bind-auto-xml") s2.Headers = append(s2.Headers, [2]string{"content-type", "application/xml"}) - s2.Data.Write([]byte(`xml`)) + s2.GetBuf().Write([]byte(`xml`)) defer s2.Release() c2 := acquireContext(s2) @@ -440,7 +440,7 @@ func TestContextBasicAuthPasswordWithColons(t *testing.T) { func TestContextFormValueURLEncoded(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("name=alice&age=30")) + s.GetBuf().Write([]byte("name=alice&age=30")) defer s.Release() c := acquireContext(s) @@ -460,7 +460,7 @@ func TestContextFormValueURLEncoded(t *testing.T) { func TestContextFormValuesMultiple(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("color=red&color=blue&color=green")) + s.GetBuf().Write([]byte("color=red&color=blue&color=green")) defer s.Release() c := acquireContext(s) @@ -484,7 +484,7 @@ func TestContextMultipartFormValue(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()}) - s.Data.Write(buf.Bytes()) + s.GetBuf().Write(buf.Bytes()) defer s.Release() c := acquireContext(s) @@ -513,7 +513,7 @@ func TestContextFormFile(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()}) - s.Data.Write(buf.Bytes()) + s.GetBuf().Write(buf.Bytes()) defer s.Release() c := acquireContext(s) @@ -538,7 +538,7 @@ func TestContextFormFile(t *testing.T) { func TestContextFormFileNonMultipart(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -565,7 +565,7 @@ func TestContextFormEmptyBody(t *testing.T) { func TestContextFormCaching(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("a=1&b=2")) + s.GetBuf().Write([]byte("a=1&b=2")) defer s.Release() c := acquireContext(s) @@ -584,7 +584,7 @@ func TestContextFormCaching(t *testing.T) { func TestContextFormResetClearsForm(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -607,7 +607,7 @@ func TestContextMaxFormSizeUnlimited(t *testing.T) { s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", w.FormDataContentType()}) - s.Data.Write(buf.Bytes()) + s.GetBuf().Write(buf.Bytes()) defer s.Release() c := acquireContext(s) @@ -624,7 +624,7 @@ func TestContextMaxFormSizeUnlimited(t *testing.T) { func TestContextBodyCopy(t *testing.T) { s, _ := newTestStream("POST", "/data") - s.Data.Write([]byte("original")) + s.GetBuf().Write([]byte("original")) defer s.Release() c := acquireContext(s) @@ -655,7 +655,7 @@ func TestContextBodyCopyEmpty(t *testing.T) { func TestContextBodyReader(t *testing.T) { s, _ := newTestStream("POST", "/data") - s.Data.Write([]byte("read me")) + s.GetBuf().Write([]byte("read me")) defer s.Release() c := acquireContext(s) @@ -1070,7 +1070,7 @@ func TestAcceptsLanguagesNoMatch(t *testing.T) { func TestFormFileNotMultipart(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -1092,7 +1092,7 @@ func TestFormFileNotMultipart(t *testing.T) { func TestMultipartFormNotMultipart(t *testing.T) { s, _ := newTestStream("POST", "/upload") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte("key=value")) + s.GetBuf().Write([]byte("key=value")) defer s.Release() c := acquireContext(s) @@ -1166,22 +1166,22 @@ func TestRequestHeaders(t *testing.T) { } } -func TestContextFormValueOk(t *testing.T) { +func TestContextFormValueOK(t *testing.T) { body := "name=alice&empty=" s, _ := newTestStream("POST", "/form") s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) - s.Data.Write([]byte(body)) + s.GetBuf().Write([]byte(body)) defer s.Release() c := acquireContext(s) defer releaseContext(c) - v, ok := c.FormValueOk("name") + v, ok := c.FormValueOK("name") if !ok || v != "alice" { t.Fatalf("expected (alice, true), got (%s, %v)", v, ok) } - v, ok = c.FormValueOk("empty") + v, ok = c.FormValueOK("empty") if !ok { t.Fatal("expected field 'empty' to be present") } @@ -1189,7 +1189,7 @@ func TestContextFormValueOk(t *testing.T) { t.Fatalf("expected empty string, got %q", v) } - _, ok = c.FormValueOk("missing") + _, ok = c.FormValueOK("missing") if ok { t.Fatal("expected missing field to return false") } @@ -1373,3 +1373,149 @@ func TestContextOverridesResetOnReuse(t *testing.T) { t.Fatal("expected hostOverride to be cleared") } } + +func TestQueryBool(t *testing.T) { + tests := []struct { + query string + key string + def bool + expected bool + }{ + {"debug=true", "debug", false, true}, + {"debug=1", "debug", false, true}, + {"debug=yes", "debug", false, true}, + {"debug=false", "debug", true, false}, + {"debug=0", "debug", true, false}, + {"debug=no", "debug", true, false}, + {"debug=TRUE", "debug", false, true}, + {"debug=FALSE", "debug", true, false}, + {"debug=Yes", "debug", false, true}, + {"debug=No", "debug", true, false}, + {"debug=invalid", "debug", true, true}, + {"debug=invalid", "debug", false, false}, + {"", "debug", true, true}, + {"", "debug", false, false}, + {"other=1", "debug", false, false}, + {"other=1", "debug", true, true}, + } + for _, tt := range tests { + name := tt.query + "_" + tt.key + if tt.def { + name += "_def=true" + } else { + name += "_def=false" + } + t.Run(name, func(t *testing.T) { + path := "/test" + if tt.query != "" { + path += "?" + tt.query + } + s, _ := newTestStream("GET", path) + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + got := c.QueryBool(tt.key, tt.def) + if got != tt.expected { + t.Fatalf("QueryBool(%q, %v) = %v, want %v", tt.key, tt.def, got, tt.expected) + } + }) + } +} + +func TestQueryInt64(t *testing.T) { + tests := []struct { + query string + key string + def int64 + expected int64 + }{ + {"page=42", "page", 0, 42}, + {"page=9999999999", "page", 0, 9999999999}, + {"page=-100", "page", 0, -100}, + {"page=0", "page", 99, 0}, + {"page=abc", "page", 10, 10}, + {"", "page", 5, 5}, + {"other=1", "page", 7, 7}, + {"page=9223372036854775807", "page", 0, 9223372036854775807}, + } + for _, tt := range tests { + name := tt.query + "_" + tt.key + t.Run(name, func(t *testing.T) { + path := "/test" + if tt.query != "" { + path += "?" + tt.query + } + s, _ := newTestStream("GET", path) + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + got := c.QueryInt64(tt.key, tt.def) + if got != tt.expected { + t.Fatalf("QueryInt64(%q, %d) = %d, want %d", tt.key, tt.def, got, tt.expected) + } + }) + } +} + +func TestParamDefault(t *testing.T) { + s, _ := newTestStream("GET", "/users/alice") + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + c.params = Params{ + {Key: "name", Value: "alice"}, + {Key: "empty", Value: ""}, + } + + if got := c.ParamDefault("name", "fallback"); got != "alice" { + t.Fatalf("expected alice, got %s", got) + } + if got := c.ParamDefault("empty", "fallback"); got != "fallback" { + t.Fatalf("expected fallback for empty param, got %s", got) + } + if got := c.ParamDefault("missing", "fallback"); got != "fallback" { + t.Fatalf("expected fallback for missing param, got %s", got) + } +} + +func TestFormValueOkDeprecated(t *testing.T) { + body := "name=alice&empty=" + s, _ := newTestStream("POST", "/form") + s.Headers = append(s.Headers, [2]string{"content-type", "application/x-www-form-urlencoded"}) + s.GetBuf().Write([]byte(body)) + defer s.Release() + + c := acquireContext(s) + defer releaseContext(c) + + // FormValueOk (deprecated) should return the same results as FormValueOK. + v1, ok1 := c.FormValueOK("name") + v2, ok2 := c.FormValueOk("name") + if v1 != v2 || ok1 != ok2 { + t.Fatalf("FormValueOk != FormValueOK: (%q,%v) vs (%q,%v)", v1, ok1, v2, ok2) + } + if !ok1 || v1 != "alice" { + t.Fatalf("expected (alice, true), got (%s, %v)", v1, ok1) + } + + v1, ok1 = c.FormValueOK("empty") + v2, ok2 = c.FormValueOk("empty") + if v1 != v2 || ok1 != ok2 { + t.Fatalf("FormValueOk != FormValueOK for empty: (%q,%v) vs (%q,%v)", v1, ok1, v2, ok2) + } + + _, ok1 = c.FormValueOK("missing") + _, ok2 = c.FormValueOk("missing") + if ok1 != ok2 { + t.Fatalf("FormValueOk != FormValueOK for missing: %v vs %v", ok1, ok2) + } + if ok1 { + t.Fatal("expected missing field to return false") + } +} diff --git a/context_response.go b/context_response.go index 9358243..c9542c4 100644 --- a/context_response.go +++ b/context_response.go @@ -125,13 +125,25 @@ func (c *Context) Blob(code int, contentType string, data []byte) error { c.capturedStatus = code c.capturedType = contentType } - headers := c.respHdrBuf[:0:8] - if len(c.respHeaders)+2 > 8 { - headers = make([][2]string, 0, len(c.respHeaders)+2) + nUser := len(c.respHeaders) + total := nUser + 2 + var headers [][2]string + if total <= len(c.respHdrBuf) { + // respHeaders shares backing array with respHdrBuf — copy user + // headers to a stack temporary before overwriting the buffer. + // Max user headers in fast path: len(respHdrBuf) - 2 = 6. + var tmp [6][2]string + copy(tmp[:nUser], c.respHeaders) + headers = c.respHdrBuf[:0:len(c.respHdrBuf)] + headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) + headers = append(headers, [2]string{"content-length", itoa(len(data))}) + headers = append(headers, tmp[:nUser]...) + } else { + headers = make([][2]string, 0, total) + headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) + headers = append(headers, [2]string{"content-length", itoa(len(data))}) + headers = append(headers, c.respHeaders...) } - headers = append(headers, [2]string{"content-type", stripCRLF(contentType)}) - headers = append(headers, [2]string{"content-length", itoa(len(data))}) - headers = append(headers, c.respHeaders...) if c.stream.ResponseWriter != nil { return c.stream.ResponseWriter.WriteResponse(c.stream, code, headers, data) } @@ -175,7 +187,14 @@ func (c *Context) SetHeader(key, value string) { break } } - v := stripCRLF(value) + v := value + for i := range len(value) { + b := value[i] + if b == '\r' || b == '\n' || b == 0 { + v = stripCRLF(value) + break + } + } for i, h := range c.respHeaders { if h[0] == k { c.respHeaders[i][1] = v @@ -189,10 +208,23 @@ func (c *Context) SetHeader(key, value string) { // replace existing values — use this for headers that allow multiple values // (e.g. set-cookie). func (c *Context) AddHeader(key, value string) { - c.respHeaders = append(c.respHeaders, [2]string{ - sanitizeHeaderKey(key), - stripCRLF(value), - }) + k := key + for i := range len(key) { + b := key[i] + if b >= 'A' && b <= 'Z' || b == '\r' || b == '\n' || b == 0 { + k = sanitizeHeaderKey(key) + break + } + } + v := value + for i := range len(value) { + b := value[i] + if b == '\r' || b == '\n' || b == 0 { + v = stripCRLF(value) + break + } + } + c.respHeaders = append(c.respHeaders, [2]string{k, v}) } // sanitizeHeaderKey lowercases and strips CRLF/null bytes. Fast path avoids @@ -207,11 +239,13 @@ func sanitizeHeaderKey(s string) string { return s } +var crlfReplacer = strings.NewReplacer("\r", "", "\n", "", "\x00", "") + // stripCRLF removes \r, \n, and \x00 to prevent HTTP response splitting // (CWE-113) and null-byte header smuggling. func stripCRLF(s string) string { if strings.ContainsAny(s, "\r\n\x00") { - return strings.NewReplacer("\r", "", "\n", "", "\x00", "").Replace(s) + return crlfReplacer.Replace(s) } return s } @@ -245,11 +279,13 @@ func (c *Context) Redirect(code int, url string) error { return c.NoContent(code) } +var cookieUnsafeReplacer = strings.NewReplacer(";", "", "\r", "", "\n", "") + // stripCookieUnsafe strips characters that could inject cookie attributes // (semicolons) or cause header injection (CRLF) from cookie field values. func stripCookieUnsafe(s string) string { if strings.ContainsAny(s, ";\r\n") { - return strings.NewReplacer(";", "", "\r", "", "\n", "").Replace(s) + return cookieUnsafeReplacer.Replace(s) } return s } @@ -261,6 +297,7 @@ func stripCookieUnsafe(s string) string { // cookie attribute injection. func (c *Context) SetCookie(cookie *Cookie) { var b strings.Builder + b.Grow(128) b.WriteString(stripCookieUnsafe(cookie.Name)) b.WriteByte('=') b.WriteString(stripCookieUnsafe(cookie.Value)) @@ -357,15 +394,34 @@ func (c *Context) File(filePath string) error { // FileFromDir safely serves a file from within baseDir. The userPath is // cleaned and joined with baseDir; if the result escapes baseDir, a 400 -// error is returned. This prevents directory traversal when serving -// user-supplied paths. +// error is returned. Symlinks are resolved and rechecked to prevent a +// symlink under baseDir from escaping the directory boundary. func (c *Context) FileFromDir(baseDir, userPath string) error { abs := filepath.Clean(filepath.Join(baseDir, filepath.FromSlash(userPath))) base := filepath.Clean(baseDir) if abs != base && !strings.HasPrefix(abs, base+string(filepath.Separator)) { return NewHTTPError(400, "invalid file path") } - return c.File(abs) + // Resolve symlinks and recheck prefix to prevent symlink escape. + resolved, err := filepath.EvalSymlinks(abs) + if err != nil { + return err + } + resolvedBase, err := filepath.EvalSymlinks(base) + if err != nil { + return err + } + if resolved != resolvedBase && !strings.HasPrefix(resolved, resolvedBase+string(filepath.Separator)) { + return NewHTTPError(400, "invalid file path") + } + info, err := os.Stat(resolved) + if err != nil { + return err + } + if info.IsDir() { + return NewHTTPError(400, "invalid file path") + } + return c.File(resolved) } func parseRange(header string, size int64) (start, end int64, ok bool) { diff --git a/context_test.go b/context_test.go index aa762f7..cf60ebf 100644 --- a/context_test.go +++ b/context_test.go @@ -19,7 +19,8 @@ type mockResponseWriter struct { func (m *mockResponseWriter) WriteResponse(_ *stream.Stream, status int, headers [][2]string, body []byte) error { m.status = status - m.headers = headers + m.headers = make([][2]string, len(headers)) + copy(m.headers, headers) m.body = make([]byte, len(body)) copy(m.body, body) return nil diff --git a/doc.go b/doc.go index 32aae88..88d06f4 100644 --- a/doc.go +++ b/doc.go @@ -431,9 +431,9 @@ // // # Form Presence // -// FormValueOk distinguishes a missing field from an empty value: +// FormValueOK distinguishes a missing field from an empty value: // -// val, ok := c.FormValueOk("name") +// val, ok := c.FormValueOK("name") // if !ok { // // field was not submitted // } diff --git a/handler.go b/handler.go index e277e4a..81bcb63 100644 --- a/handler.go +++ b/handler.go @@ -12,7 +12,9 @@ import ( ) type routerAdapter struct { - server *Server + server *Server + notFoundChain []HandlerFunc + methodNotAllowedChain []HandlerFunc } func (a *routerAdapter) HandleStream(_ context.Context, s *stream.Stream) error { @@ -111,9 +113,13 @@ func (a *routerAdapter) handleUnmatched(c *Context, s *stream.Stream) { if len(allowed) > 0 { c.statusCode = 405 allowVal := strings.Join(allowed, ", ") - if a.server.methodNotAllowedHandler != nil { + chain := a.methodNotAllowedChain + if chain == nil && a.server.methodNotAllowedHandler != nil { + chain = []HandlerFunc{a.server.methodNotAllowedHandler} + } + if chain != nil { c.SetHeader("allow", allowVal) - c.handlers = []HandlerFunc{a.server.methodNotAllowedHandler} + c.handlers = chain a.handleError(c, s, c.Next()) } if !c.written && s.ResponseWriter != nil { @@ -125,8 +131,12 @@ func (a *routerAdapter) handleUnmatched(c *Context, s *stream.Stream) { } } else { c.statusCode = 404 - if a.server.notFoundHandler != nil { - c.handlers = []HandlerFunc{a.server.notFoundHandler} + chain := a.notFoundChain + if chain == nil && a.server.notFoundHandler != nil { + chain = []HandlerFunc{a.server.notFoundHandler} + } + if chain != nil { + c.handlers = chain a.handleError(c, s, c.Next()) } if !c.written && s.ResponseWriter != nil { diff --git a/internal/ctxkit/ctxkit.go b/internal/ctxkit/ctxkit.go index e0af1a9..2a8f06f 100644 --- a/internal/ctxkit/ctxkit.go +++ b/internal/ctxkit/ctxkit.go @@ -7,8 +7,10 @@ import "github.com/goceleris/celeris/protocol/h2/stream" // Hooks registered by the celeris package at init time. var ( - NewContext func(s *stream.Stream) any - ReleaseContext func(c any) - AddParam func(c any, key, value string) - SetHandlers func(c any, handlers []any) + NewContext func(s *stream.Stream) any + ReleaseContext func(c any) + AddParam func(c any, key, value string) + SetHandlers func(c any, handlers []any) + GetResponseWriter func(c any) any + GetStream func(c any) any ) diff --git a/internal/timer/wheel.go b/internal/timer/wheel.go deleted file mode 100644 index 040a491..0000000 --- a/internal/timer/wheel.go +++ /dev/null @@ -1,145 +0,0 @@ -// Package timer implements a hierarchical timer wheel for efficient -// timeout management in event-driven I/O engines. -package timer - -import ( - "sync" - "time" -) - -const ( - // WheelSize is the number of slots in the timer wheel (power of 2 for fast modulo). - WheelSize = 8192 - // TickInterval is the granularity of the timer wheel. - TickInterval = 100 * time.Millisecond - - wheelMask = WheelSize - 1 -) - -// TimeoutKind identifies the type of timeout. -type TimeoutKind uint8 - -// Timeout kinds for the timer wheel. -const ( - ReadTimeout TimeoutKind = iota - WriteTimeout - IdleTimeout -) - -// Entry represents a scheduled timeout. -type Entry struct { - FD int - Deadline int64 // unix nano - Kind TimeoutKind - next *Entry -} - -var entryPool = sync.Pool{New: func() any { return &Entry{} }} - -func getEntry() *Entry { - return entryPool.Get().(*Entry) -} - -func putEntry(e *Entry) { - e.FD = 0 - e.Deadline = 0 - e.Kind = 0 - e.next = nil - entryPool.Put(e) -} - -// Wheel is a hashed timer wheel with O(1) schedule and cancel. -type Wheel struct { - slots [WheelSize]*Entry - current uint64 - startNs int64 - onExpire func(fd int, kind TimeoutKind) - // fdSlots tracks the slot index for each FD to enable O(1) cancel. - fdSlots map[int]int64 // fd → deadline -} - -// New creates a timer wheel that calls onExpire for each expired entry. -func New(onExpire func(fd int, kind TimeoutKind)) *Wheel { - return &Wheel{ - startNs: time.Now().UnixNano(), - onExpire: onExpire, - fdSlots: make(map[int]int64), - } -} - -func (w *Wheel) slotIndex(deadline int64) uint64 { - tick := uint64(deadline-w.startNs) / uint64(TickInterval) - return tick & wheelMask -} - -// Schedule registers a timeout for the given FD. Any previous timeout for -// the same FD is logically cancelled (the old entry is ignored on expiry). -func (w *Wheel) Schedule(fd int, timeout time.Duration, kind TimeoutKind) { - w.ScheduleAt(fd, timeout, kind, time.Now().UnixNano()) -} - -// ScheduleAt is like Schedule but accepts a pre-computed now timestamp (UnixNano) -// to avoid redundant time.Now() calls when scheduling multiple timeouts in the -// same request processing cycle. -func (w *Wheel) ScheduleAt(fd int, timeout time.Duration, kind TimeoutKind, nowNs int64) { - deadline := nowNs + int64(timeout) - slot := w.slotIndex(deadline) - - e := getEntry() - e.FD = fd - e.Deadline = deadline - e.Kind = kind - e.next = w.slots[slot] - w.slots[slot] = e - - w.fdSlots[fd] = deadline -} - -// Cancel removes any pending timeout for the given FD. -func (w *Wheel) Cancel(fd int) { - delete(w.fdSlots, fd) -} - -// Tick advances the wheel and fires expired entries. Returns the number -// of entries that expired. Call this from the event loop after processing -// I/O events. -func (w *Wheel) Tick() int { - now := time.Now().UnixNano() - currentTick := uint64(now-w.startNs) / uint64(TickInterval) - count := 0 - - for w.current <= currentTick { - slot := w.current & wheelMask - var prev *Entry - e := w.slots[slot] - for e != nil { - next := e.next - if e.Deadline <= now { - // Check if this entry is still the active one for this FD. - if active, ok := w.fdSlots[e.FD]; ok && active == e.Deadline { - delete(w.fdSlots, e.FD) - w.onExpire(e.FD, e.Kind) - count++ - } - // Remove from list. - if prev == nil { - w.slots[slot] = next - } else { - prev.next = next - } - putEntry(e) - } else { - prev = e - } - e = next - } - w.current++ - } - - return count -} - -// Len returns the number of tracked FDs (approximate). -func (w *Wheel) Len() int { - return len(w.fdSlots) -} diff --git a/internal/timer/wheel_test.go b/internal/timer/wheel_test.go deleted file mode 100644 index 3e3d3eb..0000000 --- a/internal/timer/wheel_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package timer - -import ( - "sync/atomic" - "testing" - "time" -) - -func TestWheelScheduleAndExpire(t *testing.T) { - var expired atomic.Int32 - var expiredFD int - var expiredKind TimeoutKind - - w := New(func(fd int, kind TimeoutKind) { - expiredFD = fd - expiredKind = kind - expired.Add(1) - }) - - w.Schedule(42, 150*time.Millisecond, IdleTimeout) - - // Should not expire immediately. - n := w.Tick() - if n != 0 { - t.Fatalf("expected 0 expired, got %d", n) - } - - time.Sleep(200 * time.Millisecond) - n = w.Tick() - if n != 1 { - t.Fatalf("expected 1 expired, got %d", n) - } - if expiredFD != 42 { - t.Fatalf("expected FD 42, got %d", expiredFD) - } - if expiredKind != IdleTimeout { - t.Fatalf("expected IdleTimeout, got %d", expiredKind) - } -} - -func TestWheelCancel(t *testing.T) { - var expired atomic.Int32 - w := New(func(_ int, _ TimeoutKind) { - expired.Add(1) - }) - - w.Schedule(10, 150*time.Millisecond, ReadTimeout) - w.Cancel(10) - - time.Sleep(200 * time.Millisecond) - w.Tick() - if expired.Load() != 0 { - t.Fatalf("expected 0 expired after cancel, got %d", expired.Load()) - } -} - -func TestWheelMultipleEntries(t *testing.T) { - expired := make(map[int]bool) - w := New(func(fd int, _ TimeoutKind) { - expired[fd] = true - }) - - w.Schedule(1, 150*time.Millisecond, IdleTimeout) - w.Schedule(2, 150*time.Millisecond, IdleTimeout) - w.Schedule(3, 150*time.Millisecond, IdleTimeout) - - time.Sleep(200 * time.Millisecond) - n := w.Tick() - if n != 3 { - t.Fatalf("expected 3 expired, got %d", n) - } - for _, fd := range []int{1, 2, 3} { - if !expired[fd] { - t.Fatalf("expected FD %d to expire", fd) - } - } -} - -func TestWheelReschedule(t *testing.T) { - var lastKind TimeoutKind - count := 0 - w := New(func(_ int, kind TimeoutKind) { - lastKind = kind - count++ - }) - - // Schedule then reschedule with different kind. - w.Schedule(5, 150*time.Millisecond, ReadTimeout) - w.Schedule(5, 150*time.Millisecond, WriteTimeout) - - time.Sleep(200 * time.Millisecond) - w.Tick() - // Only the latest schedule should fire. - if count != 1 { - t.Fatalf("expected 1 expiry, got %d", count) - } - if lastKind != WriteTimeout { - t.Fatalf("expected WriteTimeout, got %d", lastKind) - } -} - -func TestWheelLen(t *testing.T) { - w := New(func(_ int, _ TimeoutKind) {}) - w.Schedule(1, time.Second, IdleTimeout) - w.Schedule(2, time.Second, IdleTimeout) - if w.Len() != 2 { - t.Fatalf("expected Len=2, got %d", w.Len()) - } - w.Cancel(1) - if w.Len() != 1 { - t.Fatalf("expected Len=1, got %d", w.Len()) - } -} - -func BenchmarkWheelSchedule(b *testing.B) { - w := New(func(_ int, _ TimeoutKind) {}) - b.ResetTimer() - for i := range b.N { - w.Schedule(i%10000, 5*time.Second, IdleTimeout) - } -} - -func BenchmarkWheelTick(b *testing.B) { - w := New(func(_ int, _ TimeoutKind) {}) - // Pre-fill with entries that won't expire. - for i := range 1000 { - w.Schedule(i, time.Hour, IdleTimeout) - } - b.ResetTimer() - for range b.N { - w.Tick() - } -} diff --git a/protocol/h1/parser.go b/protocol/h1/parser.go index bf18908..c9b4b82 100644 --- a/protocol/h1/parser.go +++ b/protocol/h1/parser.go @@ -7,13 +7,14 @@ import ( // H1 parser sentinel errors. var ( - ErrBufferExhausted = errors.New("buffer exhausted") - ErrInvalidRequestLine = errors.New("invalid request line") - ErrInvalidHeader = errors.New("invalid header line") - ErrMissingHost = errors.New("missing Host header") - ErrUnsupportedVersion = errors.New("unsupported HTTP version") - ErrHeadersTooLarge = errors.New("headers too large") - ErrInvalidContentLength = errors.New("invalid content-length") + ErrBufferExhausted = errors.New("buffer exhausted") + ErrInvalidRequestLine = errors.New("invalid request line") + ErrInvalidHeader = errors.New("invalid header line") + ErrMissingHost = errors.New("missing Host header") + ErrUnsupportedVersion = errors.New("unsupported HTTP version") + ErrHeadersTooLarge = errors.New("headers too large") + ErrInvalidContentLength = errors.New("invalid content-length") + ErrDuplicateContentLength = errors.New("duplicate content-length with conflicting values") ) // Parser is a zero-allocation HTTP/1.x request parser. @@ -171,6 +172,9 @@ func (p *Parser) appendHeader(req *Request, rawName, rawValue []byte) error { if !ok { return ErrInvalidContentLength } + if req.ContentLength >= 0 && req.ContentLength != cl { + return ErrDuplicateContentLength + } req.ContentLength = cl case "transfer-encoding": if asciiContainsFoldString(value, "chunked") { @@ -206,11 +210,14 @@ func (p *Parser) appendHeader(req *Request, rawName, rawValue []byte) error { if req.ChunkedEncoding { return nil } - if cl, ok := parseInt64Bytes(rawValue); ok { - req.ContentLength = cl - } else { + cl, ok := parseInt64Bytes(rawValue) + if !ok { return ErrInvalidContentLength } + if req.ContentLength >= 0 && req.ContentLength != cl { + return ErrDuplicateContentLength + } + req.ContentLength = cl return nil } if asciiEqualFold(rawName, "Connection") { diff --git a/protocol/h1/parser_test.go b/protocol/h1/parser_test.go index 7586d6b..f0b6aa2 100644 --- a/protocol/h1/parser_test.go +++ b/protocol/h1/parser_test.go @@ -530,6 +530,81 @@ func TestFindHeaderEnd_MultipleCR(t *testing.T) { } } +func TestParseRequest_DuplicateContentLength_Conflicting(t *testing.T) { + for _, zeroCopy := range []bool{false, true} { + name := "standard" + if zeroCopy { + name = "zerocopy" + } + t.Run(name, func(t *testing.T) { + raw := "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n" + p := NewParser() + if zeroCopy { + p.noStringHeaders = true + } + p.Reset([]byte(raw)) + var req Request + _, err := p.ParseRequest(&req) + if !errors.Is(err, ErrDuplicateContentLength) { + t.Fatalf("got error %v, want %v", err, ErrDuplicateContentLength) + } + }) + } +} + +func TestParseRequest_DuplicateContentLength_Identical(t *testing.T) { + for _, zeroCopy := range []bool{false, true} { + name := "standard" + if zeroCopy { + name = "zerocopy" + } + t.Run(name, func(t *testing.T) { + raw := "POST / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 42\r\nContent-Length: 42\r\n\r\n" + p := NewParser() + if zeroCopy { + p.noStringHeaders = true + } + p.Reset([]byte(raw)) + var req Request + _, err := p.ParseRequest(&req) + if err != nil { + t.Fatalf("identical CL should be accepted, got: %v", err) + } + if req.ContentLength != 42 { + t.Fatalf("content-length = %d, want 42", req.ContentLength) + } + }) + } +} + +func TestParseRequest_DuplicateContentLength_ChunkedIgnored(t *testing.T) { + for _, zeroCopy := range []bool{false, true} { + name := "standard" + if zeroCopy { + name = "zerocopy" + } + t.Run(name, func(t *testing.T) { + raw := "POST / HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\nContent-Length: 0\r\nContent-Length: 50\r\n\r\n" + p := NewParser() + if zeroCopy { + p.noStringHeaders = true + } + p.Reset([]byte(raw)) + var req Request + _, err := p.ParseRequest(&req) + if err != nil { + t.Fatalf("conflicting CL with chunked TE should be ignored, got: %v", err) + } + if !req.ChunkedEncoding { + t.Fatal("chunked encoding not detected") + } + if req.ContentLength != -1 { + t.Fatalf("content-length = %d, want -1 for chunked", req.ContentLength) + } + }) + } +} + // TestFindHeaderEnd_PartialSequence ensures partial \r\n sequences // (without the full \r\n\r\n) are not mistakenly matched. func TestFindHeaderEnd_PartialSequence(t *testing.T) { diff --git a/protocol/h2/stream/stream.go b/protocol/h2/stream/stream.go index 2ccd4b0..9ba93fd 100644 --- a/protocol/h2/stream/stream.go +++ b/protocol/h2/stream/stream.go @@ -61,6 +61,7 @@ type Stream struct { phase Phase CachedCtx any // per-connection cached context (avoids pool Get/Put per request) OnDetach func() // called by Context.Detach to install write-thread safety + hdrBuf [16][2]string } // streamContext is a zero-alloc context.Context for streams. @@ -103,13 +104,9 @@ func NewStream(id uint32) *Stream { s := streamPool.Get().(*Stream) s.ID = id s.state.Store(int32(StateIdle)) - s.Data = getBuf() - s.OutboundBuffer = getBuf() s.windowSize.Store(65535) s.phase = PhaseInit - if cap(s.Headers) < 8 { - s.Headers = make([][2]string, 0, 8) - } + s.Headers = s.hdrBuf[:0] return s } @@ -119,9 +116,7 @@ func NewH1Stream(id uint32) *Stream { s.ID = id s.state.Store(int32(StateIdle)) s.h1Mode = true - if cap(s.Headers) < 16 { - s.Headers = make([][2]string, 0, 16) - } + s.Headers = s.hdrBuf[:0] return s } @@ -162,12 +157,30 @@ func (s *Stream) IsCancelled() bool { return s.flags.Load()&flagCancelled != 0 } +// HasDoneCh reports whether a Done channel was created, indicating +// a derived context (e.g. context.WithTimeout) is watching this stream. +func (s *Stream) HasDoneCh() bool { + return s.doneCh.Load() != nil +} + // Release returns pooled buffers, cancels the context, and returns the stream // to its pool. Safe to call multiple times; subsequent calls are no-ops. func (s *Stream) Release() { if !s.h1Mode { s.Cancel() } + s.resetAndPool() +} + +// ResetForPool returns pooled buffers and returns the stream to its pool +// WITHOUT cancelling the context. Use this in test harnesses where derived +// contexts (e.g. from context.WithTimeout) may have propagation goroutines +// that race with cancellation flag clearing. +func ResetForPool(s *Stream) { + s.resetAndPool() +} + +func (s *Stream) resetAndPool() { if s.Data != nil { s.Data.Reset() bufferPool.Put(s.Data) @@ -181,7 +194,10 @@ func (s *Stream) Release() { s.ID = 0 s.state.Store(0) s.manager = nil - s.Headers = s.Headers[:0] + clear(s.hdrBuf[:]) + s.Headers = s.hdrBuf[:0] + s.Trailers = s.Trailers[:cap(s.Trailers)] + clear(s.Trailers) s.Trailers = s.Trailers[:0] s.OutboundEndStream = false s.headersSent.Store(false) @@ -218,7 +234,7 @@ func ResetH1Stream(s *Stream) { bufferPool.Put(s.Data) s.Data = nil } - s.Headers = s.Headers[:0] + s.Headers = s.hdrBuf[:0] s.headersSent.Store(false) s.EndStream = false s.ResponseWriter = nil @@ -241,7 +257,7 @@ func ResetH2StreamInline(s *Stream, id uint32) { s.OutboundBuffer = getBuf() } s.ID = id - s.Headers = s.Headers[:0] + s.Headers = s.hdrBuf[:0] s.Trailers = s.Trailers[:0] s.OutboundEndStream = false s.headersSent.Store(false) @@ -402,7 +418,7 @@ func (s *Stream) SetHandlerStarted() { func (s *Stream) BufferOutbound(data []byte, endStream bool) { s.mu.Lock() if s.OutboundBuffer == nil { - s.OutboundBuffer = new(bytes.Buffer) + s.OutboundBuffer = getBuf() } s.OutboundBuffer.Write(data) s.OutboundEndStream = endStream diff --git a/server.go b/server.go index 7b32daa..c2da332 100644 --- a/server.go +++ b/server.go @@ -375,7 +375,14 @@ func (s *Server) doPrepare(configureFn func(cfg *resource.Config)) (engine.Engin s.collector = observe.NewCollector() } - var handler stream.Handler = &routerAdapter{server: s} + ra := &routerAdapter{server: s} + if s.notFoundHandler != nil { + ra.notFoundChain = []HandlerFunc{s.notFoundHandler} + } + if s.methodNotAllowedHandler != nil { + ra.methodNotAllowedChain = []HandlerFunc{s.methodNotAllowedHandler} + } + var handler stream.Handler = ra var err error eng, err = createEngine(cfg, handler) if err != nil { diff --git a/server_test.go b/server_test.go index dea10fb..8f9cbc4 100644 --- a/server_test.go +++ b/server_test.go @@ -65,7 +65,7 @@ func TestServerRouting(t *testing.T) { // Test POST /echo. st2, rw2 := newTestStream("POST", "/echo") - st2.Data.Write([]byte("payload")) + st2.GetBuf().Write([]byte("payload")) if err := adapter.HandleStream(context.Background(), st2); err != nil { t.Fatal(err) } diff --git a/stdlib.go b/stdlib.go index b743981..f82bf19 100644 --- a/stdlib.go +++ b/stdlib.go @@ -11,7 +11,7 @@ import ( "golang.org/x/net/http2" ) -const maxToHandlerBodySize = 100 << 20 // 100 MB +const maxToHandlerBodySize = maxBodySize // ToHandler wraps a celeris HandlerFunc as an http.Handler for use with // net/http routers, middleware, or test infrastructure. The returned handler @@ -57,7 +57,7 @@ func ToHandler(h HandlerFunc) http.Handler { return } if len(body) > 0 { - _, _ = s.Data.Write(body) + _, _ = s.GetBuf().Write(body) } } diff --git a/types.go b/types.go index 31a29d7..859035a 100644 --- a/types.go +++ b/types.go @@ -11,7 +11,11 @@ type HandlerFunc func(*Context) error // parsing (32 MB), matching net/http. const DefaultMaxFormSize int64 = 32 << 20 -const maxStreamBodySize = 100 << 20 // 100MB +// maxBodySize is the maximum request/response body size (100 MB), shared by +// stream responses (File, Stream), bridge adapter output, and ToHandler input. +const maxBodySize = 100 << 20 + +const maxStreamBodySize = maxBodySize // SameSite controls the SameSite attribute of a cookie. type SameSite int