diff --git a/CLAUDE.md b/CLAUDE.md index 1e764e7..690660d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -96,4 +96,5 @@ Unit tests (`tests/test_*.py`) cover each feature area per dialect. Integration - `size()` dispatches to `ARRAY_LENGTH` for arrays, `LENGTH` for strings - Depth tracking: `_visit_child()` increments/decrements `_depth` and checks limits - Error types use dual messaging pattern to prevent information disclosure (CWE-209) +- `validate_schema` parameter: opt-in strict validation on `convert()`/`convert_parameterized()`/`analyze()`. Validates `table.field` references exist in schemas; skips comprehension variables, bare identifiers, and nested JSON keys beyond the first field. Raises `InvalidSchemaError` (with dual messaging). Requires schemas to be provided. - Ruff for linting, mypy strict for type checking, line length 100, target Python 3.12+ diff --git a/README.md b/README.md index 6b52f63..a9050f7 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,38 @@ sql = convert( # => usr.metadata->>'role' = 'admin' ``` +## Schema Validation + +Enable strict validation to catch typos and references to nonexistent fields: + +```python +from pycel2sql import convert, InvalidSchemaError +from pycel2sql.schema import Schema, FieldSchema + +schemas = { + "usr": Schema([ + FieldSchema("name"), + FieldSchema("age", type="integer"), + FieldSchema("metadata", is_jsonb=True), + ]) +} + +# Valid field — works normally +sql = convert('usr.name == "alice"', schemas=schemas, validate_schema=True) + +# Unknown field — raises InvalidSchemaError +convert('usr.email == "test"', schemas=schemas, validate_schema=True) +# => InvalidSchemaError: field not found in schema +``` + +Validation scope: +- **Validates**: `table.field` references — table must exist in `schemas`, field must exist in that table's `Schema` +- **Skips**: Nested JSON paths beyond the first field (e.g., `usr.metadata.settings.theme` validates `metadata` exists, not `settings`) +- **Skips**: Comprehension variables (`t` in `tags.all(t, t > 0)`) +- **Skips**: Bare identifiers without a table prefix (`age > 10`) + +Works with all three public API functions: `convert()`, `convert_parameterized()`, and `analyze()`. + ## Schema Introspection Auto-discover table schemas from a live database connection instead of building `Schema` objects manually: diff --git a/src/pycel2sql/__init__.py b/src/pycel2sql/__init__.py index cb383e1..55c4451 100644 --- a/src/pycel2sql/__init__.py +++ b/src/pycel2sql/__init__.py @@ -13,7 +13,7 @@ from celpy.celparser import CELParser from pycel2sql._converter import Converter -from pycel2sql._errors import ConversionError, IntrospectionError +from pycel2sql._errors import ConversionError, IntrospectionError, InvalidSchemaError from pycel2sql.dialect._base import Dialect from pycel2sql.dialect.bigquery import BigQueryDialect from pycel2sql.dialect.duckdb import DuckDBDialect @@ -32,6 +32,7 @@ "Result", "ConversionError", "IntrospectionError", + "InvalidSchemaError", "Dialect", "BigQueryDialect", "DuckDBDialect", @@ -58,6 +59,7 @@ def convert( schemas: dict[str, Schema] | None = None, max_depth: int | None = None, max_output_length: int | None = None, + validate_schema: bool = False, ) -> str: """Convert a CEL expression to an inline SQL WHERE clause string. @@ -67,12 +69,15 @@ def convert( schemas: Optional table schemas for JSON/array field detection. max_depth: Maximum recursion depth. Defaults to 100. max_output_length: Maximum SQL output length. Defaults to 50000. + validate_schema: If True, raise InvalidSchemaError for unrecognized + table or field references. Requires schemas to be provided. Returns: The SQL WHERE clause string. Raises: ConversionError: If conversion fails. + InvalidSchemaError: If validate_schema is True and a field is not in the schema. """ if dialect is None: dialect = PostgresDialect() @@ -86,6 +91,8 @@ def convert( kwargs["max_depth"] = max_depth if max_output_length is not None: kwargs["max_output_length"] = max_output_length + if validate_schema: + kwargs["validate_schema"] = validate_schema converter = Converter(dialect, **kwargs) converter.visit(tree) @@ -99,6 +106,7 @@ def convert_parameterized( schemas: dict[str, Schema] | None = None, max_depth: int | None = None, max_output_length: int | None = None, + validate_schema: bool = False, ) -> Result: """Convert a CEL expression to a parameterized SQL WHERE clause. @@ -108,12 +116,15 @@ def convert_parameterized( schemas: Optional table schemas for JSON/array field detection. max_depth: Maximum recursion depth. Defaults to 100. max_output_length: Maximum SQL output length. Defaults to 50000. + validate_schema: If True, raise InvalidSchemaError for unrecognized + table or field references. Requires schemas to be provided. Returns: Result with SQL containing $1, $2, ... placeholders and parameter list. Raises: ConversionError: If conversion fails. + InvalidSchemaError: If validate_schema is True and a field is not in the schema. """ if dialect is None: dialect = PostgresDialect() @@ -127,6 +138,8 @@ def convert_parameterized( kwargs["max_depth"] = max_depth if max_output_length is not None: kwargs["max_output_length"] = max_output_length + if validate_schema: + kwargs["validate_schema"] = validate_schema converter = Converter(dialect, **kwargs) converter.visit(tree) @@ -148,6 +161,7 @@ def analyze( schemas: dict[str, Schema] | None = None, max_depth: int | None = None, max_output_length: int | None = None, + validate_schema: bool = False, ) -> AnalysisResult: """Analyze a CEL expression for SQL conversion and index recommendations. @@ -157,9 +171,14 @@ def analyze( schemas: Optional table schemas for JSON/array field detection. max_depth: Maximum recursion depth. max_output_length: Maximum SQL output length. + validate_schema: If True, raise InvalidSchemaError for unrecognized + table or field references. Requires schemas to be provided. Returns: AnalysisResult with SQL and index recommendations. + + Raises: + InvalidSchemaError: If validate_schema is True and a field is not in the schema. """ from pycel2sql._analysis import analyze_patterns from pycel2sql.dialect._base import get_index_advisor @@ -177,6 +196,8 @@ def analyze( kwargs["max_depth"] = max_depth if max_output_length is not None: kwargs["max_output_length"] = max_output_length + if validate_schema: + kwargs["validate_schema"] = validate_schema converter = Converter(dialect, **kwargs) converter.visit(tree) diff --git a/src/pycel2sql/_converter.py b/src/pycel2sql/_converter.py index 59477c0..f160ee1 100644 --- a/src/pycel2sql/_converter.py +++ b/src/pycel2sql/_converter.py @@ -16,9 +16,11 @@ MAX_COMPREHENSION_DEPTH, ) from pycel2sql._errors import ( + ERR_MSG_SCHEMA_VALIDATION_FAILED, InvalidArgumentsError, InvalidByteArrayLengthError, InvalidDurationError, + InvalidSchemaError, MaxComprehensionDepthExceededError, MaxDepthExceededError, MaxOutputLengthExceededError, @@ -145,6 +147,7 @@ def __init__( max_depth: int = DEFAULT_MAX_RECURSION_DEPTH, max_output_length: int = DEFAULT_MAX_SQL_OUTPUT_LENGTH, parameterize: bool = False, + validate_schema: bool = False, ) -> None: self._w = StringIO() self._dialect = dialect @@ -157,6 +160,12 @@ def __init__( self._parameters: list[Any] = [] self._param_count = 0 self._comprehension_vars: set[str] = set() + self._validate_schema = validate_schema + if self._validate_schema and not self._schemas: + raise InvalidSchemaError( + ERR_MSG_SCHEMA_VALIDATION_FAILED, + "validate_schema=True requires at least one schema to be provided", + ) @property def result(self) -> str: @@ -580,6 +589,12 @@ def member_dot(self, tree: Tree) -> None: # Check for JSON path table_name = self._get_root_ident(obj) + + # Schema validation (before JSON check and SQL writing) + if table_name and not self._is_comprehension_var(table_name): + first_field = self._get_first_field(obj, field_name) + self._validate_field_in_schema(table_name, first_field) + if table_name and self._is_field_json(table_name, self._get_first_field(obj, field_name)): self._build_json_path(tree) return @@ -1729,6 +1744,25 @@ def _is_json_text_extraction(self, tree: Tree) -> bool: # ---- Schema helpers ---- + def _validate_field_in_schema(self, table_name: str, field_name: str) -> None: + """Validate that a field exists in the schema for a table. + + No-op if validate_schema is False. + """ + if not self._validate_schema: + return + schema = self._schemas.get(table_name) + if schema is None: + raise InvalidSchemaError( + ERR_MSG_SCHEMA_VALIDATION_FAILED, + f"table '{table_name}' not found in schemas", + ) + if schema.find_field(field_name) is None: + raise InvalidSchemaError( + ERR_MSG_SCHEMA_VALIDATION_FAILED, + f"field '{field_name}' not found in schema for '{table_name}'", + ) + def _is_member_dot_array_field(self, tree: Tree) -> bool: """Check if a tree represents an array field via schema.""" node = tree diff --git a/src/pycel2sql/_errors.py b/src/pycel2sql/_errors.py index a3f9a08..c0fb4b1 100644 --- a/src/pycel2sql/_errors.py +++ b/src/pycel2sql/_errors.py @@ -108,3 +108,4 @@ class IntrospectionError(ConversionError): ERR_MSG_INVALID_ARGUMENTS = "invalid function arguments" ERR_MSG_UNKNOWN_TYPE = "unknown type" ERR_MSG_INVALID_PATTERN = "invalid pattern" +ERR_MSG_SCHEMA_VALIDATION_FAILED = "field not found in schema" diff --git a/tests/test_validate_schema.py b/tests/test_validate_schema.py new file mode 100644 index 0000000..a8f9659 --- /dev/null +++ b/tests/test_validate_schema.py @@ -0,0 +1,205 @@ +"""Tests for validate_schema parameter.""" + +import pytest + +from pycel2sql import analyze, convert, convert_parameterized +from pycel2sql._errors import InvalidSchemaError +from pycel2sql.schema import FieldSchema, Schema + + +def _make_schemas() -> dict[str, Schema]: + return { + "usr": Schema([ + FieldSchema(name="name", type="text"), + FieldSchema(name="age", type="integer"), + FieldSchema(name="metadata", type="jsonb", is_jsonb=True), + FieldSchema(name="tags", type="text", repeated=True), + ]), + } + + +class TestValidateSchemaConfig: + """Tests for validate_schema configuration.""" + + def test_default_false_allows_unknown_fields(self): + """Default validate_schema=False allows unknown fields.""" + result = convert("usr.nonexistent == 'foo'") + assert "usr.nonexistent" in result + + def test_true_with_no_schemas_raises(self): + """validate_schema=True with no schemas raises immediately.""" + with pytest.raises(InvalidSchemaError): + convert("usr.name == 'foo'", validate_schema=True) + + def test_true_with_empty_schemas_raises(self): + """validate_schema=True with empty schemas dict raises immediately.""" + with pytest.raises(InvalidSchemaError): + convert("usr.name == 'foo'", schemas={}, validate_schema=True) + + def test_false_with_schemas_allows_unknown(self): + """validate_schema=False with schemas still allows unknown fields.""" + schemas = _make_schemas() + result = convert("usr.nonexistent == 'foo'", schemas=schemas, validate_schema=False) + assert "usr.nonexistent" in result + + +class TestValidateSchemaFieldAccess: + """Tests for field access validation.""" + + def test_valid_field_succeeds(self): + """Known field passes validation.""" + schemas = _make_schemas() + result = convert("usr.name == 'alice'", schemas=schemas, validate_schema=True) + assert "usr.name" in result + + def test_unknown_field_raises(self): + """Unknown field raises InvalidSchemaError.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError) as exc_info: + convert("usr.email == 'test@example.com'", schemas=schemas, validate_schema=True) + assert "field not found in schema" in str(exc_info.value) + assert "email" in exc_info.value.internal_details + + def test_unknown_table_raises(self): + """Unknown table raises InvalidSchemaError.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError) as exc_info: + convert("orders.total > 100", schemas=schemas, validate_schema=True) + assert "field not found in schema" in str(exc_info.value) + assert "orders" in exc_info.value.internal_details + + def test_multiple_valid_fields(self): + """Multiple known fields pass validation.""" + schemas = _make_schemas() + result = convert( + "usr.name == 'alice' && usr.age > 18", + schemas=schemas, + validate_schema=True, + ) + assert "usr.name" in result + assert "usr.age" in result + + def test_error_dual_messaging(self): + """Error uses sanitized user message and detailed internal message.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError) as exc_info: + convert("usr.missing == 1", schemas=schemas, validate_schema=True) + # User-facing message is sanitized + assert str(exc_info.value) == "field not found in schema" + # Internal detail has specifics + assert "missing" in exc_info.value.internal_details + assert "usr" in exc_info.value.internal_details + + +class TestValidateSchemaJSON: + """Tests for JSON field validation.""" + + def test_json_first_field_validated(self): + """JSON field validates that the first field exists in schema.""" + schemas = _make_schemas() + result = convert( + "usr.metadata.key == 'val'", + schemas=schemas, + validate_schema=True, + ) + assert result # Should succeed since 'metadata' is in schema + + def test_json_nested_keys_not_over_validated(self): + """Nested JSON keys beyond first field are not validated.""" + schemas = _make_schemas() + # 'metadata' exists but 'settings' and 'theme' are nested JSON keys — should pass + result = convert( + "usr.metadata.settings.theme == 'dark'", + schemas=schemas, + validate_schema=True, + ) + assert result + + def test_json_unknown_first_field_raises(self): + """Unknown first field with JSON-like access raises.""" + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError): + convert("usr.config.key == 'val'", schemas=schemas, validate_schema=True) + + +class TestValidateSchemaComprehensions: + """Tests for comprehension variable handling.""" + + def test_comprehension_var_not_validated(self): + """Comprehension variables are not validated against schema.""" + schemas = _make_schemas() + result = convert( + "usr.tags.all(t, t == 'admin')", + schemas=schemas, + validate_schema=True, + ) + assert result + + def test_comprehension_with_field_access(self): + """Comprehension on a valid field succeeds.""" + schemas = _make_schemas() + result = convert( + "usr.tags.exists(t, t == 'admin')", + schemas=schemas, + validate_schema=True, + ) + assert result + + +class TestValidateSchemaBareIdents: + """Tests for bare identifiers (no table prefix).""" + + def test_bare_ident_not_validated(self): + """Bare identifiers without table prefix are not validated.""" + schemas = _make_schemas() + result = convert( + "age > 10", + schemas=schemas, + validate_schema=True, + ) + assert "age" in result + + +class TestValidateSchemaConvertParameterized: + """Tests for validate_schema with convert_parameterized.""" + + def test_valid_field_succeeds(self): + schemas = _make_schemas() + result = convert_parameterized( + "usr.name == 'alice'", + schemas=schemas, + validate_schema=True, + ) + assert result.sql + assert result.parameters + + def test_unknown_field_raises(self): + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError): + convert_parameterized( + "usr.email == 'test'", + schemas=schemas, + validate_schema=True, + ) + + +class TestValidateSchemaAnalyze: + """Tests for validate_schema with analyze.""" + + def test_valid_field_succeeds(self): + schemas = _make_schemas() + result = analyze( + "usr.name == 'alice'", + schemas=schemas, + validate_schema=True, + ) + assert result.sql + + def test_unknown_field_raises(self): + schemas = _make_schemas() + with pytest.raises(InvalidSchemaError): + analyze( + "usr.email == 'test'", + schemas=schemas, + validate_schema=True, + )