From 3dca69f6bd6d1659fff4eb612097c321541ec3bc Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 12 Feb 2026 13:02:30 +0100 Subject: [PATCH 1/3] fix: sanitize column names in pagination and add test for column function --- page.go | 2 +- page_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/page.go b/page.go index 58a6002..140e6ec 100644 --- a/page.go +++ b/page.go @@ -134,7 +134,6 @@ func (p *Page) GetOrder(defaultSort ...string) []Sort { continue } if s, ok := NewSort(part); ok { - s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() sort = append(sort, s) } } @@ -235,6 +234,7 @@ func (p Paginator[T]) getOrder(page *Page) []string { if p.settings.ColumnFunc != nil { s.Column = p.settings.ColumnFunc(s.Column) } + s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() list[i] = s.String() } return list diff --git a/page_test.go b/page_test.go index fa2e0ff..df300a2 100644 --- a/page_test.go +++ b/page_test.go @@ -150,3 +150,34 @@ func TestPaginationEdgeCases(t *testing.T) { require.NoError(t, err4) require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql4) } + +func TestColumnFunc(t *testing.T) { + fn := func(column string) string { + switch column { + case "id": + return "ID" + case "name": + return "NAME" + default: + return column + } + } + paginator := pgkit.NewPaginator[T]( + pgkit.WithColumnFunc(fn), + ) + page := &pgkit.Page{ + Page: 1, + Size: 10, + Sort: []pgkit.Sort{ + {Column: "id", Order: pgkit.Asc}, + {Column: "name", Order: pgkit.Desc}, + {Column: "created_at", Order: pgkit.Asc}, + }, + } + _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) + + sql, args, err := query.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "ID" ASC, "NAME" DESC, "created_at" ASC LIMIT 11 OFFSET 0`, sql) + require.Empty(t, args) +} From c2e052e293385b8b1ce9da567c9b9285e2da180e Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 12 Feb 2026 13:39:57 +0100 Subject: [PATCH 2/3] fix: update pagination defaults and enhance column sorting functionality --- page.go | 89 +++++++++++++++++++++++----------------------------- page_test.go | 75 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 105 insertions(+), 59 deletions(-) diff --git a/page.go b/page.go index 140e6ec..a43580e 100644 --- a/page.go +++ b/page.go @@ -29,26 +29,11 @@ type Sort struct { } func (s Sort) String() string { - if s.Column == "" { - return "" - } - s.Order = sanitizeOrder(s.Order) return fmt.Sprintf("%s %s", s.Column, s.Order) } var _MatcherOrderBy = regexp.MustCompile(`-?([a-zA-Z0-9]+)`) -func sanitizeOrder(order Order) Order { - switch strings.ToUpper(strings.TrimSpace(string(order))) { - case string(Desc): - return Desc - case string(Asc): - return Asc - default: - return Asc - } -} - func NewSort(s string) (Sort, bool) { s = strings.TrimSpace(s) if s == "" || !_MatcherOrderBy.MatchString(s) { @@ -74,12 +59,6 @@ type Page struct { } func NewPage(size, page uint32, sort ...Sort) *Page { - if size == 0 { - size = DefaultPageSize - } - if page == 0 { - page = 1 - } return &Page{ Size: size, Page: page, @@ -105,39 +84,48 @@ func (p *Page) SetDefaults(o *PaginatorSettings) { } } -func (p *Page) GetOrder(defaultSort ...string) []Sort { - // if page has sort, use it +func (p *Page) GetOrder(columnFunc func(string) string, defaultSort ...string) []Sort { + var sorts []Sort if p != nil && len(p.Sort) != 0 { - for i, s := range p.Sort { - s.Column = strings.TrimSpace(s.Column) - s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() - s.Order = sanitizeOrder(s.Order) - p.Sort[i] = s + // use sort + sorts = p.Sort + } + // fall back to column + if len(sorts) == 0 { + if p != nil && p.Column != "" { + for part := range strings.SplitSeq(p.Column, ",") { + if s, ok := NewSort(part); ok { + sorts = append(sorts, s) + } + } } - return p.Sort } - // if page has column, use default sort - if p == nil || p.Column == "" { - sort := make([]Sort, 0, len(defaultSort)) + if len(sorts) == 0 { for _, s := range defaultSort { if s, ok := NewSort(s); ok { - sort = append(sort, s) + sorts = append(sorts, s) } } - return sort } - // use column - sort := make([]Sort, 0) - for part := range strings.SplitSeq(p.Column, ",") { - part = strings.TrimSpace(part) - if part == "" { - continue + + for i := range sorts { + s := &sorts[i] + s.Column = strings.TrimSpace(s.Column) + if columnFunc != nil { + s.Column = columnFunc(s.Column) } - if s, ok := NewSort(part); ok { - sort = append(sort, s) + s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() + + switch strings.ToUpper(strings.TrimSpace(string(s.Order))) { + case string(Desc): + s.Order = Desc + case string(Asc): + s.Order = Asc + default: + s.Order = Asc } } - return sort + return sorts } func (p *Page) Offset() uint64 { @@ -228,14 +216,10 @@ type Paginator[T any] struct { } func (p Paginator[T]) getOrder(page *Page) []string { - sort := page.GetOrder(p.settings.Sort...) + sort := page.GetOrder(p.settings.ColumnFunc, p.settings.Sort...) list := make([]string, len(sort)) - for i, s := range sort { - if p.settings.ColumnFunc != nil { - s.Column = p.settings.ColumnFunc(s.Column) - } - s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() - list[i] = s.String() + for i := range sort { + list[i] = sort[i].String() } return list } @@ -253,6 +237,11 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele } func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) { + if page == nil { + page = &Page{} + } + page.SetDefaults(&p.settings) + limit, offset := page.Limit(), page.Offset() q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") diff --git a/page_test.go b/page_test.go index df300a2..8b2ec72 100644 --- a/page_test.go +++ b/page_test.go @@ -26,24 +26,24 @@ func TestPagination(t *testing.T) { page := pgkit.NewPage(0, 0) result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page) sql, args, err := query.ToSql() require.NoError(t, err) - require.Equal(t, "SELECT * FROM t ORDER BY id ASC LIMIT 6 OFFSET 0", sql) + require.Equal(t, "SELECT * FROM t ORDER BY id ASC LIMIT 3 OFFSET 0", sql) require.Empty(t, args) result = paginator.PrepareResult(make([]T, 0), page) require.Len(t, result, 0) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) + require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page) - result = paginator.PrepareResult(make([]T, MaxSize), page) - require.Len(t, result, MaxSize) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) + result = paginator.PrepareResult(make([]T, DefaultSize), page) + require.Len(t, result, DefaultSize) + require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page) - result = paginator.PrepareResult(make([]T, MaxSize+2), page) - require.Len(t, result, MaxSize) - require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page) + result = paginator.PrepareResult(make([]T, DefaultSize+2), page) + require.Len(t, result, DefaultSize) + require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize, More: true}, page) } func TestInvalidSort(t *testing.T) { @@ -181,3 +181,60 @@ func TestColumnFunc(t *testing.T) { require.Equal(t, `SELECT * FROM t ORDER BY "ID" ASC, "NAME" DESC, "created_at" ASC LIMIT 11 OFFSET 0`, sql) require.Empty(t, args) } + +func TestColumnFallbackUsesColumnFunc(t *testing.T) { + paginator := pgkit.NewPaginator[T]( + pgkit.WithColumnFunc(strings.ToUpper), + pgkit.WithSort("id"), + ) + page := &pgkit.Page{ + Page: 1, + Size: 10, + Column: "name", + } + + _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) + + sql, args, err := query.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "NAME" ASC LIMIT 11 OFFSET 0`, sql) + require.Empty(t, args) +} + +func TestSortTakesPrecedenceOverColumn(t *testing.T) { + paginator := pgkit.NewPaginator[T]() + page := &pgkit.Page{ + Page: 1, + Size: 10, + Column: "name", + Sort: []pgkit.Sort{ + {Column: "id", Order: pgkit.Desc}, + }, + } + + _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) + + sql, args, err := query.ToSql() + require.NoError(t, err) + require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11 OFFSET 0`, sql) + require.Empty(t, args) +} + +func TestPaginationOffsetAndPageRecompute(t *testing.T) { + paginator := pgkit.NewPaginator[T]() + page := &pgkit.Page{ + Page: 3, + Size: 2, + } + + _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) + + sql, args, err := query.ToSql() + require.NoError(t, err) + require.Equal(t, "SELECT * FROM t LIMIT 3 OFFSET 4", sql) + require.Empty(t, args) + + result := paginator.PrepareResult(make([]T, 3), page) + require.Len(t, result, 2) + require.Equal(t, &pgkit.Page{Page: 3, Size: 2, More: true}, page) +} From ccadff4a618989bf8a9182bf80309c58cfb951f7 Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Thu, 12 Feb 2026 13:41:48 +0100 Subject: [PATCH 3/3] fix: correct SQL query string formatting in pagination test --- page_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/page_test.go b/page_test.go index 8b2ec72..96b3cb4 100644 --- a/page_test.go +++ b/page_test.go @@ -30,7 +30,7 @@ func TestPagination(t *testing.T) { sql, args, err := query.ToSql() require.NoError(t, err) - require.Equal(t, "SELECT * FROM t ORDER BY id ASC LIMIT 3 OFFSET 0", sql) + require.Equal(t, `SELECT * FROM t ORDER BY "id" ASC LIMIT 3 OFFSET 0`, sql) require.Empty(t, args) result = paginator.PrepareResult(make([]T, 0), page)