From e39ec69ddb8eeb73123054b6b00e0c27b5a0abdf Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Fri, 27 Feb 2026 17:11:22 +0200 Subject: [PATCH] Add validate_schema parameter for strict field validation When validate_schema=True, convert()/convert_parameterized()/analyze() raise InvalidSchemaError for unrecognized table or field references in table.field expressions. This catches typos and references to nonexistent fields that would otherwise silently produce incorrect SQL. Validation skips comprehension variables, bare identifiers, and nested JSON keys beyond the first field. Requires schemas to be provided. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 1 + README.md | 32 ++++++ src/pycel2sql/__init__.py | 23 +++- src/pycel2sql/_converter.py | 34 ++++++ src/pycel2sql/_errors.py | 1 + tests/test_validate_schema.py | 205 ++++++++++++++++++++++++++++++++++ 6 files changed, 295 insertions(+), 1 deletion(-) create mode 100644 tests/test_validate_schema.py diff --git a/CLAUDE.md b/CLAUDE.md index 1e764e7..690660d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -96,4 +96,5 @@ Unit tests (`tests/test_*.py`) cover each feature area per dialect. Integration - `size()` dispatches to `ARRAY_LENGTH` for arrays, `LENGTH` for strings - Depth tracking: `_visit_child()` increments/decrements `_depth` and checks limits - Error types use dual messaging pattern to prevent information disclosure (CWE-209) +- `validate_schema` parameter: opt-in strict validation on `convert()`/`convert_parameterized()`/`analyze()`. Validates `table.field` references exist in schemas; skips comprehension variables, bare identifiers, and nested JSON keys beyond the first field. Raises `InvalidSchemaError` (with dual messaging). Requires schemas to be provided. - Ruff for linting, mypy strict for type checking, line length 100, target Python 3.12+ diff --git a/README.md b/README.md index 6b52f63..a9050f7 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,38 @@ sql = convert( # => usr.metadata->>'role' = 'admin' ``` +## Schema Validation + +Enable strict validation to catch typos and references to nonexistent fields: + +```python +from pycel2sql import convert, InvalidSchemaError +from pycel2sql.schema import Schema, FieldSchema + +schemas = { + "usr": Schema([ + FieldSchema("name"), + FieldSchema("age", type="integer"), + FieldSchema("metadata", is_jsonb=True), + ]) +} + +# Valid field — works normally +sql = convert('usr.name == "alice"', schemas=schemas, validate_schema=True) + +# Unknown field — raises InvalidSchemaError +convert('usr.email == "test"', schemas=schemas, validate_schema=True) +# => InvalidSchemaError: field not found in schema +``` + +Validation scope: +- **Validates**: `table.field` references — table must exist in `schemas`, field must exist in that table's `Schema` +- **Skips**: Nested JSON paths beyond the first field (e.g., `usr.metadata.settings.theme` validates `metadata` exists, not `settings`) +- **Skips**: Comprehension variables (`t` in `tags.all(t, t > 0)`) +- **Skips**: Bare identifiers without a table prefix (`age > 10`) + +Works with all three public API functions: `convert()`, `convert_parameterized()`, and `analyze()`. + ## Schema Introspection Auto-discover table schemas from a live database connection instead of building `Schema` objects manually: diff --git a/src/pycel2sql/__init__.py b/src/pycel2sql/__init__.py index cb383e1..55c4451 100644 --- a/src/pycel2sql/__init__.py +++ b/src/pycel2sql/__init__.py @@ -13,7 +13,7 @@ from celpy.celparser import CELParser from pycel2sql._converter import Converter -from pycel2sql._errors import ConversionError, IntrospectionError +from pycel2sql._errors import ConversionError, IntrospectionError, InvalidSchemaError from pycel2sql.dialect._base import Dialect from pycel2sql.dialect.bigquery import BigQueryDialect from pycel2sql.dialect.duckdb import DuckDBDialect @@ -32,6 +32,7 @@ "Result", "ConversionError", "IntrospectionError", + "InvalidSchemaError", "Dialect", "BigQueryDialect", "DuckDBDialect", @@ -58,6 +59,7 @@ def convert( schemas: dict[str, Schema] | None = None, max_depth: int | None = None, max_output_length: int | None = None, + validate_schema: bool = False, ) -> str: """Convert a CEL expression to an inline SQL WHERE clause string. @@ -67,12 +69,15 @@ def convert( schemas: Optional table schemas for JSON/array field detection. max_depth: Maximum recursion depth. Defaults to 100. max_output_length: Maximum SQL output length. Defaults to 50000. + validate_schema: If True, raise InvalidSchemaError for unrecognized + table or field references. Requires schemas to be provided. Returns: The SQL WHERE clause string. Raises: ConversionError: If conversion fails. + InvalidSchemaError: If validate_schema is True and a field is not in the schema. """ if dialect is None: dialect = PostgresDialect() @@ -86,6 +91,8 @@ def convert( kwargs["max_depth"] = max_depth if max_output_length is not None: kwargs["max_output_length"] = max_output_length + if validate_schema: + kwargs["validate_schema"] = validate_schema converter = Converter(dialect, **kwargs) converter.visit(tree) @@ -99,6 +106,7 @@ def convert_parameterized( schemas: dict[str, Schema] | None = None, max_depth: int | None = None, max_output_length: int | None = None, + validate_schema: bool = False, ) -> Result: """Convert a CEL expression to a parameterized SQL WHERE clause. @@ -108,12 +116,15 @@ def convert_parameterized( schemas: Optional table schemas for JSON/array field detection. max_depth: Maximum recursion depth. Defaults to 100. max_output_length: Maximum SQL output length. Defaults to 50000. + validate_schema: If True, raise InvalidSchemaError for unrecognized + table or field references. Requires schemas to be provided. Returns: Result with SQL containing $1, $2, ... placeholders and parameter list. Raises: ConversionError: If conversion fails. + InvalidSchemaError: If validate_schema is True and a field is not in the schema. """ if dialect is None: dialect = PostgresDialect() @@ -127,6 +138,8 @@ def convert_parameterized( kwargs["max_depth"] = max_depth if max_output_length is not None: kwargs["max_output_length"] = max_output_length + if validate_schema: + kwargs["validate_schema"] = validate_schema converter = Converter(dialect, **kwargs) converter.visit(tree) @@ -148,6 +161,7 @@ def analyze( schemas: dict[str, Schema] | None = None, max_depth: int | None = None, max_output_length: int | None = None, + validate_schema: bool = False, ) -> AnalysisResult: """Analyze a CEL expression for SQL conversion and index recommendations. @@ -157,9 +171,14 @@ def analyze( schemas: Optional table schemas for JSON/array field detection. max_depth: Maximum recursion depth. max_output_length: Maximum SQL output length. + validate_schema: If True, raise InvalidSchemaError for unrecognized + table or field references. Requires schemas to be provided. Returns: AnalysisResult with SQL and index recommendations. + + Raises: + InvalidSchemaError: If validate_schema is True and a field is not in the schema. """ from pycel2sql._analysis import analyze_patterns from pycel2sql.dialect._base import get_index_advisor @@ -177,6 +196,8 @@ def analyze( kwargs["max_depth"] = max_depth if max_output_length is not None: kwargs["max_output_length"] = max_output_length + if validate_schema: + kwargs["validate_schema"] = validate_schema converter = Converter(dialect, **kwargs) converter.visit(tree) diff --git a/src/pycel2sql/_converter.py b/src/pycel2sql/_converter.py index 59477c0..f160ee1 100644 --- a/src/pycel2sql/_converter.py +++ b/src/pycel2sql/_converter.py @@ -16,9 +16,11 @@ MAX_COMPREHENSION_DEPTH, ) from pycel2sql._errors import ( + ERR_MSG_SCHEMA_VALIDATION_FAILED, InvalidArgumentsError, InvalidByteArrayLengthError, InvalidDurationError, + InvalidSchemaError, MaxComprehensionDepthExceededError, MaxDepthExceededError, MaxOutputLengthExceededError, @@ -145,6 +147,7 @@ def __init__( max_depth: int = DEFAULT_MAX_RECURSION_DEPTH, max_output_length: int = DEFAULT_MAX_SQL_OUTPUT_LENGTH, parameterize: bool = False, + validate_schema: bool = False, ) -> None: self._w = StringIO() self._dialect = dialect @@ -157,6 +160,12 @@ def __init__( self._parameters: list[Any] = [] self._param_count = 0 self._comprehension_vars: set[str] = set() + self._validate_schema = validate_schema + if self._validate_schema and not self._schemas: + raise InvalidSchemaError( + ERR_MSG_SCHEMA_VALIDATION_FAILED, + "validate_schema=True requires at least one schema to be provided", + ) @property def result(self) -> str: @@ -580,6 +589,12 @@ def member_dot(self, tree: Tree) -> None: # Check for JSON path table_name = self._get_root_ident(obj) + + # Schema validation (before JSON check and SQL writing) + if table_name and not self._is_comprehension_var(table_name): + first_field = self._get_first_field(obj, field_name) + self._validate_field_in_schema(table_name, first_field) + if table_name and self._is_field_json(table_name, self._get_first_field(obj, field_name)): self._build_json_path(tree) return @@ -1729,6 +1744,25 @@ def _is_json_text_extraction(self, tree: Tree) -> bool: # ---- Schema helpers ---- + def _validate_field_in_schema(self, table_name: str, field_name: str) -> None: + """Validate that a field exists in the schema for a table. + + No-op if validate_schema is False. + """ + if not self._validate_schema: + return + schema = self._schemas.get(table_name) + if schema is None: + raise InvalidSchemaError( + ERR_MSG_SCHEMA_VALIDATION_FAILED, + f"table '{table_name}' not found in schemas", + ) + if schema.find_field(field_name) is None: + raise InvalidSchemaError( + ERR_MSG_SCHEMA_VALIDATION_FAILED, + f"field '{field_name}' not found in schema for '{table_name}'", + ) + def _is_member_dot_array_field(self, tree: Tree) -> bool: """Check if a tree represents an array field via schema.""" node = tree diff --git a/src/pycel2sql/_errors.py b/src/pycel2sql/_errors.py index a3f9a08..c0fb4b1 100644 --- a/src/pycel2sql/_errors.py +++ b/src/pycel2sql/_errors.py @@ -108,3 +108,4 @@ class IntrospectionError(ConversionError): ERR_MSG_INVALID_ARGUMENTS = "invalid function arguments" ERR_MSG_UNKNOWN_TYPE = "unknown type" ERR_MSG_INVALID_PATTERN = "invalid pattern" +ERR_MSG_SCHEMA_VALIDATION_FAILED = "field not found in schema" diff --git a/tests/test_validate_schema.py b/tests/test_validate_schema.py new file mode 100644 index 0000000..a8f9659 --- /dev/null +++ b/tests/test_validate_schema.py @@ -0,0 +1,205 @@ +"""Tests for validate_schema parameter.""" + +import pytest + +from pycel2sql import analyze, convert, convert_parameterized +from pycel2sql._errors import InvalidSchemaError +from pycel2sql.schema import FieldSchema, Schema + + +def _make_schemas() -> dict[str, Schema]: + return { + "usr": Schema([ + FieldSchema(name="name", type="text"), + FieldSchema(name="age", type="integer"), + FieldSchema(name="metadata", type="jsonb", is_jsonb=True), + FieldSchema(name="tags", type="text", repeated=True), + ]), + } + + +class TestValidateSchemaConfig: + """Tests for validate_schema configuration.""" + + def test_default_false_allows_unknown_fields(self): + """Default validate_schema=False allows unknown fields.""" + result = convert("usr.nonexistent == 'foo'") + assert "usr.nonexistent" in result + + def test_true_with_no_schemas_raises(self): + """validate_schema=True with no schemas raises immediately.""" + with pytest.raises(InvalidSchemaError): + convert("usr.name == 'foo'", validate_schema=True) + + def test_true_with_empty_schemas_raises(self): + """validate_schema=True with empty schemas dict raises immediately.""" + with pytest.raises(InvalidSchemaError): + convert("usr.name == 'foo'", schemas={}, validate_schema=True) + + def test_false_with_schemas_allows_unknown(self): + """validate_schema=False with schemas still allows unknown fields.""" + schemas = _make_schemas() + result = convert("usr.nonexistent == 'foo'", schemas=schemas, validate_schema=False) + assert "usr.nonexistent" in result + + +class TestValidateSchemaFieldAccess: + """Tests for field access validation.""" + + def test_valid_field_succeeds(self): + """Known field passes validation.""" + schemas = _make_schemas() + result = convert("usr.name == 'alice'", schemas=schemas, validate_schema=True) + assert "usr.name" in result + + def test_unknown_field_raises(self): + """Unknown field raises InvalidSchemaError.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError) as exc_info: + convert("usr.email == 'test@example.com'", schemas=schemas, validate_schema=True) + assert "field not found in schema" in str(exc_info.value) + assert "email" in exc_info.value.internal_details + + def test_unknown_table_raises(self): + """Unknown table raises InvalidSchemaError.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError) as exc_info: + convert("orders.total > 100", schemas=schemas, validate_schema=True) + assert "field not found in schema" in str(exc_info.value) + assert "orders" in exc_info.value.internal_details + + def test_multiple_valid_fields(self): + """Multiple known fields pass validation.""" + schemas = _make_schemas() + result = convert( + "usr.name == 'alice' && usr.age > 18", + schemas=schemas, + validate_schema=True, + ) + assert "usr.name" in result + assert "usr.age" in result + + def test_error_dual_messaging(self): + """Error uses sanitized user message and detailed internal message.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError) as exc_info: + convert("usr.missing == 1", schemas=schemas, validate_schema=True) + # User-facing message is sanitized + assert str(exc_info.value) == "field not found in schema" + # Internal detail has specifics + assert "missing" in exc_info.value.internal_details + assert "usr" in exc_info.value.internal_details + + +class TestValidateSchemaJSON: + """Tests for JSON field validation.""" + + def test_json_first_field_validated(self): + """JSON field validates that the first field exists in schema.""" + schemas = _make_schemas() + result = convert( + "usr.metadata.key == 'val'", + schemas=schemas, + validate_schema=True, + ) + assert result # Should succeed since 'metadata' is in schema + + def test_json_nested_keys_not_over_validated(self): + """Nested JSON keys beyond first field are not validated.""" + schemas = _make_schemas() + # 'metadata' exists but 'settings' and 'theme' are nested JSON keys — should pass + result = convert( + "usr.metadata.settings.theme == 'dark'", + schemas=schemas, + validate_schema=True, + ) + assert result + + def test_json_unknown_first_field_raises(self): + """Unknown first field with JSON-like access raises.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError): + convert("usr.config.key == 'val'", schemas=schemas, validate_schema=True) + + +class TestValidateSchemaComprehensions: + """Tests for comprehension variable handling.""" + + def test_comprehension_var_not_validated(self): + """Comprehension variables are not validated against schema.""" + schemas = _make_schemas() + result = convert( + "usr.tags.all(t, t == 'admin')", + schemas=schemas, + validate_schema=True, + ) + assert result + + def test_comprehension_with_field_access(self): + """Comprehension on a valid field succeeds.""" + schemas = _make_schemas() + result = convert( + "usr.tags.exists(t, t == 'admin')", + schemas=schemas, + validate_schema=True, + ) + assert result + + +class TestValidateSchemaBareIdents: + """Tests for bare identifiers (no table prefix).""" + + def test_bare_ident_not_validated(self): + """Bare identifiers without table prefix are not validated.""" + schemas = _make_schemas() + result = convert( + "age > 10", + schemas=schemas, + validate_schema=True, + ) + assert "age" in result + + +class TestValidateSchemaConvertParameterized: + """Tests for validate_schema with convert_parameterized.""" + + def test_valid_field_succeeds(self): + schemas = _make_schemas() + result = convert_parameterized( + "usr.name == 'alice'", + schemas=schemas, + validate_schema=True, + ) + assert result.sql + assert result.parameters + + def test_unknown_field_raises(self): + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError): + convert_parameterized( + "usr.email == 'test'", + schemas=schemas, + validate_schema=True, + ) + + +class TestValidateSchemaAnalyze: + """Tests for validate_schema with analyze.""" + + def test_valid_field_succeeds(self): + schemas = _make_schemas() + result = analyze( + "usr.name == 'alice'", + schemas=schemas, + validate_schema=True, + ) + assert result.sql + + def test_unknown_field_raises(self): + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError): + analyze( + "usr.email == 'test'", + schemas=schemas, + validate_schema=True, + )