-
Notifications
You must be signed in to change notification settings - Fork 6
Sanitize column names after applying column name transformation #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3dca69f
c2e052e
ccadff4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -29,26 +29,11 @@ type Sort struct { | |||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func (s Sort) String() string { | ||||||||||||
| if s.Column == "" { | ||||||||||||
| return "" | ||||||||||||
| } | ||||||||||||
| s.Order = sanitizeOrder(s.Order) | ||||||||||||
| return fmt.Sprintf("%s %s", s.Column, s.Order) | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| var _MatcherOrderBy = regexp.MustCompile(`-?([a-zA-Z0-9]+)`) | ||||||||||||
|
|
||||||||||||
| func sanitizeOrder(order Order) Order { | ||||||||||||
| switch strings.ToUpper(strings.TrimSpace(string(order))) { | ||||||||||||
| case string(Desc): | ||||||||||||
| return Desc | ||||||||||||
| case string(Asc): | ||||||||||||
| return Asc | ||||||||||||
| default: | ||||||||||||
| return Asc | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func NewSort(s string) (Sort, bool) { | ||||||||||||
| s = strings.TrimSpace(s) | ||||||||||||
| if s == "" || !_MatcherOrderBy.MatchString(s) { | ||||||||||||
|
|
@@ -74,12 +59,6 @@ type Page struct { | |||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func NewPage(size, page uint32, sort ...Sort) *Page { | ||||||||||||
| if size == 0 { | ||||||||||||
| size = DefaultPageSize | ||||||||||||
| } | ||||||||||||
| if page == 0 { | ||||||||||||
| page = 1 | ||||||||||||
| } | ||||||||||||
| return &Page{ | ||||||||||||
| Size: size, | ||||||||||||
| Page: page, | ||||||||||||
|
|
@@ -105,40 +84,48 @@ func (p *Page) SetDefaults(o *PaginatorSettings) { | |||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func (p *Page) GetOrder(defaultSort ...string) []Sort { | ||||||||||||
| // if page has sort, use it | ||||||||||||
| func (p *Page) GetOrder(columnFunc func(string) string, defaultSort ...string) []Sort { | ||||||||||||
| var sorts []Sort | ||||||||||||
| if p != nil && len(p.Sort) != 0 { | ||||||||||||
|
Comment on lines
+87
to
89
|
||||||||||||
| for i, s := range p.Sort { | ||||||||||||
| s.Column = strings.TrimSpace(s.Column) | ||||||||||||
| s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() | ||||||||||||
| s.Order = sanitizeOrder(s.Order) | ||||||||||||
| p.Sort[i] = s | ||||||||||||
| // use sort | ||||||||||||
| sorts = p.Sort | ||||||||||||
| } | ||||||||||||
| // fall back to column | ||||||||||||
| if len(sorts) == 0 { | ||||||||||||
| if p != nil && p.Column != "" { | ||||||||||||
| for part := range strings.SplitSeq(p.Column, ",") { | ||||||||||||
| if s, ok := NewSort(part); ok { | ||||||||||||
| sorts = append(sorts, s) | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| return p.Sort | ||||||||||||
| } | ||||||||||||
| // if page has column, use default sort | ||||||||||||
| if p == nil || p.Column == "" { | ||||||||||||
| sort := make([]Sort, 0, len(defaultSort)) | ||||||||||||
| if len(sorts) == 0 { | ||||||||||||
| for _, s := range defaultSort { | ||||||||||||
| if s, ok := NewSort(s); ok { | ||||||||||||
| sort = append(sort, s) | ||||||||||||
| sorts = append(sorts, s) | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| return sort | ||||||||||||
| } | ||||||||||||
| // use column | ||||||||||||
| sort := make([]Sort, 0) | ||||||||||||
| for part := range strings.SplitSeq(p.Column, ",") { | ||||||||||||
| part = strings.TrimSpace(part) | ||||||||||||
| if part == "" { | ||||||||||||
| continue | ||||||||||||
|
|
||||||||||||
| for i := range sorts { | ||||||||||||
| s := &sorts[i] | ||||||||||||
| s.Column = strings.TrimSpace(s.Column) | ||||||||||||
| if columnFunc != nil { | ||||||||||||
| s.Column = columnFunc(s.Column) | ||||||||||||
| } | ||||||||||||
| if s, ok := NewSort(part); ok { | ||||||||||||
| s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() | ||||||||||||
| sort = append(sort, s) | ||||||||||||
| s.Column = pgx.Identifier(strings.Split(s.Column, ".")).Sanitize() | ||||||||||||
|
|
||||||||||||
| switch strings.ToUpper(strings.TrimSpace(string(s.Order))) { | ||||||||||||
| case string(Desc): | ||||||||||||
| s.Order = Desc | ||||||||||||
| case string(Asc): | ||||||||||||
| s.Order = Asc | ||||||||||||
| default: | ||||||||||||
| s.Order = Asc | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| return sort | ||||||||||||
| return sorts | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func (p *Page) Offset() uint64 { | ||||||||||||
|
|
@@ -229,13 +216,10 @@ type Paginator[T any] struct { | |||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func (p Paginator[T]) getOrder(page *Page) []string { | ||||||||||||
| sort := page.GetOrder(p.settings.Sort...) | ||||||||||||
| sort := page.GetOrder(p.settings.ColumnFunc, p.settings.Sort...) | ||||||||||||
| list := make([]string, len(sort)) | ||||||||||||
| for i, s := range sort { | ||||||||||||
| if p.settings.ColumnFunc != nil { | ||||||||||||
| s.Column = p.settings.ColumnFunc(s.Column) | ||||||||||||
| } | ||||||||||||
| list[i] = s.String() | ||||||||||||
| for i := range sort { | ||||||||||||
| list[i] = sort[i].String() | ||||||||||||
| } | ||||||||||||
| return list | ||||||||||||
| } | ||||||||||||
|
|
@@ -253,6 +237,11 @@ func (p Paginator[T]) PrepareQuery(q sq.SelectBuilder, page *Page) ([]T, sq.Sele | |||||||||||
| } | ||||||||||||
|
|
||||||||||||
| func (p Paginator[T]) PrepareRaw(q string, args []any, page *Page) ([]T, string, []any) { | ||||||||||||
| if page == nil { | ||||||||||||
| page = &Page{} | ||||||||||||
| } | ||||||||||||
| page.SetDefaults(&p.settings) | ||||||||||||
|
|
||||||||||||
|
Comment on lines
+240
to
+244
|
||||||||||||
| limit, offset := page.Limit(), page.Offset() | ||||||||||||
|
|
||||||||||||
| q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") | ||||||||||||
|
||||||||||||
| q = q + " ORDER BY " + strings.Join(p.getOrder(page), ", ") | |
| orders := p.getOrder(page) | |
| if len(orders) > 0 { | |
| q = q + " ORDER BY " + strings.Join(orders, ", ") | |
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,24 +26,24 @@ func TestPagination(t *testing.T) { | |
| page := pgkit.NewPage(0, 0) | ||
| result, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) | ||
| require.Len(t, result, 0) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page) | ||
|
|
||
| sql, args, err := query.ToSql() | ||
| require.NoError(t, err) | ||
| require.Equal(t, "SELECT * FROM t ORDER BY id ASC LIMIT 6 OFFSET 0", sql) | ||
| require.Equal(t, `SELECT * FROM t ORDER BY "id" ASC LIMIT 3 OFFSET 0`, sql) | ||
|
||
| require.Empty(t, args) | ||
|
|
||
| result = paginator.PrepareResult(make([]T, 0), page) | ||
| require.Len(t, result, 0) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page) | ||
|
|
||
| result = paginator.PrepareResult(make([]T, MaxSize), page) | ||
| require.Len(t, result, MaxSize) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize}, page) | ||
| result = paginator.PrepareResult(make([]T, DefaultSize), page) | ||
| require.Len(t, result, DefaultSize) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize}, page) | ||
|
|
||
| result = paginator.PrepareResult(make([]T, MaxSize+2), page) | ||
| require.Len(t, result, MaxSize) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: MaxSize, More: true}, page) | ||
| result = paginator.PrepareResult(make([]T, DefaultSize+2), page) | ||
| require.Len(t, result, DefaultSize) | ||
| require.Equal(t, &pgkit.Page{Page: 1, Size: DefaultSize, More: true}, page) | ||
| } | ||
|
|
||
| func TestInvalidSort(t *testing.T) { | ||
|
|
@@ -150,3 +150,91 @@ func TestPaginationEdgeCases(t *testing.T) { | |
| require.NoError(t, err4) | ||
| require.Equal(t, "SELECT * FROM t LIMIT 21 OFFSET 0", sql4) | ||
| } | ||
|
|
||
| func TestColumnFunc(t *testing.T) { | ||
| fn := func(column string) string { | ||
| switch column { | ||
| case "id": | ||
| return "ID" | ||
| case "name": | ||
| return "NAME" | ||
| default: | ||
| return column | ||
| } | ||
| } | ||
| paginator := pgkit.NewPaginator[T]( | ||
| pgkit.WithColumnFunc(fn), | ||
| ) | ||
| page := &pgkit.Page{ | ||
| Page: 1, | ||
| Size: 10, | ||
| Sort: []pgkit.Sort{ | ||
| {Column: "id", Order: pgkit.Asc}, | ||
| {Column: "name", Order: pgkit.Desc}, | ||
| {Column: "created_at", Order: pgkit.Asc}, | ||
| }, | ||
| } | ||
| _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) | ||
|
|
||
| sql, args, err := query.ToSql() | ||
| require.NoError(t, err) | ||
| require.Equal(t, `SELECT * FROM t ORDER BY "ID" ASC, "NAME" DESC, "created_at" ASC LIMIT 11 OFFSET 0`, sql) | ||
| require.Empty(t, args) | ||
klaidliadon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| func TestColumnFallbackUsesColumnFunc(t *testing.T) { | ||
| paginator := pgkit.NewPaginator[T]( | ||
| pgkit.WithColumnFunc(strings.ToUpper), | ||
| pgkit.WithSort("id"), | ||
| ) | ||
| page := &pgkit.Page{ | ||
| Page: 1, | ||
| Size: 10, | ||
| Column: "name", | ||
| } | ||
|
|
||
| _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) | ||
|
|
||
| sql, args, err := query.ToSql() | ||
| require.NoError(t, err) | ||
| require.Equal(t, `SELECT * FROM t ORDER BY "NAME" ASC LIMIT 11 OFFSET 0`, sql) | ||
| require.Empty(t, args) | ||
| } | ||
|
|
||
| func TestSortTakesPrecedenceOverColumn(t *testing.T) { | ||
| paginator := pgkit.NewPaginator[T]() | ||
| page := &pgkit.Page{ | ||
| Page: 1, | ||
| Size: 10, | ||
| Column: "name", | ||
| Sort: []pgkit.Sort{ | ||
| {Column: "id", Order: pgkit.Desc}, | ||
| }, | ||
| } | ||
|
|
||
| _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) | ||
|
|
||
| sql, args, err := query.ToSql() | ||
| require.NoError(t, err) | ||
| require.Equal(t, `SELECT * FROM t ORDER BY "id" DESC LIMIT 11 OFFSET 0`, sql) | ||
| require.Empty(t, args) | ||
| } | ||
|
|
||
| func TestPaginationOffsetAndPageRecompute(t *testing.T) { | ||
| paginator := pgkit.NewPaginator[T]() | ||
| page := &pgkit.Page{ | ||
| Page: 3, | ||
| Size: 2, | ||
| } | ||
|
|
||
| _, query := paginator.PrepareQuery(sq.Select("*").From("t"), page) | ||
|
|
||
| sql, args, err := query.ToSql() | ||
| require.NoError(t, err) | ||
| require.Equal(t, "SELECT * FROM t LIMIT 3 OFFSET 4", sql) | ||
| require.Empty(t, args) | ||
|
|
||
| result := paginator.PrepareResult(make([]T, 3), page) | ||
| require.Len(t, result, 2) | ||
| require.Equal(t, &pgkit.Page{Page: 3, Size: 2, More: true}, page) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sort.String()no longer normalizes/sanitizesOrder(and no longer guards against emptyColumn). SinceSortandString()are exported, this is a behavioral/security regression for any external callers building SQL usingSort.String(). Consider restoring the previous behavior (normalize order to ASC/DESC, and return "" whenColumnis empty) or making stringification explicitly unsafe/unexported.