Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 39 additions & 50 deletions page.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,11 @@ type Sort struct {
}

func (s Sort) String() string {
if s.Column == "" {
return ""
}
s.Order = sanitizeOrder(s.Order)
return fmt.Sprintf("%s %s", s.Column, s.Order)
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort.String() no longer normalizes/sanitizes Order (and no longer guards against empty Column). Since Sort and String() are exported, this is a behavioral/security regression for any external callers building SQL using Sort.String(). Consider restoring the previous behavior (normalize order to ASC/DESC, and return "" when Column is empty) or making stringification explicitly unsafe/unexported.

Suggested change
return fmt.Sprintf("%s %s", s.Column, s.Order)
if s.Column == "" {
return ""
}
order := strings.ToUpper(strings.TrimSpace(string(s.Order)))
switch order {
case "DESC":
order = "DESC"
default:
order = "ASC"
}
return fmt.Sprintf("%s %s", s.Column, order)

Copilot uses AI. Check for mistakes.
}

var _MatcherOrderBy = regexp.MustCompile(`-?([a-zA-Z0-9]+)`)

func sanitizeOrder(order Order) Order {
switch strings.ToUpper(strings.TrimSpace(string(order))) {
case string(Desc):
return Desc
case string(Asc):
return Asc
default:
return Asc
}
}

func NewSort(s string) (Sort, bool) {
s = strings.TrimSpace(s)
if s == "" || !_MatcherOrderBy.MatchString(s) {
Expand All @@ -74,12 +59,6 @@ type Page struct {
}

func NewPage(size, page uint32, sort ...Sort) *Page {
if size == 0 {
size = DefaultPageSize
}
if page == 0 {
page = 1
}
return &Page{
Size: size,
Page: page,
Expand All @@ -105,40 +84,48 @@ func (p *Page) SetDefaults(o *PaginatorSettings) {
}
}

func (p *Page) GetOrder(defaultSort ...string) []Sort {
// if page has sort, use it
func (p *Page) GetOrder(columnFunc func(string) string, defaultSort ...string) []Sort {
var sorts []Sort
if p != nil && len(p.Sort) != 0 {
Comment on lines +87 to 89
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Page.GetOrder changed signature to require a columnFunc parameter. Since GetOrder is exported, this is a breaking API change for downstream users. If the goal is to apply ColumnFunc before sanitization, consider keeping the original GetOrder(...defaultSort) API and handling ColumnFunc inside Paginator.getOrder (or add a new method and keep the old one for compatibility).

Copilot uses AI. Check for mistakes.
for i, s := range p.Sort {
s.Column = strings.TrimSpace(s.Column)
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()
s.Order = sanitizeOrder(s.Order)
p.Sort[i] = s
// use sort
sorts = p.Sort
}
// fall back to column
if len(sorts) == 0 {
if p != nil && p.Column != "" {
for part := range strings.SplitSeq(p.Column, ",") {
if s, ok := NewSort(part); ok {
sorts = append(sorts, s)
}
}
}
return p.Sort
}
// if page has column, use default sort
if p == nil || p.Column == "" {
sort := make([]Sort, 0, len(defaultSort))
if len(sorts) == 0 {
for _, s := range defaultSort {
if s, ok := NewSort(s); ok {
sort = append(sort, s)
sorts = append(sorts, s)
}
}
return sort
}
// use column
sort := make([]Sort, 0)
for part := range strings.SplitSeq(p.Column, ",") {
part = strings.TrimSpace(part)
if part == "" {
continue

for i := range sorts {
s := &sorts[i]
s.Column = strings.TrimSpace(s.Column)
if columnFunc != nil {
s.Column = columnFunc(s.Column)
}
if s, ok := NewSort(part); ok {
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()
sort = append(sort, s)
s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize()

switch strings.ToUpper(strings.TrimSpace(string(s.Order))) {
case string(Desc):
s.Order = Desc
case string(Asc):
s.Order = Asc
default:
s.Order = Asc
}
}
return sort
return sorts
}

func (p *Page) Offset() uint64 {
Expand Down Expand Up @@ -229,13 +216,10 @@ type Paginator[T any] struct {
}

func (p Paginator[T]) getOrder(page *Page) []string {
sort := page.GetOrder(p.settings.Sort...)
sort := page.GetOrder(p.settings.ColumnFunc, p.settings.Sort...)
list := make([]string, len(sort))
for i, s := range sort {
if p.settings.ColumnFunc != nil {
s.Column = p.settings.ColumnFunc(s.Column)
}
list[i] = s.String()
for i := range sort {
list[i] = sort[i].String()
}
return list
}
Expand All @@ -253,6 +237,11 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele
}

func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) {
if page == nil {
page = &Page{}
}
page.SetDefaults(&p.settings)

Comment on lines +240 to +244
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrepareRaw may return a query containing @limit/@offset without adding the corresponding pgx.NamedArgs when args is empty, because the current loop only appends named args inside the for range args body. Add an explicit if len(args)==0 { args = append(args, pgx.NamedArgs{...}) } (or similar) before the loop.

Copilot uses AI. Check for mistakes.
limit, offset := page.Limit(), page.Offset()

q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ")
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrepareRaw unconditionally appends ORDER BY ... to the raw query string. If p.getOrder(page) returns an empty slice (e.g., no default sort configured and no page sort/column provided), this generates invalid SQL (... ORDER BY LIMIT ...). Consider only appending the ORDER BY clause when there is at least one sort term.

Suggested change
q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ")
orders := p.getOrder(page)
if len(orders) > 0 {
q = q + " ORDER BY " + strings.Join(orders, ", ")
}

Copilot uses AI. Check for mistakes.
Expand Down
106 changes: 97 additions & 9 deletions page_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ func TestPagination(t *testing.T) {
page := pgkit.NewPage(0, 0)
result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)
require.Len(t, result, 0)
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page)
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page)

sql, args, err := query.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t ORDER BY id ASC LIMIT 6 OFFSET 0", sql)
require.Equal(t, `SELECT * FROM t ORDER BY "id" ASC LIMIT 3 OFFSET 0`, sql)
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With column sanitization now happening after ColumnFunc (via pgx.Identifier(...).Sanitize()), the default sort column should be quoted consistently (as in the other tests). This assertion likely needs to expect ORDER BY "id" ASC ... rather than ORDER BY id ASC ....

Copilot uses AI. Check for mistakes.
require.Empty(t, args)

result = paginator.PrepareResult(make([]T, 0), page)
require.Len(t, result, 0)
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page)
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page)

result = paginator.PrepareResult(make([]T, MaxSize), page)
require.Len(t, result, MaxSize)
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page)
result = paginator.PrepareResult(make([]T, DefaultSize), page)
require.Len(t, result, DefaultSize)
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page)

result = paginator.PrepareResult(make([]T, MaxSize+2), page)
require.Len(t, result, MaxSize)
require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page)
result = paginator.PrepareResult(make([]T, DefaultSize+2), page)
require.Len(t, result, DefaultSize)
require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize, More: true}, page)
}

func TestInvalidSort(t *testing.T) {
Expand Down Expand Up @@ -150,3 +150,91 @@ func TestPaginationEdgeCases(t *testing.T) {
require.NoError(t, err4)
require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql4)
}

func TestColumnFunc(t *testing.T) {
fn := func(column string) string {
switch column {
case "id":
return "ID"
case "name":
return "NAME"
default:
return column
}
}
paginator := pgkit.NewPaginator[T](
pgkit.WithColumnFunc(fn),
)
page := &pgkit.Page{
Page: 1,
Size: 10,
Sort: []pgkit.Sort{
{Column: "id", Order: pgkit.Asc},
{Column: "name", Order: pgkit.Desc},
{Column: "created_at", Order: pgkit.Asc},
},
}
_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)

sql, args, err := query.ToSql()
require.NoError(t, err)
require.Equal(t, `SELECT * FROM t ORDER BY "ID" ASC, "NAME" DESC, "created_at" ASC LIMIT 11 OFFSET 0`, sql)
require.Empty(t, args)
}

func TestColumnFallbackUsesColumnFunc(t *testing.T) {
paginator := pgkit.NewPaginator[T](
pgkit.WithColumnFunc(strings.ToUpper),
pgkit.WithSort("id"),
)
page := &pgkit.Page{
Page: 1,
Size: 10,
Column: "name",
}

_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)

sql, args, err := query.ToSql()
require.NoError(t, err)
require.Equal(t, `SELECT * FROM t ORDER BY "NAME" ASC LIMIT 11 OFFSET 0`, sql)
require.Empty(t, args)
}

func TestSortTakesPrecedenceOverColumn(t *testing.T) {
paginator := pgkit.NewPaginator[T]()
page := &pgkit.Page{
Page: 1,
Size: 10,
Column: "name",
Sort: []pgkit.Sort{
{Column: "id", Order: pgkit.Desc},
},
}

_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)

sql, args, err := query.ToSql()
require.NoError(t, err)
require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11 OFFSET 0`, sql)
require.Empty(t, args)
}

func TestPaginationOffsetAndPageRecompute(t *testing.T) {
paginator := pgkit.NewPaginator[T]()
page := &pgkit.Page{
Page: 3,
Size: 2,
}

_, query := paginator.PrepareQuery(sq.Select("*").From("t"), page)

sql, args, err := query.ToSql()
require.NoError(t, err)
require.Equal(t, "SELECT * FROM t LIMIT 3 OFFSET 4", sql)
require.Empty(t, args)

result := paginator.PrepareResult(make([]T, 3), page)
require.Len(t, result, 2)
require.Equal(t, &pgkit.Page{Page: 3, Size: 2, More: true}, page)
}
Loading