Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,5 @@ Unit tests (`tests/test_*.py`) cover each feature area per dialect. Integration
- `size()` dispatches to `ARRAY_LENGTH` for arrays, `LENGTH` for strings
- Depth tracking: `_visit_child()` increments/decrements `_depth` and checks limits
- Error types use dual messaging pattern to prevent information disclosure (CWE-209)
- `validate_schema` parameter: opt-in strict validation on `convert()`/`convert_parameterized()`/`analyze()`. Validates `table.field` references exist in schemas; skips comprehension variables, bare identifiers, and nested JSON keys beyond the first field. Raises `InvalidSchemaError` (with dual messaging). Requires schemas to be provided.
- Ruff for linting, mypy strict for type checking, line length 100, target Python 3.12+
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,38 @@ sql = convert(
# => usr.metadata->>'role' = 'admin'
```

## Schema Validation

Enable strict validation to catch typos and references to nonexistent fields:

```python
from pycel2sql import convert, InvalidSchemaError
from pycel2sql.schema import Schema, FieldSchema

schemas = {
"usr": Schema([
FieldSchema("name"),
FieldSchema("age", type="integer"),
FieldSchema("metadata", is_jsonb=True),
])
}

# Valid field — works normally
sql = convert('usr.name == "alice"', schemas=schemas, validate_schema=True)

# Unknown field — raises InvalidSchemaError
convert('usr.email == "test"', schemas=schemas, validate_schema=True)
# => InvalidSchemaError: field not found in schema
```

Validation scope:
- **Validates**: `table.field` references — table must exist in `schemas`, field must exist in that table's `Schema`
- **Skips**: Nested JSON paths beyond the first field (e.g., `usr.metadata.settings.theme` validates `metadata` exists, not `settings`)
- **Skips**: Comprehension variables (`t` in `tags.all(t, t > 0)`)
- **Skips**: Bare identifiers without a table prefix (`age > 10`)

Works with all three public API functions: `convert()`, `convert_parameterized()`, and `analyze()`.

## Schema Introspection

Auto-discover table schemas from a live database connection instead of building `Schema` objects manually:
Expand Down
23 changes: 22 additions & 1 deletion src/pycel2sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from celpy.celparser import CELParser

from pycel2sql._converter import Converter
from pycel2sql._errors import ConversionError, IntrospectionError
from pycel2sql._errors import ConversionError, IntrospectionError, InvalidSchemaError
from pycel2sql.dialect._base import Dialect
from pycel2sql.dialect.bigquery import BigQueryDialect
from pycel2sql.dialect.duckdb import DuckDBDialect
Expand All @@ -32,6 +32,7 @@
"Result",
"ConversionError",
"IntrospectionError",
"InvalidSchemaError",
"Dialect",
"BigQueryDialect",
"DuckDBDialect",
Expand All @@ -58,6 +59,7 @@ def convert(
schemas: dict[str, Schema] | None = None,
max_depth: int | None = None,
max_output_length: int | None = None,
validate_schema: bool = False,
) -> str:
"""Convert a CEL expression to an inline SQL WHERE clause string.

Expand All @@ -67,12 +69,15 @@ def convert(
schemas: Optional table schemas for JSON/array field detection.
max_depth: Maximum recursion depth. Defaults to 100.
max_output_length: Maximum SQL output length. Defaults to 50000.
validate_schema: If True, raise InvalidSchemaError for unrecognized
table or field references. Requires schemas to be provided.

Returns:
The SQL WHERE clause string.

Raises:
ConversionError: If conversion fails.
InvalidSchemaError: If validate_schema is True and a field is not in the schema.
"""
if dialect is None:
dialect = PostgresDialect()
Expand All @@ -86,6 +91,8 @@ def convert(
kwargs["max_depth"] = max_depth
if max_output_length is not None:
kwargs["max_output_length"] = max_output_length
if validate_schema:
kwargs["validate_schema"] = validate_schema

converter = Converter(dialect, **kwargs)
converter.visit(tree)
Expand All @@ -99,6 +106,7 @@ def convert_parameterized(
schemas: dict[str, Schema] | None = None,
max_depth: int | None = None,
max_output_length: int | None = None,
validate_schema: bool = False,
) -> Result:
"""Convert a CEL expression to a parameterized SQL WHERE clause.

Expand All @@ -108,12 +116,15 @@ def convert_parameterized(
schemas: Optional table schemas for JSON/array field detection.
max_depth: Maximum recursion depth. Defaults to 100.
max_output_length: Maximum SQL output length. Defaults to 50000.
validate_schema: If True, raise InvalidSchemaError for unrecognized
table or field references. Requires schemas to be provided.

Returns:
Result with SQL containing $1, $2, ... placeholders and parameter list.

Raises:
ConversionError: If conversion fails.
InvalidSchemaError: If validate_schema is True and a field is not in the schema.
"""
if dialect is None:
dialect = PostgresDialect()
Expand All @@ -127,6 +138,8 @@ def convert_parameterized(
kwargs["max_depth"] = max_depth
if max_output_length is not None:
kwargs["max_output_length"] = max_output_length
if validate_schema:
kwargs["validate_schema"] = validate_schema

converter = Converter(dialect, **kwargs)
converter.visit(tree)
Expand All @@ -148,6 +161,7 @@ def analyze(
schemas: dict[str, Schema] | None = None,
max_depth: int | None = None,
max_output_length: int | None = None,
validate_schema: bool = False,
) -> AnalysisResult:
"""Analyze a CEL expression for SQL conversion and index recommendations.

Expand All @@ -157,9 +171,14 @@ def analyze(
schemas: Optional table schemas for JSON/array field detection.
max_depth: Maximum recursion depth.
max_output_length: Maximum SQL output length.
validate_schema: If True, raise InvalidSchemaError for unrecognized
table or field references. Requires schemas to be provided.

Returns:
AnalysisResult with SQL and index recommendations.

Raises:
InvalidSchemaError: If validate_schema is True and a field is not in the schema.
"""
from pycel2sql._analysis import analyze_patterns
from pycel2sql.dialect._base import get_index_advisor
Expand All @@ -177,6 +196,8 @@ def analyze(
kwargs["max_depth"] = max_depth
if max_output_length is not None:
kwargs["max_output_length"] = max_output_length
if validate_schema:
kwargs["validate_schema"] = validate_schema

converter = Converter(dialect, **kwargs)
converter.visit(tree)
Expand Down
34 changes: 34 additions & 0 deletions src/pycel2sql/_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
MAX_COMPREHENSION_DEPTH,
)
from pycel2sql._errors import (
ERR_MSG_SCHEMA_VALIDATION_FAILED,
InvalidArgumentsError,
InvalidByteArrayLengthError,
InvalidDurationError,
InvalidSchemaError,
MaxComprehensionDepthExceededError,
MaxDepthExceededError,
MaxOutputLengthExceededError,
Expand Down Expand Up @@ -145,6 +147,7 @@ def __init__(
max_depth: int = DEFAULT_MAX_RECURSION_DEPTH,
max_output_length: int = DEFAULT_MAX_SQL_OUTPUT_LENGTH,
parameterize: bool = False,
validate_schema: bool = False,
) -> None:
self._w = StringIO()
self._dialect = dialect
Expand All @@ -157,6 +160,12 @@ def __init__(
self._parameters: list[Any] = []
self._param_count = 0
self._comprehension_vars: set[str] = set()
self._validate_schema = validate_schema
if self._validate_schema and not self._schemas:
raise InvalidSchemaError(
ERR_MSG_SCHEMA_VALIDATION_FAILED,
"validate_schema=True requires at least one schema to be provided",
)

@property
def result(self) -> str:
Expand Down Expand Up @@ -580,6 +589,12 @@ def member_dot(self, tree: Tree) -> None:

# Check for JSON path
table_name = self._get_root_ident(obj)

# Schema validation (before JSON check and SQL writing)
if table_name and not self._is_comprehension_var(table_name):
first_field = self._get_first_field(obj, field_name)
self._validate_field_in_schema(table_name, first_field)

if table_name and self._is_field_json(table_name, self._get_first_field(obj, field_name)):
self._build_json_path(tree)
return
Expand Down Expand Up @@ -1729,6 +1744,25 @@ def _is_json_text_extraction(self, tree: Tree) -> bool:

# ---- Schema helpers ----

def _validate_field_in_schema(self, table_name: str, field_name: str) -> None:
"""Validate that a field exists in the schema for a table.

No-op if validate_schema is False.
"""
if not self._validate_schema:
return
schema = self._schemas.get(table_name)
if schema is None:
raise InvalidSchemaError(
ERR_MSG_SCHEMA_VALIDATION_FAILED,
f"table '{table_name}' not found in schemas",
)
if schema.find_field(field_name) is None:
raise InvalidSchemaError(
ERR_MSG_SCHEMA_VALIDATION_FAILED,
f"field '{field_name}' not found in schema for '{table_name}'",
)

def _is_member_dot_array_field(self, tree: Tree) -> bool:
"""Check if a tree represents an array field via schema."""
node = tree
Expand Down
1 change: 1 addition & 0 deletions src/pycel2sql/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@ class IntrospectionError(ConversionError):
ERR_MSG_INVALID_ARGUMENTS = "invalid function arguments"
ERR_MSG_UNKNOWN_TYPE = "unknown type"
ERR_MSG_INVALID_PATTERN = "invalid pattern"
ERR_MSG_SCHEMA_VALIDATION_FAILED = "field not found in schema"
Loading
Loading