Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 209 additions & 9 deletions pkg/database/postgres/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,74 @@ func ValidateColumnName(name string) error {
return nil
}

// ValidateJSONBPath validates JSONB path expressions
// Supports expressions like: column->'key'->>'value', column->0->>'nested', etc.
func ValidateJSONBPath(path string) error {
path = strings.TrimSpace(path)

if len(path) == 0 {
return fmt.Errorf("JSONB path cannot be empty")
}

if len(path) > 512 { // Allow longer paths for JSONB
return fmt.Errorf("JSONB path exceeds maximum length (512 chars): %d", len(path))
}

// Check for SQL injection patterns (allow single quotes for JSONB keys, but block double quotes and dangerous SQL)
dangerousPatterns := []string{";", "--", "/*", "*/", "\""}
for _, pattern := range dangerousPatterns {
if strings.Contains(path, pattern) {
return fmt.Errorf("JSONB path contains potentially dangerous characters: %s", path)
}
}

// Validate structure: must start with valid column name, followed by JSONB operators
// Find the base column name (everything before the first ->)
operatorIndex := strings.Index(path, "->")
if operatorIndex == -1 {
return fmt.Errorf("JSONB path must contain -> or ->> operator: %s", path)
}

baseName := path[:operatorIndex]
if !validColumnRegex.MatchString(baseName) {
return fmt.Errorf("invalid base column in JSONB path: %s", baseName)
}

// Rest of the path after base column should contain balanced single quotes and valid JSONB operators
// Check for balanced quotes
singleQuoteCount := strings.Count(path, "'")
if singleQuoteCount%2 != 0 {
return fmt.Errorf("unbalanced quotes in JSONB path: %s", path)
}

return nil
}

// IsJSONBPath checks if a column reference uses JSONB operators
func IsJSONBPath(name string) bool {
return strings.Contains(name, "->") || strings.Contains(name, "->>")
}

// ValidateAndFormatColumn validates a column name or JSONB path and formats it for SQL
// For simple columns, it quotes the identifier; for JSONB paths, it returns the path as-is (already safe)
func ValidateAndFormatColumn(name string) (string, error) {
name = strings.TrimSpace(name)

// Check if it's a JSONB path expression
if IsJSONBPath(name) {
if err := ValidateJSONBPath(name); err != nil {
return "", err
}
return name, nil // Return JSONB path unquoted
}

// Otherwise validate as regular column name
if err := ValidateColumnName(name); err != nil {
return "", err
}
return pq.QuoteIdentifier(name), nil // Quote regular column names
}

// SplitQualifiedName splits a qualified table name on dots while respecting quoted identifiers
func SplitQualifiedName(qualifiedName string) ([]string, error) {
parts := make([]string, 0, 2)
Expand Down Expand Up @@ -234,6 +302,12 @@ var allowedOperators = map[string]bool{
"lt": true, "<": true, "lte": true, "<=": true,
"like": true, "ilike": true, "in": true, "not_in": true,
"is_null": true, "is_not_null": true, "any": true,
// JSONB operators
"jsonb_contains": true, // @>
"jsonb_contained": true, // <@
"jsonb_has_key": true, // ?
"jsonb_has_any_key": true, // ?|
"jsonb_has_all_keys": true, // ?&
}

// ValidateOperator ensures operator is in the whitelist and safe to use
Expand Down Expand Up @@ -721,7 +795,11 @@ func (postgresDbService *PostgresDbService) ToInterfaceSlice(v interface{}) ([]i

// BuildSimpleCondition builds conditions for simple operators (=, !=, <, >, etc.)
func (postgresDbService *PostgresDbService) BuildSimpleCondition(filter models.QueryFilter, operator string, argCounter int) (string, []interface{}, int) {
condition := fmt.Sprintf("%s %s $%d", pq.QuoteIdentifier(filter.Column), operator, argCounter)
formattedColumn, err := ValidateAndFormatColumn(filter.Column)
if err != nil {
return "", nil, argCounter
}
condition := fmt.Sprintf("%s %s $%d", formattedColumn, operator, argCounter)
args := []interface{}{filter.Value}
return condition, args, argCounter + 1
}
Expand All @@ -733,6 +811,11 @@ func (postgresDbService *PostgresDbService) BuildInCondition(filter models.Query
return "", nil, argCounter
}

formattedColumn, err := ValidateAndFormatColumn(filter.Column)
if err != nil {
return "", nil, argCounter
}

placeholders := make([]string, len(values))
args := make([]interface{}, 0, len(values))
for i, val := range values {
Expand All @@ -745,36 +828,151 @@ func (postgresDbService *PostgresDbService) BuildInCondition(filter models.Query
if useNot {
operator = "NOT IN"
}
condition := fmt.Sprintf("%s %s (%s)", pq.QuoteIdentifier(filter.Column), operator, strings.Join(placeholders, ", "))
condition := fmt.Sprintf("%s %s (%s)", formattedColumn, operator, strings.Join(placeholders, ", "))
return condition, args, argCounter
}

// BuildNullCondition builds conditions for IS NULL/IS NOT NULL operators
func (postgresDbService *PostgresDbService) BuildNullCondition(filter models.QueryFilter, useNot bool, argCounter int) (string, []interface{}, int) {
formattedColumn, err := ValidateAndFormatColumn(filter.Column)
if err != nil {
return "", nil, argCounter
}
operator := "IS NULL"
if useNot {
operator = "IS NOT NULL"
}
condition := fmt.Sprintf("%s %s", pq.QuoteIdentifier(filter.Column), operator)
condition := fmt.Sprintf("%s %s", formattedColumn, operator)
return condition, nil, argCounter
}

// BuildAnyCondition builds conditions for ANY operator
func (postgresDbService *PostgresDbService) BuildAnyCondition(filter models.QueryFilter, argCounter int) (string, []interface{}, int) {
condition := fmt.Sprintf("$%d = ANY(%s)", argCounter, pq.QuoteIdentifier(filter.Column))
formattedColumn, err := ValidateAndFormatColumn(filter.Column)
if err != nil {
return "", nil, argCounter
}
condition := fmt.Sprintf("$%d = ANY(%s)", argCounter, formattedColumn)
args := []interface{}{filter.Value}
return condition, args, argCounter + 1
}

// BuildJSONBCondition builds conditions for JSONB path queries
// Example: column=["raw_statement"], json_path=["result", "success"], operator="eq", value="true"
// Produces: raw_statement->'result'->>'success' = $1
func (postgresDbService *PostgresDbService) BuildJSONBCondition(filter models.QueryFilter, argCounter int) (string, []interface{}, int) {
// Validate column name
if err := ValidateColumnName(filter.Column); err != nil {
return "", nil, argCounter
}

if len(filter.JSONPath) == 0 {
return "", nil, argCounter
}

// Build JSONB path expression: column->'key1'->'key2'->>'final_key'
quotedCol := pq.QuoteIdentifier(filter.Column)
pathExpr := quotedCol

// Navigate through the path, using ->> for the last key to extract text
for i, key := range filter.JSONPath {
quotedKey := fmt.Sprintf("'%s'", strings.ReplaceAll(key, "'", "''")) // SQL escape single quotes
if i == len(filter.JSONPath)-1 {
// Last key - use ->> to extract as text for comparison
pathExpr += fmt.Sprintf(" ->> %s", quotedKey)
} else {
// Intermediate keys - use -> to navigate as JSONB
pathExpr += fmt.Sprintf(" -> %s", quotedKey)
}
}

// Now build the condition using the path expression
operator := strings.ToLower(filter.Operator)
var condition string
var args []interface{}

switch operator {
case "eq", "=":
condition = fmt.Sprintf("%s = $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "gt", ">":
condition = fmt.Sprintf("%s > $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "gte", ">=":
condition = fmt.Sprintf("%s >= $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "lt", "<":
condition = fmt.Sprintf("%s < $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "lte", "<=":
condition = fmt.Sprintf("%s <= $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "like":
condition = fmt.Sprintf("%s LIKE $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "ilike":
condition = fmt.Sprintf("%s ILIKE $%d", pathExpr, argCounter)
args = []interface{}{filter.Value}
argCounter++
case "in":
values, ok := postgresDbService.ToInterfaceSlice(filter.Value)
if !ok || len(values) == 0 {
return "", nil, argCounter
}
placeholders := make([]string, len(values))
for i, val := range values {
placeholders[i] = fmt.Sprintf("$%d", argCounter)
args = append(args, val)
argCounter++
}
condition = fmt.Sprintf("%s IN (%s)", pathExpr, strings.Join(placeholders, ", "))
case "not_in":
values, ok := postgresDbService.ToInterfaceSlice(filter.Value)
if !ok || len(values) == 0 {
return "", nil, argCounter
}
placeholders := make([]string, len(values))
for i, val := range values {
placeholders[i] = fmt.Sprintf("$%d", argCounter)
args = append(args, val)
argCounter++
}
condition = fmt.Sprintf("%s NOT IN (%s)", pathExpr, strings.Join(placeholders, ", "))
case "is_null":
condition = fmt.Sprintf("%s IS NULL", pathExpr)
case "is_not_null":
condition = fmt.Sprintf("%s IS NOT NULL", pathExpr)
default:
// Unknown operator
return "", nil, argCounter
}

return condition, args, argCounter
}

func (postgresDbService *PostgresDbService) BuildFilterCondition(filter models.QueryFilter, argCounter int) (string, []interface{}, int) {
// VALIDATE OPERATOR FIRST - before any SQL string building
if err := ValidateOperator(filter.Operator); err != nil {
// Return empty condition on invalid operator - caller should handle this
// or we could return error as fourth return value (future improvement)
// Return empty condition on invalid operator
return "", nil, argCounter
}

// VALIDATE COLUMN NAME - ensure column is safe from SQL injection
// Check if this is a JSONB path query
if len(filter.JSONPath) > 0 {
return postgresDbService.BuildJSONBCondition(filter, argCounter)
}

// Regular column validation and processing
if err := ValidateColumnName(filter.Column); err != nil {
// Return empty condition on invalid column
return "", nil, argCounter
Expand Down Expand Up @@ -2740,10 +2938,12 @@ func (r *PostgresDbService) RemoveManyToManyRelations(relationship *models.Relat
// - int: Updated argCounter after consuming parameters
//
// Example output for one-to-many:
// SELECT orders.* FROM orders WHERE orders.user_id = $1
//
// SELECT orders.* FROM orders WHERE orders.user_id = $1
//
// Example output for many-to-many:
// SELECT t.* FROM products t INNER JOIN order_items j ON t.id = j.product_id WHERE j.order_id = $1
//
// SELECT t.* FROM products t INNER JOIN order_items j ON t.id = j.product_id WHERE j.order_id = $1
func (r *PostgresDbService) buildRelationshipBaseQuery(relationship *models.RelationshipDefinition, params models.QueryParams, argCounter int) (string, int) {
var query strings.Builder

Expand Down
Loading
Loading