From 5e336063981fb6bb7d1876e0da13575c61b492e8 Mon Sep 17 00:00:00 2001 From: wmoustafa Date: Fri, 22 May 2026 16:16:02 -0700 Subject: [PATCH] Add nested-type access and more SQL operators to data-generation inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends `coral-data-generation` so the symbolic-constraint solver from PR #564 covers a wider class of WHERE predicates: more SQL operators, struct and map/array element access, and a predicate-based inference entry point that resolves per-path domains from a DNF query. Also tightens two inference paths whose existing rewrites silently produced wrong results for the new cases. ## New operator coverage Eight new `DomainTransformer` implementations are wired into `DomainInferenceProgram.withDefaultTransformers()`: | Transformer | SQL operator | | --- | --- | | `AbsIntegerTransformer` | `ABS(x)` | | `MinusIntegerTransformer` | binary `x - k` and `k - x` | | `NegateIntegerTransformer` | unary `-x` | | `UpperRegexTransformer` | `UPPER(x)` | | `ConcatRegexTransformer` | `CONCAT(x, lit)` / `CONCAT(lit, x)` | | `TrimRegexTransformer` | `TRIM(x)` — supports both Calcite's 3-operand standard form and Hive's 1-operand form | | `FieldAccessTransformer` | struct field access (`s.name`) on nested expressions | | `ItemTransformer` | `ITEM(coll, idx-or-key)` for array indexing and map lookup on nested expressions | `ConcatRegexTransformer` matches both `SqlStdOperatorTable.CONCAT` (the SQL `||` operator) and the `OTHER_FUNCTION` named `concat` that Hive emits. Existing transformers (`LowerRegexTransformer`, `PlusIntegerTransformer`, `TimesIntegerTransformer`, `SubstringRegexTransformer`) now accept `RexFieldAccess` as a valid variable operand, so expressions like `LOWER(s.name)`, `s.age + 5`, and `UPPER(sarr[0].name)` flow through. `SubstringRegexTransformer.canHandle` also gained an operand-arity check. The transformer registration is grouped into string ops → integer ops → cross-domain → structural pass-throughs for readability. ## Nested-type access New `AccessPath` value type identifies any value reachable from a root column index through a chain of struct fields (`FIELD`), map lookups (`MAP_KEY`), and array indices (`ARRAY_INDEX`). It's the key type of the new multi-path resolution API (below) and is also used in tests to assert which nested values were resolved. `DomainInferenceProgram.deriveInputDomain` gained two base cases so inference terminates correctly at nested column references — struct field access on a `RexInputRef` (e.g., `$3.name`) and ITEM access on a `RexInputRef` (e.g., `ITEM($2, 1)` for arrays, `ITEM($4, 'env')` for maps). ## Predicate-based inference: two reductions up the SQL evaluation hierarchy Master exposed one primitive — `deriveInputDomain(expr, outputDomain) → inputDomain` — which answers the leaf question: given an expression and a constraint on its output, derive the constraint on the input variable. Real callers, though, start higher up the SQL evaluation stack. The PR adds the two reductions that bridge a full WHERE clause down to the primitive: ``` WHERE clause (tree of AND / OR over comparisons) │ │ DnfRewriter (already exists) ▼ list of DNF disjuncts ── resolveAllPaths (new) │ │ for each disjunct, for each conjunct ▼ single comparison predicate (expr OP literal) ── deriveInputDomainFromPredicate (new) │ │ compute output domain from OP + literal ▼ (expression, output domain) pair ── deriveInputDomain (primitive) │ │ walk expr, refine via transformers ▼ domain on the input variable ``` - **`deriveInputDomainFromPredicate(RexCall predicate)`** is one reduction above the primitive. It takes a comparison `expr OP literal` (`=`, `<`, `>`, `<=`, `>=`), computes the output domain from the operator and literal — `> 5` ⇒ `IntegerDomain([6, ∞))`, `= 'abc'` ⇒ `RegexDomain.literal("abc")` — and reduces to `deriveInputDomain(expr, that)`. It also unwraps the `RexCall(UNARY_MINUS, RexLiteral)` shape Calcite uses for negative literals so `age = -5` works the same as `age = 5`. - **`resolveAllPaths(List disjuncts)`** is one reduction above that. Given the DNF disjuncts produced by `DnfRewriter`, it walks every disjunct, every conjunct, calls `deriveInputDomainFromPredicate` on each comparison, and combines the per-`AccessPath` results with AND semantics within a disjunct (intersection) and OR semantics across disjuncts (union). Predicates outside the comparison-with-literal shape are silently skipped — notably column-to-column join predicates, which still require per-column literals. For `WHERE (age > 10 AND name = 'foo') OR (age = 0)` the result is roughly `{ $age → IntegerDomain([11,∞) ∪ {0}), $name → RegexDomain("foo") }`. Nothing else is added: anything more specific belongs in a transformer, and anything less specific (such as converting a WHERE tree to DNF in the first place) was already the caller's job via `DnfRewriter`. ## Tighten `RegexToIntegerDomainConverter`: accept only canonical decimal regexes - **Input:** `R = ^[0-9]{3}$`. - **Master returns:** `IntegerDomain{0..999}`. - **Should return:** `IntegerDomain{100..999}` — SQL `CAST(integer AS VARCHAR)` produces canonical decimal (`0 → "0"`, never `"000"`), so `0` does not belong. - **Fix:** narrow the converter's contract to canonical-decimal regexes only. The accept rule changes from "finite + digit-only" to "finite + subset of `^(0|[1-9][0-9]*)$`". Non-canonical inputs (`^[0-9]{3}$`, `^009$`, empty regex, …) are now rejected with `NonConvertibleDomainException`. `CastRegexTransformer`'s `CAST(int AS VARCHAR)` branch keeps calling `convert(outputRegex)` directly and relies on this strict contract. ## ProjectPullUpRewriter: remap the join condition when a left Project changes field count Concrete scenario: tables `T1(a, b, c)` (3 cols) and `T2(x, y)` (2 cols). Plan before pull-up: ``` Join(condition: b = x) ├── Project(a, b) keeps 2 of T1's 3 columns │ └── Scan(T1) └── Scan(T2) ``` The join's row type is `[Project-output | T2] = [a, b, x, y]`, so inside the condition `b` resolves to `$1` and `x` to `$2`. The condition is `$1 = $2`. After pull-up, the `Project` moves above the `Join`, and the new join's left input is the raw `Scan(T1)`: ``` Project(...) └── Join(condition: ???) ├── Scan(T1) └── Scan(T2) ``` The new join's row type is `[T1 | T2] = [a, b, c, x, y]`. `b` is still `$1`, but `x` is now `$3` because the left input grew from 2 columns back to 3. The rewritten condition must be `$1 = $3`. Master inlined left-side `InputRef`s through the removed `Project` but left right-side `InputRef`s at their old positions. The rewritten condition came out as `$1 = $2`, which in the new frame points at `T1.c` (`VARCHAR`) — not `T2.x` (`INTEGER`). Wrong column, and a type mismatch that breaks join evaluation. The fix replaces the two side-specific helpers (`inlineLeftSide`, `inlineRightSide`) with a single `remapJoinCondition` pass. For every `InputRef` in the old condition it computes the position in the new frame using `oldLeftCount` (Project-output width) and `newLeftCount` (unprojected-left width): right-side references shift by `newLeftCount - oldLeftCount`; left-side references are remapped through the lifted projection expressions. ## IntegerDomain - New `negate()` method (returns `multiply(-1)`), used by the new `NegateIntegerTransformer`. - `Interval.isAdjacent` refactored to make the overflow guard explicit in two named booleans, matching the original behavior. ## Build `coral-data-generation/build.gradle` now applies the `java-library` plugin so the module exposes proper `api`/`implementation` configurations. ## Tests `RegexDomainInferenceProgramTest` is the main integration suite and grows substantially: it exercises every new operator individually, every new nested-type access pattern, and combined SQL queries with AND/OR over struct/map/array paths against four test tables (`test.T`, `test.complex`, `test.deep`, `test.interleaved`). Notable coverage areas: - single-operator tests for `SUBSTRING`, `LOWER`, `UPPER`, `CAST(int→str)`, `CAST(str→int)`, `CAST(str→date)`, arithmetic, `MINUS`, `ABS`, unary minus, `CONCAT`, `TRIM`, comparison operators with and without arithmetic - multi-column AND/OR with same-column intersection, disjoint ranges, range-with-equality, contradictory ranges, mixed regex/integer domains - struct field equality and arithmetic, map-element equality, array of structs, nested struct (`nested_struct.sub.value`), map of structs (`map_of_structs['key'].score`), and interleaved combinations - CAST cross-domain on struct fields, OR disjunction on struct fields, per-column union semantics `RegexTransformerTest` is a new dedicated unit-test class for `Concat`: prefix/suffix stripping, prefix/suffix mismatch (empty domain), empty suffix as identity, non-literal output passthrough. `IntegerTransformerTest` adds rigorous-style cases for `Minus`, `Negate`, and `Abs`: each test constructs the `RexCall` via `RexBuilder` and calls `transformer.refineInputDomain` directly, then asserts containment and boundaries — including the empty case for `ABS` over an all-negative output interval. `RegexToIntegerDomainConverterTest` is updated to match the new contract: tests that previously passed non-canonical regexes (e.g., `^[0-9]{3}$`, `^009$`, `^[0-9]?$`) now assert the converter rejects them with `NonConvertibleDomainException`. Parallel positive tests use canonical-form inputs (`^[1-9][0-9]{2}$` instead of `^[0-9]{3}$`). `CastRegexTransformerTest` adds concrete accept/reject probes for the returned regex (e.g., `getAutomaton().run("100")`), pins the canonical behavior of `CAST(int AS VARCHAR)` with a canonical 3-digit output, and documents the non-canonical fallback path. `ProjectPullUpRewriterTest` asserts row-type field-name and type preservation across pull-ups, and pins the rewritten join condition to `=($1, $3)` for the case described above. ## Verification Full module pipeline (`build`, `javadoc`, `spotlessJavaCheck`) passes; all tests in the module pass. --- coral-data-generation/build.gradle | 2 + .../coral/datagen/domain/AccessPath.java | 187 +++ .../domain/DomainInferenceProgram.java | 297 +++- .../coral/datagen/domain/IntegerDomain.java | 14 +- .../domain/RegexToIntegerDomainConverter.java | 63 +- .../transformer/AbsIntegerTransformer.java | 99 ++ .../transformer/CastRegexTransformer.java | 22 +- .../transformer/ConcatRegexTransformer.java | 120 ++ .../transformer/FieldAccessTransformer.java | 52 + .../domain/transformer/ItemTransformer.java | 63 + .../transformer/LowerRegexTransformer.java | 3 +- .../transformer/MinusIntegerTransformer.java | 89 ++ .../transformer/NegateIntegerTransformer.java | 63 + .../transformer/PlusIntegerTransformer.java | 9 +- .../SubstringRegexTransformer.java | 13 +- .../transformer/TimesIntegerTransformer.java | 9 +- .../transformer/TrimRegexTransformer.java | 95 ++ .../transformer/UpperRegexTransformer.java | 84 + .../datagen/rel/ProjectPullUpRewriter.java | 102 +- .../domain/CastRegexTransformerTest.java | 74 +- .../datagen/domain/IntegerDomainTest.java | 18 +- .../domain/IntegerTransformerTest.java | 224 ++- .../RegexDomainInferenceProgramTest.java | 1385 ++++++++++++++++- .../RegexToIntegerDomainConverterTest.java | 306 ++-- .../datagen/domain/RegexTransformerTest.java | 136 ++ .../rel/CanonicalPredicateExtractorTest.java | 8 +- .../rel/ProjectPullUpRewriterTest.java | 56 +- 27 files changed, 3271 insertions(+), 322 deletions(-) create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/AccessPath.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/AbsIntegerTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ConcatRegexTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/FieldAccessTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ItemTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/MinusIntegerTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/NegateIntegerTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TrimRegexTransformer.java create mode 100644 coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/UpperRegexTransformer.java create mode 100644 coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexTransformerTest.java diff --git a/coral-data-generation/build.gradle b/coral-data-generation/build.gradle index 83f06aa5c..395fb0f55 100644 --- a/coral-data-generation/build.gradle +++ b/coral-data-generation/build.gradle @@ -1,3 +1,5 @@ +apply plugin: 'java-library' + dependencies { implementation project(path: ':coral-common') implementation project(path: ':coral-hive') diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/AccessPath.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/AccessPath.java new file mode 100644 index 000000000..f05b01fbf --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/AccessPath.java @@ -0,0 +1,187 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + + +/** + * Path to a value reachable from a top-level column via struct field, map key, and array + * index accesses. + * + * Examples: + * - Flat column: AccessPath.of(0) represents column $0 + * - Struct field access: AccessPath.ofField(3, "name") represents $3.name + * - Array element access: AccessPath.ofArrayIndex(2, 1) represents $2[1] + * - Map element access: AccessPath.ofMapKey(4, "key1") represents $4['key1'] + * - Nested: AccessPath.of(5).append(arrayIndex(1)).append(field("name")) represents $5[1].name + */ +public class AccessPath { + private final int columnIndex; + private final List path; + + /** + * Represents a single step in a nested access path. + */ + public static class PathElement { + public enum Kind { + FIELD, + ARRAY_INDEX, + MAP_KEY + } + + private final Kind kind; + private final String fieldName; + private final int arrayIndex; + private final String mapKey; + + private PathElement(Kind kind, String fieldName, int arrayIndex, String mapKey) { + this.kind = kind; + this.fieldName = fieldName; + this.arrayIndex = arrayIndex; + this.mapKey = mapKey; + } + + public static PathElement field(String name) { + return new PathElement(Kind.FIELD, name, -1, null); + } + + public static PathElement arrayIndex(int index) { + return new PathElement(Kind.ARRAY_INDEX, null, index, null); + } + + public static PathElement mapKey(String key) { + return new PathElement(Kind.MAP_KEY, null, -1, key); + } + + public Kind getKind() { + return kind; + } + + public String getFieldName() { + return fieldName; + } + + public int getArrayIndex() { + return arrayIndex; + } + + public String getMapKey() { + return mapKey; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + PathElement that = (PathElement) o; + return kind == that.kind && arrayIndex == that.arrayIndex && Objects.equals(fieldName, that.fieldName) + && Objects.equals(mapKey, that.mapKey); + } + + @Override + public int hashCode() { + return Objects.hash(kind, fieldName, arrayIndex, mapKey); + } + + @Override + public String toString() { + switch (kind) { + case FIELD: + return "." + fieldName; + case ARRAY_INDEX: + return "[" + arrayIndex + "]"; + case MAP_KEY: + return "['" + mapKey + "']"; + default: + return "?"; + } + } + } + + private AccessPath(int columnIndex, List path) { + this.columnIndex = columnIndex; + this.path = Collections.unmodifiableList(path); + } + + /** + * Creates a path for a flat column reference. + */ + public static AccessPath of(int colIdx) { + return new AccessPath(colIdx, Collections.emptyList()); + } + + /** + * Creates a path for a struct field access (e.g., $3.name). + */ + public static AccessPath ofField(int colIdx, String fieldName) { + return new AccessPath(colIdx, Collections.singletonList(PathElement.field(fieldName))); + } + + /** + * Creates a path for an array element access (e.g., $2[1]). + */ + public static AccessPath ofArrayIndex(int colIdx, int index) { + return new AccessPath(colIdx, Collections.singletonList(PathElement.arrayIndex(index))); + } + + /** + * Creates a path for a map element access (e.g., $4['key1']). + */ + public static AccessPath ofMapKey(int colIdx, String key) { + return new AccessPath(colIdx, Collections.singletonList(PathElement.mapKey(key))); + } + + /** + * Creates a new AccessPath by appending a path element to this path. + */ + public AccessPath append(PathElement element) { + List newPath = new ArrayList<>(path); + newPath.add(element); + return new AccessPath(columnIndex, newPath); + } + + public int getColumnIndex() { + return columnIndex; + } + + public List getPath() { + return path; + } + + public boolean isFlat() { + return path.isEmpty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + AccessPath that = (AccessPath) o; + return columnIndex == that.columnIndex && path.equals(that.path); + } + + @Override + public int hashCode() { + return Objects.hash(columnIndex, path); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("$" + columnIndex); + for (PathElement elem : path) { + sb.append(elem); + } + return sb.toString(); + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/DomainInferenceProgram.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/DomainInferenceProgram.java index 9450656c8..a80e35dcc 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/DomainInferenceProgram.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/DomainInferenceProgram.java @@ -6,16 +6,32 @@ package com.linkedin.coral.datagen.domain; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import com.linkedin.coral.datagen.domain.transformer.AbsIntegerTransformer; import com.linkedin.coral.datagen.domain.transformer.CastRegexTransformer; +import com.linkedin.coral.datagen.domain.transformer.ConcatRegexTransformer; +import com.linkedin.coral.datagen.domain.transformer.FieldAccessTransformer; +import com.linkedin.coral.datagen.domain.transformer.ItemTransformer; import com.linkedin.coral.datagen.domain.transformer.LowerRegexTransformer; +import com.linkedin.coral.datagen.domain.transformer.MinusIntegerTransformer; +import com.linkedin.coral.datagen.domain.transformer.NegateIntegerTransformer; import com.linkedin.coral.datagen.domain.transformer.PlusIntegerTransformer; import com.linkedin.coral.datagen.domain.transformer.SubstringRegexTransformer; import com.linkedin.coral.datagen.domain.transformer.TimesIntegerTransformer; +import com.linkedin.coral.datagen.domain.transformer.TrimRegexTransformer; +import com.linkedin.coral.datagen.domain.transformer.UpperRegexTransformer; /** @@ -54,8 +70,17 @@ public DomainInferenceProgram(List transformers) { * This is the recommended way to create an instance for production use. */ public static DomainInferenceProgram withDefaultTransformers() { - return new DomainInferenceProgram(Arrays.asList(new LowerRegexTransformer(), new SubstringRegexTransformer(), - new PlusIntegerTransformer(), new TimesIntegerTransformer(), new CastRegexTransformer())); + return new DomainInferenceProgram(Arrays.asList( + // String/regex transformers + new LowerRegexTransformer(), new UpperRegexTransformer(), new SubstringRegexTransformer(), + new ConcatRegexTransformer(), new TrimRegexTransformer(), + // Integer transformers + new PlusIntegerTransformer(), new MinusIntegerTransformer(), new TimesIntegerTransformer(), + new NegateIntegerTransformer(), new AbsIntegerTransformer(), + // Cross-domain + new CastRegexTransformer(), + // Structural pass-throughs + new ItemTransformer(), new FieldAccessTransformer())); } /** @@ -72,6 +97,19 @@ public static DomainInferenceProgram withDefaultTransformers() { return outputDomain; } + // Base case: struct field access on a column ref (e.g., $3.name) + if (expr instanceof RexFieldAccess) { + RexFieldAccess fa = (RexFieldAccess) expr; + if (fa.getReferenceExpr() instanceof RexInputRef) { + return outputDomain; + } + } + + // Base case: ITEM access on a column ref (e.g., ITEM($2, 1) for arrays, ITEM($4, 'key') for maps) + if (isTerminalItemAccess(expr)) { + return outputDomain; + } + // Find a transformer that can handle this expression for (DomainTransformer transformer : transformers) { if (transformer.canHandle(expr) && transformer.isVariableOperandPositionValid(expr)) { @@ -93,6 +131,261 @@ public static DomainInferenceProgram withDefaultTransformers() { throw new IllegalStateException("No applicable transformer for expression: " + expr); } + /** + * Derives the domain constraint on the input variable from a comparison predicate. + * + *

Handles {@code =}, {@code <}, {@code >}, {@code <=}, {@code >=} comparisons where + * one side is a literal and the other is an expression containing the variable. + * + * @param predicate the comparison RexCall (e.g., {@code expr > 5}) + * @return the refined domain constraint on the input variable + */ + public Domain deriveInputDomainFromPredicate(RexCall predicate) { + SqlOperator op = predicate.getOperator(); + RexNode lhs = predicate.getOperands().get(0); + RexNode rhs = predicate.getOperands().get(1); + + // Calcite often represents negative literals as RexCall(UNARY_MINUS, positiveLiteral). + // Unwrap that here so the rest of the predicate handler can treat the RHS as a literal. + boolean negateRhs = false; + if (rhs instanceof RexCall) { + RexCall rhsCall = (RexCall) rhs; + if (rhsCall.getKind() == SqlKind.MINUS_PREFIX && rhsCall.getOperands().size() == 1 + && rhsCall.getOperands().get(0) instanceof RexLiteral) { + rhs = rhsCall.getOperands().get(0); + negateRhs = true; + } + } + + if (!(rhs instanceof RexLiteral)) { + throw new IllegalArgumentException("RHS of comparison must be a literal, got: " + rhs); + } + + RexLiteral literal = (RexLiteral) rhs; + Domain outputDomain = createDomainFromComparison(op, literal, negateRhs); + return deriveInputDomain(lhs, outputDomain); + } + + /** + * Resolves domains for all access paths constrained by a list of DNF disjuncts. + * + * Each disjunct may be a single comparison or a conjunction (AND) of comparisons. + * Within a disjunct, predicates on the same path are intersected (AND semantics). + * Across disjuncts, domains for the same path are unioned (OR semantics). + * + * Example: {@code (a > 5 AND a < 10) OR (a = 20 AND b = 3)} + * produces: {@code {$0: [6,9] ∪ {20}, $1: {3}}} + * + * @param disjuncts the DNF disjuncts from {@link com.linkedin.coral.datagen.rel.DnfRewriter} + * @return map from {@link AccessPath} (root column index plus nested struct/map/array accesses) + * to the resolved domain for that path + */ + public Map> resolveAllPaths(List disjuncts) { + Map> result = new HashMap<>(); + + for (RexNode disjunct : disjuncts) { + Map> disjunctDomains = resolveDisjunct(disjunct); + + // Union with existing domains (OR semantics across disjuncts) + for (Map.Entry> entry : disjunctDomains.entrySet()) { + AccessPath colPath = entry.getKey(); + Domain domain = entry.getValue(); + if (result.containsKey(colPath)) { + result.put(colPath, unionDomains(result.get(colPath), domain)); + } else { + result.put(colPath, domain); + } + } + } + + return result; + } + + /** + * Resolves domains for all columns from a single disjunct. + * A disjunct may be a conjunction (AND) of multiple comparisons, or a single comparison. + * Predicates on the same column within a disjunct are intersected. + */ + private Map> resolveDisjunct(RexNode disjunct) { + List conjuncts = RelOptUtil.conjunctions(disjunct); + Map> pathDomains = new HashMap<>(); + + for (RexNode conjunct : conjuncts) { + if (!(conjunct instanceof RexCall)) { + continue; + } + RexCall call = (RexCall) conjunct; + + // Find which column this conjunct constrains + RexNode lhs = call.getOperands().get(0); + AccessPath colPath = findAccessPath(lhs); + if (colPath == null) { + continue; + } + + try { + Domain domain = deriveInputDomainFromPredicate(call); + if (pathDomains.containsKey(colPath)) { + // Intersect within the same disjunct (AND semantics) + pathDomains.put(colPath, intersectDomains(pathDomains.get(colPath), domain)); + } else { + pathDomains.put(colPath, domain); + } + } catch (UnsupportedOperationException | IllegalArgumentException | IllegalStateException e) { + // Skip predicates we can't resolve + } + } + + return pathDomains; + } + + /** + * Finds the AccessPath for the column referenced within an expression tree. + * Handles flat columns (RexInputRef), struct field access (RexFieldAccess), + * and array/map element access (ITEM operator). + * Returns null if no column reference is found. + */ + private AccessPath findAccessPath(RexNode expr) { + if (expr instanceof RexInputRef) { + return AccessPath.of(((RexInputRef) expr).getIndex()); + } + + // Struct field access: $3.name + if (expr instanceof RexFieldAccess) { + RexFieldAccess fa = (RexFieldAccess) expr; + AccessPath inner = findAccessPath(fa.getReferenceExpr()); + if (inner != null) { + return inner.append(AccessPath.PathElement.field(fa.getField().getName())); + } + return null; + } + + if (expr instanceof RexCall) { + RexCall call = (RexCall) expr; + + // ITEM access: ITEM($2, 1) for arrays, ITEM($4, 'key') for maps + if (ItemTransformer.isItemOperator(call) && call.getOperands().size() == 2 + && call.getOperands().get(1) instanceof RexLiteral) { + AccessPath inner = findAccessPath(call.getOperands().get(0)); + if (inner != null) { + RexLiteral indexOrKey = (RexLiteral) call.getOperands().get(1); + Object value = indexOrKey.getValue2(); + if (value instanceof Number) { + return inner.append(AccessPath.PathElement.arrayIndex(((Number) value).intValue())); + } else { + return inner.append(AccessPath.PathElement.mapKey(value.toString())); + } + } + } + + // General case: search operands for a column reference + for (RexNode operand : call.getOperands()) { + AccessPath path = findAccessPath(operand); + if (path != null) { + return path; + } + } + } + + return null; + } + + /** + * Checks if an expression is a terminal ITEM access directly on a RexInputRef. + */ + private boolean isTerminalItemAccess(RexNode expr) { + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + return ItemTransformer.isItemOperator(call) && call.getOperands().size() == 2 + && call.getOperands().get(0) instanceof RexInputRef && call.getOperands().get(1) instanceof RexLiteral; + } + + /** + * Unions two domains of the same type. + */ + @SuppressWarnings("unchecked") + private Domain unionDomains(Domain a, Domain b) { + if (a instanceof RegexDomain && b instanceof RegexDomain) { + return ((RegexDomain) a).union((RegexDomain) b); + } else if (a instanceof IntegerDomain && b instanceof IntegerDomain) { + return ((IntegerDomain) a).union((IntegerDomain) b); + } + // Mixed domain types: return the broader one (can't union across types) + return a; + } + + /** + * Intersects two domains of the same type. + */ + @SuppressWarnings("unchecked") + private Domain intersectDomains(Domain a, Domain b) { + if (a instanceof RegexDomain && b instanceof RegexDomain) { + return ((RegexDomain) a).intersect((RegexDomain) b); + } else if (a instanceof IntegerDomain && b instanceof IntegerDomain) { + return ((IntegerDomain) a).intersect((IntegerDomain) b); + } + // Mixed domain types: return the more constrained one + return a; + } + + /** + * Creates an output domain from a comparison operator and literal value. + * For EQUALS, produces a singleton domain. + * For inequalities, produces a range domain. + */ + private Domain createDomainFromComparison(SqlOperator op, RexLiteral literal, boolean negate) { + String rhsValue = literal.getValue2().toString(); + SqlKind kind = op.getKind(); + + if (isNumericType(literal.getType().getSqlTypeName())) { + long numericValue = Long.parseLong(rhsValue); + if (negate) { + numericValue = -numericValue; + } + + switch (kind) { + case EQUALS: + return IntegerDomain.of(numericValue); + case GREATER_THAN: + return IntegerDomain.of(numericValue + 1, Long.MAX_VALUE); + case GREATER_THAN_OR_EQUAL: + return IntegerDomain.of(numericValue, Long.MAX_VALUE); + case LESS_THAN: + return IntegerDomain.of(Long.MIN_VALUE, numericValue - 1); + case LESS_THAN_OR_EQUAL: + return IntegerDomain.of(Long.MIN_VALUE, numericValue); + default: + break; + } + } + + if (kind == SqlKind.EQUALS) { + // String literals are never produced by the UNARY_MINUS unwrap above. + return RegexDomain.literal(rhsValue); + } + + throw new UnsupportedOperationException("Comparison operator " + op + " (kind=" + kind + ") not supported for type " + + literal.getType().getSqlTypeName()); + } + + private boolean isNumericType(org.apache.calcite.sql.type.SqlTypeName typeName) { + switch (typeName) { + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case DECIMAL: + case FLOAT: + case REAL: + case DOUBLE: + return true; + default: + return false; + } + } + /** * Creates an empty domain matching the type of the given domain. */ diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/IntegerDomain.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/IntegerDomain.java index 975500fa7..0fa7fde4f 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/IntegerDomain.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/IntegerDomain.java @@ -65,8 +65,10 @@ public boolean overlaps(Interval other) { } public boolean isAdjacent(Interval other) { - return this.max != Long.MAX_VALUE && this.max + 1 == other.min - || other.max != Long.MAX_VALUE && other.max + 1 == this.min; + // Guard against overflow: MAX + 1 wraps to MIN, which would falsely match + boolean thisAdjacentToOther = this.max != Long.MAX_VALUE && this.max + 1 == other.min; + boolean otherAdjacentToThis = other.max != Long.MAX_VALUE && other.max + 1 == this.min; + return thisAdjacentToOther || otherAdjacentToThis; } public Interval merge(Interval other) { @@ -267,6 +269,14 @@ public IntegerDomain add(long constant) { return new IntegerDomain(shifted); } + /** + * Negates all values in this domain. + * For each interval [a, b], produces [-b, -a]. + */ + public IntegerDomain negate() { + return multiply(-1); + } + /** * Multiplies all values in this domain by a constant. */ diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverter.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverter.java index a62f522e5..3c2787169 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverter.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverter.java @@ -13,23 +13,46 @@ /** - * Converts RegexDomain patterns to IntegerDomain when the regex represents - * only bounded numeric patterns. + * Converts a {@link RegexDomain} of canonical-decimal integer strings into the corresponding + * {@link IntegerDomain}. * - * Conversion is allowed only if the automaton: - * 1. Accepts a finite language - * 2. All transitions are over digit characters [0-9] only + *

This converter intentionally answers only one question: "Given a regex whose accepted + * strings are all canonical decimal representations of integers, what is the integer set?". + * Canonical decimal means the form produced by {@code String.valueOf(int)} / SQL + * {@code CAST(integer AS VARCHAR)}: {@code "0"} or a non-zero leading digit followed by + * additional digits. Strings like {@code "000"}, {@code "01"}, {@code "-0"} are non-canonical + * and any regex that admits them is rejected. * - * Examples: - * - {@code /^2024$/} converts to {@code [2024, 2024]} - * - {@code /^[0-9]{3}$/} converts to {@code [0, 999]} - * - {@code /^19[0-9]{2}$/} converts to {@code [1900, 1999]} - * - {@code /^(10|11|12)$/} converts to {@code [10, 10], [11, 11], [12, 12]} + *

Allowed: + *

    + *
  • {@code /^2024$/} → {@code [2024, 2024]}
  • + *
  • {@code /^19[0-9]{2}$/} → {@code [1900, 1999]}
  • + *
  • {@code /^(10|11|12)$/} → {@code [10, 12]}
  • + *
  • {@code /^[0-9]$/} → {@code [0, 9]} (every accepted string is canonical)
  • + *
+ * + *

Rejected (throws {@link NonConvertibleDomainException}): + *

    + *
  • {@code /^[0-9]{3}$/} — admits {@code "000"}, {@code "001"}, … which are not canonical
  • + *
  • {@code /^009$/} — non-canonical literal
  • + *
  • {@code /^[0-9]+$/} — infinite language
  • + *
  • {@code /^abc$/} — not a subset of canonical integer strings
  • + *
+ * + *

If a caller has a non-canonical regex but wants the canonical inverse, it must intersect + * with the canonical-strings regex first; the converter will then accept the result. */ public class RegexToIntegerDomainConverter { private static final int MAX_ENUMERATION_SIZE = 5000; + /** + * Automaton accepting exactly the canonical decimal string form of every non-negative + * integer ({@code "0"}, {@code "1"}, {@code "2"}, …, {@code "10"}, …). This is the precondition + * the converter requires its inputs to satisfy. + */ + private static final Automaton CANONICAL_INTEGER_STRINGS = new RegexDomain("^(0|[1-9][0-9]*)$").getAutomaton(); + /** * Checks if the given RegexDomain can be converted to IntegerDomain. * @@ -38,7 +61,7 @@ public class RegexToIntegerDomainConverter { */ public boolean isConvertible(RegexDomain regexDomain) { Automaton a = regexDomain.getAutomaton(); - return !a.isEmpty() && a.isFinite() && isDigitOnly(a); + return !a.isEmpty() && a.isFinite() && a.subsetOf(CANONICAL_INTEGER_STRINGS); } /** @@ -57,8 +80,8 @@ public IntegerDomain convert(RegexDomain regexDomain) { if (!a.isFinite()) { throw new NonConvertibleDomainException("Infinite language"); } - if (!isDigitOnly(a)) { - throw new NonConvertibleDomainException("Non-digit characters in automaton"); + if (!a.subsetOf(CANONICAL_INTEGER_STRINGS)) { + throw new NonConvertibleDomainException("Regex accepts strings outside canonical decimal form"); } // Try enumeration first @@ -80,20 +103,6 @@ public IntegerDomain convert(RegexDomain regexDomain) { } } - /** - * Checks that all transitions in the automaton are over digit characters only. - */ - private boolean isDigitOnly(Automaton a) { - for (State s : a.getStates()) { - for (Transition t : s.getTransitions()) { - if (t.getMin() < '0' || t.getMax() > '9') { - return false; - } - } - } - return true; - } - /** * Computes the minimum numeric value accepted by the automaton. * Uses BFS to find the shortest accepted string (which for digit-only strings diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/AbsIntegerTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/AbsIntegerTransformer.java new file mode 100644 index 000000000..66671fbec --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/AbsIntegerTransformer.java @@ -0,0 +1,99 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; +import com.linkedin.coral.datagen.domain.IntegerDomain; + + +/** + * Integer domain transformer for ABS (absolute value) operations. + * + * Inverts ABS(x) = output_value + * to produce input constraint x = output_value OR x = -output_value + * + * Example: + * ABS(x) = 5 + * produces input constraint: x in {-5} ∪ {5} + * + * ABS(x) = 0 + * produces input constraint: x in {0} + */ +public class AbsIntegerTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + return call.getOperator() == SqlStdOperatorTable.ABS && call.getOperands().size() == 1; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode operand = call.getOperands().get(0); + return operand instanceof RexInputRef || operand instanceof RexCall || operand instanceof RexFieldAccess; + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexCall call = (RexCall) expr; + return call.getOperands().get(0); + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + if (!(outputDomain instanceof IntegerDomain)) { + throw new IllegalArgumentException( + getClass().getSimpleName() + " expects IntegerDomain but got " + outputDomain.getClass().getSimpleName()); + } + + IntegerDomain intDomain = (IntegerDomain) outputDomain; + + // ABS(x) = v means x = v OR x = -v (for v >= 0) + // For each interval [a, b] where a >= 0: input can be [a, b] or [-b, -a] + List inputIntervals = new ArrayList<>(); + + for (IntegerDomain.Interval interval : intDomain.getIntervals()) { + long min = interval.getMin(); + long max = interval.getMax(); + + // ABS output must be >= 0; skip negative parts + if (max < 0) { + continue; + } + if (min < 0) { + min = 0; + } + + // Positive side: [min, max] + inputIntervals.add(new IntegerDomain.Interval(min, max)); + + // Negative side: [-max, -min] + if (max > 0 || min > 0) { + inputIntervals.add(new IntegerDomain.Interval(-max, -min)); + } + } + + if (inputIntervals.isEmpty()) { + return IntegerDomain.empty(); + } + + return IntegerDomain.of(inputIntervals); + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/CastRegexTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/CastRegexTransformer.java index 1ec6485e3..a1fc1cb5d 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/CastRegexTransformer.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/CastRegexTransformer.java @@ -10,6 +10,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlKind; @@ -37,6 +38,9 @@ */ public class CastRegexTransformer implements DomainTransformer { + /** Any integer-shaped string (canonical or with leading zeros). Used by the fallback path. */ + private static final RegexDomain ANY_INTEGER_STRING = new RegexDomain("^-?[0-9]+$"); + private final RegexToIntegerDomainConverter regexToIntegerConverter; public CastRegexTransformer() { @@ -56,7 +60,7 @@ public boolean canHandle(RexNode expr) { public boolean isVariableOperandPositionValid(RexNode expr) { RexCall call = (RexCall) expr; RexNode operand = call.getOperands().get(0); - return operand instanceof RexInputRef || operand instanceof RexCall; + return operand instanceof RexInputRef || operand instanceof RexCall || operand instanceof RexFieldAccess; } @Override @@ -117,21 +121,21 @@ public RexNode getChildForVariable(RexNode expr) { if (isIntegerType(sourceTypeName) && isStringType(targetTypeName) && outputDomain instanceof RegexDomain) { RegexDomain outputRegex = (RegexDomain) outputDomain; - // Try to convert RegexDomain to IntegerDomain + // SQL CAST(integer AS VARCHAR) produces canonical decimal: no leading zeros and no "-0". + // RegexToIntegerDomainConverter accepts only canonical-decimal-string regexes by contract; + // it rejects (NonConvertibleDomainException) anything else. Real inference at this branch + // produces canonical outputs (literal "500", alternations like "100|999", etc.). Non-canonical + // shapes fall through to the regex-intersection fallback below. try { if (regexToIntegerConverter.isConvertible(outputRegex)) { - IntegerDomain intDomain = regexToIntegerConverter.convert(outputRegex); - // Return the IntegerDomain as input constraint - return intDomain; + return regexToIntegerConverter.convert(outputRegex); } } catch (NonConvertibleDomainException e) { // Fall through to default handling } - // Intersect with valid integer format - String integerFormatRegex = "^-?[0-9]+$"; - RegexDomain integerFormatDomain = new RegexDomain(integerFormatRegex); - return outputRegex.intersect(integerFormatDomain); + // Fall back: intersection of the original regex with any integer-shaped string. + return outputRegex.intersect(ANY_INTEGER_STRING); } // ========== CAST(integer AS string) with IntegerDomain output ========== diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ConcatRegexTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ConcatRegexTransformer.java new file mode 100644 index 000000000..2f3f4ee5b --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ConcatRegexTransformer.java @@ -0,0 +1,120 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; +import com.linkedin.coral.datagen.domain.RegexDomain; + + +/** + * Regex-based transformer for CONCAT operations. + * + * Inverts CONCAT(x, literal) = output or CONCAT(literal, x) = output + * by stripping the known literal prefix/suffix from the output constraint. + * + * Example: + * CONCAT(x, 'World') = 'HelloWorld' + * produces input constraint: literal 'Hello' + * + * CONCAT('Hello', x) = 'HelloWorld' + * produces input constraint: literal 'World' + */ +public class ConcatRegexTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + if (call.getOperands().size() != 2) { + return false; + } + SqlOperator op = call.getOperator(); + // Standard SQL ||, or Hive's concat() function (which arrives as OTHER_FUNCTION). + return op == SqlStdOperatorTable.CONCAT || "concat".equalsIgnoreCase(op.getName()); + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); + boolean rightVar = (right instanceof RexInputRef || right instanceof RexCall || right instanceof RexFieldAccess); + boolean leftLit = left instanceof RexLiteral; + boolean rightLit = right instanceof RexLiteral; + return (leftVar && rightLit) || (rightVar && leftLit); + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + if (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess) { + return left; + } else { + return right; + } + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + if (!(outputDomain instanceof RegexDomain)) { + throw new IllegalArgumentException( + "ConcatRegexTransformer expects RegexDomain but got " + outputDomain.getClass().getSimpleName()); + } + + RegexDomain outputRegex = (RegexDomain) outputDomain; + RexCall call = (RexCall) expr; + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); + + // Only handle literal concatenation for now + if (!outputRegex.isLiteral()) { + return outputRegex; + } + + String outputLiteral = outputRegex.getLiteralValue(); + + if (leftVar) { + // CONCAT(x, suffix) = output => x = output with suffix stripped + RexLiteral suffixLiteral = (RexLiteral) right; + String suffix = suffixLiteral.getValueAs(String.class); + + if (outputLiteral.endsWith(suffix)) { + String prefix = outputLiteral.substring(0, outputLiteral.length() - suffix.length()); + return RegexDomain.literal(prefix); + } else { + // Contradiction: output doesn't end with the suffix + return RegexDomain.empty(); + } + } else { + // CONCAT(prefix, x) = output => x = output with prefix stripped + RexLiteral prefixLiteral = (RexLiteral) left; + String prefix = prefixLiteral.getValueAs(String.class); + + if (outputLiteral.startsWith(prefix)) { + String remainder = outputLiteral.substring(prefix.length()); + return RegexDomain.literal(remainder); + } else { + // Contradiction: output doesn't start with the prefix + return RegexDomain.empty(); + } + } + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/FieldAccessTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/FieldAccessTransformer.java new file mode 100644 index 000000000..3047e508d --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/FieldAccessTransformer.java @@ -0,0 +1,52 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; + + +/** + * Domain transformer for struct field access (RexFieldAccess). + * + * Handles nested struct access patterns where the inner expression is not a bare + * RexInputRef (e.g., accessing a field on the result of an ITEM call for array-of-structs). + * Terminal field access (directly on a RexInputRef) is handled by the base case + * in DomainInferenceProgram. + * + * Field access simply selects a field — the constraint propagates inward unchanged. + */ +public class FieldAccessTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + return expr instanceof RexFieldAccess; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexFieldAccess fa = (RexFieldAccess) expr; + RexNode ref = fa.getReferenceExpr(); + return ref instanceof RexInputRef || ref instanceof RexCall || ref instanceof RexFieldAccess; + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexFieldAccess fa = (RexFieldAccess) expr; + return fa.getReferenceExpr(); + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + // Field access just selects a field — pass the constraint through unchanged + return outputDomain; + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ItemTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ItemTransformer.java new file mode 100644 index 000000000..dc7db49c0 --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/ItemTransformer.java @@ -0,0 +1,63 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; + + +/** + * Domain transformer for the ITEM operator (array/map element access). + * + * Handles nested access patterns like ITEM(ITEM($5, 1), 'name') by passing + * the domain constraint through to the inner expression. Terminal ITEM access + * (ITEM on a bare RexInputRef) is handled by the base case in DomainInferenceProgram. + * + * The ITEM operator simply selects an element — the constraint propagates inward unchanged. + */ +public class ItemTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + return isItemOperator(call) && call.getOperands().size() == 2; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode operand0 = call.getOperands().get(0); + RexNode operand1 = call.getOperands().get(1); + // Variable must be in operand 0 (the collection); operand 1 must be a literal index/key + return (operand0 instanceof RexInputRef || operand0 instanceof RexCall || operand0 instanceof RexFieldAccess) + && operand1 instanceof RexLiteral; + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexCall call = (RexCall) expr; + return call.getOperands().get(0); + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + // ITEM just selects an element — pass the constraint through unchanged + return outputDomain; + } + + public static boolean isItemOperator(RexCall call) { + return "ITEM".equalsIgnoreCase(call.getOperator().getName()); + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/LowerRegexTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/LowerRegexTransformer.java index 4e53e0b62..f2aae3cac 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/LowerRegexTransformer.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/LowerRegexTransformer.java @@ -9,6 +9,7 @@ import java.util.List; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -41,7 +42,7 @@ public boolean canHandle(RexNode expr) { public boolean isVariableOperandPositionValid(RexNode expr) { RexCall call = (RexCall) expr; RexNode arg = call.getOperands().get(0); - return arg instanceof RexInputRef || arg instanceof RexCall; + return arg instanceof RexInputRef || arg instanceof RexCall || arg instanceof RexFieldAccess; } @Override diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/MinusIntegerTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/MinusIntegerTransformer.java new file mode 100644 index 000000000..6cabd1a82 --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/MinusIntegerTransformer.java @@ -0,0 +1,89 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; +import com.linkedin.coral.datagen.domain.IntegerDomain; + + +/** + * Integer domain transformer for MINUS (subtraction) operations. + * + * Inverts x - literal = output_value + * to produce input constraint x = output_value + literal + * + * Also handles literal - x = output_value + * producing input constraint x = literal - output_value + * + * Example: + * x - 3 = 7 + * produces input constraint: x = 10 + */ +public class MinusIntegerTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + return expr instanceof RexCall && ((RexCall) expr).getOperator() == SqlStdOperatorTable.MINUS; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); + boolean rightVar = (right instanceof RexInputRef || right instanceof RexCall || right instanceof RexFieldAccess); + boolean leftLit = left instanceof RexLiteral; + boolean rightLit = right instanceof RexLiteral; + return (leftVar && rightLit) || (rightVar && leftLit); + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + if (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess) { + return left; + } else { + return right; + } + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + if (!(outputDomain instanceof IntegerDomain)) { + throw new IllegalArgumentException( + getClass().getSimpleName() + " expects IntegerDomain but got " + outputDomain.getClass().getSimpleName()); + } + + IntegerDomain intDomain = (IntegerDomain) outputDomain; + RexCall call = (RexCall) expr; + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); + + RexLiteral literalNode = (RexLiteral) (leftVar ? right : left); + long literal = literalNode.getValueAs(Long.class); + + if (leftVar) { + // x - literal = output => x = output + literal + return intDomain.add(literal); + } else { + // literal - x = output => x = literal - output + // For each interval [a, b] in output: x in [literal - b, literal - a] + return intDomain.negate().add(literal); + } + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/NegateIntegerTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/NegateIntegerTransformer.java new file mode 100644 index 000000000..9675be9b6 --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/NegateIntegerTransformer.java @@ -0,0 +1,63 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; +import com.linkedin.coral.datagen.domain.IntegerDomain; + + +/** + * Integer domain transformer for UNARY_MINUS (negation) operations. + * + * Inverts -(x) = output_value + * to produce input constraint x = -(output_value) + * + * Example: + * -(x) = -5 + * produces input constraint: x = 5 + */ +public class NegateIntegerTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + return call.getOperator() == SqlStdOperatorTable.UNARY_MINUS && call.getOperands().size() == 1; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode operand = call.getOperands().get(0); + return operand instanceof RexInputRef || operand instanceof RexCall || operand instanceof RexFieldAccess; + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexCall call = (RexCall) expr; + return call.getOperands().get(0); + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + if (!(outputDomain instanceof IntegerDomain)) { + throw new IllegalArgumentException( + getClass().getSimpleName() + " expects IntegerDomain but got " + outputDomain.getClass().getSimpleName()); + } + + IntegerDomain intDomain = (IntegerDomain) outputDomain; + return intDomain.negate(); + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/PlusIntegerTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/PlusIntegerTransformer.java index 8456f2427..bf5b59f8f 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/PlusIntegerTransformer.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/PlusIntegerTransformer.java @@ -6,6 +6,7 @@ package com.linkedin.coral.datagen.domain.transformer; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -43,8 +44,8 @@ public boolean isVariableOperandPositionValid(RexNode expr) { RexCall call = (RexCall) expr; RexNode left = call.getOperands().get(0); RexNode right = call.getOperands().get(1); - boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall); - boolean rightVar = (right instanceof RexInputRef || right instanceof RexCall); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); + boolean rightVar = (right instanceof RexInputRef || right instanceof RexCall || right instanceof RexFieldAccess); boolean leftLit = left instanceof RexLiteral; boolean rightLit = right instanceof RexLiteral; return (leftVar && rightLit) || (rightVar && leftLit); @@ -55,7 +56,7 @@ public RexNode getChildForVariable(RexNode expr) { RexCall call = (RexCall) expr; RexNode left = call.getOperands().get(0); RexNode right = call.getOperands().get(1); - if (left instanceof RexInputRef || left instanceof RexCall) { + if (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess) { return left; } else { return right; @@ -73,7 +74,7 @@ public RexNode getChildForVariable(RexNode expr) { RexCall call = (RexCall) expr; RexNode left = call.getOperands().get(0); RexNode right = call.getOperands().get(1); - boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); // Get the literal value RexLiteral literalNode = (RexLiteral) (leftVar ? right : left); diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/SubstringRegexTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/SubstringRegexTransformer.java index 1f550afd8..982ee90bf 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/SubstringRegexTransformer.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/SubstringRegexTransformer.java @@ -8,6 +8,7 @@ import java.util.Arrays; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -37,8 +38,13 @@ public class SubstringRegexTransformer implements DomainTransformer { @Override public boolean canHandle(RexNode expr) { - return expr instanceof RexCall && ((RexCall) expr).getOperator().getKind() == SqlKind.OTHER_FUNCTION - && ((RexCall) expr).getOperator().getName().equals("substr"); + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + // SUBSTRING has 3 operands: (string, start, length) + return call.getOperator().getKind() == SqlKind.OTHER_FUNCTION && call.getOperator().getName().equals("substr") + && call.getOperands().size() == 3; } @Override @@ -51,7 +57,8 @@ public boolean isVariableOperandPositionValid(RexNode expr) { RexNode lengthArg = call.getOperands().get(2); // String arg must be variable (RexInputRef) or another call (nested expression) - boolean stringIsVariable = (stringArg instanceof RexInputRef) || (stringArg instanceof RexCall); + boolean stringIsVariable = + (stringArg instanceof RexInputRef) || (stringArg instanceof RexCall) || (stringArg instanceof RexFieldAccess); // Start and length must be literals boolean startIsLiteral = startArg instanceof RexLiteral; diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TimesIntegerTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TimesIntegerTransformer.java index 1b6507b82..f9ac76cb8 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TimesIntegerTransformer.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TimesIntegerTransformer.java @@ -9,6 +9,7 @@ import java.util.List; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -48,8 +49,8 @@ public boolean isVariableOperandPositionValid(RexNode expr) { RexCall call = (RexCall) expr; RexNode left = call.getOperands().get(0); RexNode right = call.getOperands().get(1); - boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall); - boolean rightVar = (right instanceof RexInputRef || right instanceof RexCall); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); + boolean rightVar = (right instanceof RexInputRef || right instanceof RexCall || right instanceof RexFieldAccess); boolean leftLit = left instanceof RexLiteral; boolean rightLit = right instanceof RexLiteral; return (leftVar && rightLit) || (rightVar && leftLit); @@ -60,7 +61,7 @@ public RexNode getChildForVariable(RexNode expr) { RexCall call = (RexCall) expr; RexNode left = call.getOperands().get(0); RexNode right = call.getOperands().get(1); - if (left instanceof RexInputRef || left instanceof RexCall) { + if (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess) { return left; } else { return right; @@ -78,7 +79,7 @@ public RexNode getChildForVariable(RexNode expr) { RexCall call = (RexCall) expr; RexNode left = call.getOperands().get(0); RexNode right = call.getOperands().get(1); - boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall); + boolean leftVar = (left instanceof RexInputRef || left instanceof RexCall || left instanceof RexFieldAccess); // Get the literal value RexLiteral literalNode = (RexLiteral) (leftVar ? right : left); diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TrimRegexTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TrimRegexTransformer.java new file mode 100644 index 000000000..918aabdcf --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/TrimRegexTransformer.java @@ -0,0 +1,95 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; +import com.linkedin.coral.datagen.domain.RegexDomain; + +import dk.brics.automaton.Automaton; + + +/** + * Regex-based transformer for TRIM operations. + * + * Inverts TRIM(x) = output_pattern + * to produce a constraint allowing optional leading/trailing whitespace. + * + * Example: + * TRIM(x) = 'abc' + * produces input constraint: ' *abc *' + */ +public class TrimRegexTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + if (!(expr instanceof RexCall)) { + return false; + } + RexCall call = (RexCall) expr; + SqlOperator op = call.getOperator(); + int arity = call.getOperands().size(); + // Calcite's standard TRIM has 3 operands: (flag, trim_char, source). + // Hive's trim() arrives as an OTHER_FUNCTION with a single source operand. + boolean isCalciteStandard = op == SqlStdOperatorTable.TRIM && arity == 3; + boolean isHiveStyle = "trim".equalsIgnoreCase(op.getName()) && arity == 1; + return isCalciteStandard || isHiveStyle; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode source = sourceOperand(call); + return source instanceof RexInputRef || source instanceof RexCall || source instanceof RexFieldAccess; + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + return sourceOperand((RexCall) expr); + } + + private static RexNode sourceOperand(RexCall call) { + // Calcite's standard TRIM puts source at operand 2; Hive's 1-operand trim puts it at 0. + return call.getOperands().size() == 3 ? call.getOperands().get(2) : call.getOperands().get(0); + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + if (!(outputDomain instanceof RegexDomain)) { + throw new IllegalArgumentException( + "TrimRegexTransformer expects RegexDomain but got " + outputDomain.getClass().getSimpleName()); + } + + RegexDomain outputRegex = (RegexDomain) outputDomain; + + // Wrap the inner pattern with optional leading and trailing spaces so any + // value matching it after trimming becomes valid before trimming. + Automaton optionalSpaces = Automaton.makeChar(' ').repeat(); + + if (outputRegex.isLiteral()) { + String literalValue = outputRegex.getLiteralValue(); + List parts = new ArrayList<>(); + for (char c : literalValue.toCharArray()) { + parts.add(Automaton.makeChar(c)); + } + Automaton inner = parts.isEmpty() ? Automaton.makeEmptyString() : Automaton.concatenate(parts); + return new RegexDomain(Automaton.concatenate(java.util.Arrays.asList(optionalSpaces, inner, optionalSpaces))); + } + + // For complex patterns, conservatively return as-is. + return outputRegex; + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/UpperRegexTransformer.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/UpperRegexTransformer.java new file mode 100644 index 000000000..28d23294e --- /dev/null +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/domain/transformer/UpperRegexTransformer.java @@ -0,0 +1,84 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain.transformer; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; + +import com.linkedin.coral.datagen.domain.Domain; +import com.linkedin.coral.datagen.domain.DomainTransformer; +import com.linkedin.coral.datagen.domain.RegexDomain; + +import dk.brics.automaton.Automaton; + + +/** + * Regex-based transformer for UPPER operations. + * + * Inverts UPPER(input) = output_pattern + * to produce a case-insensitive regex constraint on the input. + * + * Example: + * UPPER(input) = "ABC" + * produces input constraint: [aA][bB][cC] + */ +public class UpperRegexTransformer implements DomainTransformer { + + @Override + public boolean canHandle(RexNode expr) { + return expr instanceof RexCall && ((RexCall) expr).getOperator() == SqlStdOperatorTable.UPPER; + } + + @Override + public boolean isVariableOperandPositionValid(RexNode expr) { + RexCall call = (RexCall) expr; + RexNode arg = call.getOperands().get(0); + return arg instanceof RexInputRef || arg instanceof RexCall || arg instanceof RexFieldAccess; + } + + @Override + public RexNode getChildForVariable(RexNode expr) { + RexCall call = (RexCall) expr; + return call.getOperands().get(0); + } + + @Override + public Domain refineInputDomain(RexNode expr, Domain outputDomain) { + if (!(outputDomain instanceof RegexDomain)) { + throw new IllegalArgumentException( + "UpperRegexTransformer expects RegexDomain but got " + outputDomain.getClass().getSimpleName()); + } + + RegexDomain outputRegex = (RegexDomain) outputDomain; + + // If output is a literal uppercase string, make it case-insensitive + if (outputRegex.isLiteral()) { + String literalValue = outputRegex.getLiteralValue(); + List parts = new ArrayList<>(); + + for (char c : literalValue.toCharArray()) { + if (Character.isLetter(c)) { + char lower = Character.toLowerCase(c); + char upper = Character.toUpperCase(c); + parts.add(Automaton.makeChar(lower).union(Automaton.makeChar(upper))); + } else { + parts.add(Automaton.makeChar(c)); + } + } + + return new RegexDomain(Automaton.concatenate(parts)); + } + + // For complex patterns, return as-is + return outputRegex; + } +} diff --git a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriter.java b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriter.java index 8400c9fb8..d2a849004 100644 --- a/coral-data-generation/src/main/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriter.java +++ b/coral-data-generation/src/main/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriter.java @@ -193,27 +193,21 @@ private static RelNode pullProjectAboveFilter(Filter filter, Project project) { * where cond' has Project expressions inlined. */ private static RelNode pullProjectsAboveJoin(Join join, Project leftProj, Project rightProj) { - RelNode newLeft = join.getLeft(); - RelNode newRight = join.getRight(); - RexNode newCondition = join.getCondition(); - - // Capture left field count before reassigning newLeft (item #5 fix) - int leftFieldCount = - (leftProj != null) ? leftProj.getRowType().getFieldCount() : newLeft.getRowType().getFieldCount(); - - // Inline left Project if present - if (leftProj != null) { - newCondition = inlineLeftSide(newCondition, leftProj.getProjects(), leftFieldCount); - newLeft = leftProj.getInput(); - } - - // Inline right Project if present - if (rightProj != null) { - newCondition = inlineRightSide(newCondition, rightProj.getProjects(), leftFieldCount); - newRight = rightProj.getInput(); - } - - // Create new Join with inlined condition + RelNode newLeft = (leftProj != null) ? leftProj.getInput() : join.getLeft(); + RelNode newRight = (rightProj != null) ? rightProj.getInput() : join.getRight(); + + // The condition is currently expressed against the old frame: + // [ leftProj-output (or newLeft) | rightProj-output (or newRight) ] + // We need to remap it to the new frame: + // [ newLeft | newRight ] + // The left boundary moves from oldLeftCount to newLeftCount; right-side refs must shift + // by (newLeftCount - oldLeftCount) when no rightProj absorbs them. + int oldLeftCount = + (leftProj != null) ? leftProj.getRowType().getFieldCount() : join.getLeft().getRowType().getFieldCount(); + int newLeftCount = newLeft.getRowType().getFieldCount(); + RexNode newCondition = remapJoinCondition(join.getCondition(), leftProj, rightProj, oldLeftCount, newLeftCount); + + // Create new Join with remapped condition Join newJoin = join.copy(join.getTraitSet(), newCondition, newLeft, newRight, join.getJoinType(), join.isSemiJoinDone()); @@ -233,12 +227,10 @@ private static RelNode pullProjectsAboveJoin(Join join, Project leftProj, Projec return leftProj.copy(leftProj.getTraitSet(), newJoin, combinedExprs, join.getRowType()); } else if (leftProj != null) { - // Only left had Project + // Only left had Project — add pass-through for right side using newLeftCount List exprs = new ArrayList<>(); exprs.addAll(leftProj.getProjects()); - // Use the new left side's field count for pass-through offsets (item #6 fix) - int newLeftCount = newLeft.getRowType().getFieldCount(); int rightCount = newJoin.getRowType().getFieldCount() - newLeftCount; for (int i = 0; i < rightCount; i++) { exprs.add( @@ -247,11 +239,9 @@ private static RelNode pullProjectsAboveJoin(Join join, Project leftProj, Projec return leftProj.copy(leftProj.getTraitSet(), newJoin, exprs, join.getRowType()); } else { - // Only right had Project + // Only right had Project — add pass-through for left side List exprs = new ArrayList<>(); - int newLeftCount = newLeft.getRowType().getFieldCount(); - // Add pass-through for left side for (int i = 0; i < newLeftCount; i++) { exprs.add(new RexInputRef(i, newJoin.getRowType().getFieldList().get(i).getType())); } @@ -288,41 +278,43 @@ public RexNode visitInputRef(RexInputRef ref) { } /** - * Inline expressions for left side of join condition. - * Only rewrites input refs < leftFieldCount. - */ - private static RexNode inlineLeftSide(RexNode condition, List leftProjects, int leftFieldCount) { - return condition.accept(new RexShuttle() { - @Override - public RexNode visitInputRef(RexInputRef ref) { - int idx = ref.getIndex(); - if (idx < leftFieldCount && idx >= 0 && idx < leftProjects.size()) { - // Return directly - no recursive inlining - return leftProjects.get(idx); - } - return ref; - } - }); - } - - /** - * Inline expressions for right side of join condition. - * Only rewrites input refs >= leftFieldCount. + * Remaps a join condition from the old frame [leftProj-output | rightProj-output] to the + * new frame [newLeft | newRight] in a single pass: + * + *

    + *
  • Left-side refs ({@code idx < oldLeftCount}): + * inlined to {@code leftProj.projects[idx]} if present, else passed through.
  • + *
  • Right-side refs ({@code idx >= oldLeftCount}): + * inlined to {@code adjustOffsets(rightProj.projects[idx - oldLeftCount], newLeftCount)} + * if rightProj present; otherwise shifted to {@code newLeftCount + (idx - oldLeftCount)}.
  • + *
+ * + * The right-side shift is what previously broke when a left Project changed the field count + * but no right Project absorbed the right-side refs — they were left at their old offsets + * and pointed at the wrong column in the new frame. */ - private static RexNode inlineRightSide(RexNode condition, List rightProjects, int leftFieldCount) { + private static RexNode remapJoinCondition(RexNode condition, Project leftProj, Project rightProj, int oldLeftCount, + int newLeftCount) { + final List leftProjects = (leftProj != null) ? leftProj.getProjects() : null; + final List rightProjects = (rightProj != null) ? rightProj.getProjects() : null; return condition.accept(new RexShuttle() { @Override public RexNode visitInputRef(RexInputRef ref) { int idx = ref.getIndex(); - if (idx >= leftFieldCount) { - int rightIdx = idx - leftFieldCount; - if (rightIdx >= 0 && rightIdx < rightProjects.size()) { - RexNode replacement = rightProjects.get(rightIdx); - // Return directly with adjusted offsets - no recursive inlining - return adjustOffsets(replacement, leftFieldCount); + if (idx < oldLeftCount) { + // Left-side ref + if (leftProjects != null && idx >= 0 && idx < leftProjects.size()) { + return leftProjects.get(idx); } + return ref; } - return ref; + // Right-side ref + int rightIdx = idx - oldLeftCount; + if (rightProjects != null && rightIdx >= 0 && rightIdx < rightProjects.size()) { + return adjustOffsets(rightProjects.get(rightIdx), newLeftCount); + } + // No right Project to absorb this ref — shift to the new left/right boundary. + return new RexInputRef(newLeftCount + rightIdx, ref.getType()); } }); } diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/CastRegexTransformerTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/CastRegexTransformerTest.java index ea030c536..aee3168fc 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/CastRegexTransformerTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/CastRegexTransformerTest.java @@ -44,39 +44,45 @@ public void setup() { @Test public void testCastStringToIntegerWithIntegerDomainOutput() { - // CAST(x AS INTEGER) where output is IntegerDomain [100, 200] - // Input should be RegexDomain matching "100", "101", ..., "200" + // CAST(x AS INTEGER) where output is IntegerDomain [100, 200]. + // Input regex should accept canonical strings "100".."200" and reject anything outside. RexNode inputRef = rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0); RexNode castExpr = rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.INTEGER), inputRef); IntegerDomain outputDomain = IntegerDomain.of(100, 200); - Domain inputDomain = transformer.refineInputDomain(castExpr, outputDomain); - // Input should be RegexDomain assertTrue(inputDomain instanceof RegexDomain); RegexDomain regexInput = (RegexDomain) inputDomain; - // Should match numeric values in the range - assertFalse(regexInput.isEmpty()); + assertTrue(regexInput.getAutomaton().run("100"), "should accept '100'"); + assertTrue(regexInput.getAutomaton().run("150"), "should accept '150'"); + assertTrue(regexInput.getAutomaton().run("200"), "should accept '200'"); + assertFalse(regexInput.getAutomaton().run("99"), "should reject '99'"); + assertFalse(regexInput.getAutomaton().run("201"), "should reject '201'"); + assertFalse(regexInput.getAutomaton().run("abc"), "should reject non-numeric"); } @Test public void testCastStringToIntegerSmallRange() { - // CAST(x AS INTEGER) where output is [10, 12] - // Input should be RegexDomain matching "10|11|12" + // CAST(x AS INTEGER) where output is [10, 12]. Input regex should accept exactly {"10","11","12"}. RexNode inputRef = rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0); RexNode castExpr = rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.INTEGER), inputRef); IntegerDomain outputDomain = IntegerDomain.of(10, 12); - Domain inputDomain = transformer.refineInputDomain(castExpr, outputDomain); assertTrue(inputDomain instanceof RegexDomain); RegexDomain regexInput = (RegexDomain) inputDomain; - assertFalse(regexInput.isEmpty()); + + assertTrue(regexInput.getAutomaton().run("10"), "should accept '10'"); + assertTrue(regexInput.getAutomaton().run("11"), "should accept '11'"); + assertTrue(regexInput.getAutomaton().run("12"), "should accept '12'"); + assertFalse(regexInput.getAutomaton().run("9"), "should reject '9'"); + assertFalse(regexInput.getAutomaton().run("13"), "should reject '13'"); + assertFalse(regexInput.getAutomaton().run("100"), "should reject '100'"); } @Test @@ -100,27 +106,43 @@ public void testCastStringToIntegerWithRegexDomainOutput() { // ==================== Integer to String Conversion Tests ==================== @Test - public void testCastIntegerToStringWithRegexDomainOutput() { - // CAST(x AS VARCHAR) where output is RegexDomain "^[0-9]{3}$" - // Input should be IntegerDomain [0, 999] - + public void testCastIntegerToStringWithCanonicalRegexOutput() { + // CAST(x AS VARCHAR) where the output regex is canonical (^[1-9][0-9]{2}$ = 100..999). + // SQL CAST(integer AS VARCHAR) produces canonical decimal, so the inverted integer set + // is exactly the integers whose canonical decimal form lies in the regex. RexNode inputRef = rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.INTEGER), 0); RexNode castExpr = rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.VARCHAR), inputRef); - RegexDomain outputDomain = new RegexDomain("^[0-9]{3}$"); - + RegexDomain outputDomain = new RegexDomain("^[1-9][0-9]{2}$"); Domain inputDomain = transformer.refineInputDomain(castExpr, outputDomain); - // Input should be IntegerDomain (converted from regex) assertTrue(inputDomain instanceof IntegerDomain); IntegerDomain intInput = (IntegerDomain) inputDomain; - assertTrue(intInput.contains(0)); + assertTrue(intInput.contains(100)); assertTrue(intInput.contains(500)); assertTrue(intInput.contains(999)); + assertFalse(intInput.contains(99)); assertFalse(intInput.contains(1000)); } + @Test + public void testCastIntegerToStringWithNonCanonicalRegexFallsBackToRegex() { + // CAST(x AS VARCHAR) with a non-canonical output regex like ^[0-9]{3}$ (admits "000"): + // the converter contract requires canonical-only input and rejects this, so the + // transformer falls through to the regex-format fallback. The result is therefore a + // RegexDomain over integer-shaped strings, not an IntegerDomain. Real SQL inference + // does not produce non-canonical regexes at this branch; the test pins the fallback. + RexNode inputRef = rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.INTEGER), 0); + RexNode castExpr = rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.VARCHAR), inputRef); + + RegexDomain outputDomain = new RegexDomain("^[0-9]{3}$"); + Domain inputDomain = transformer.refineInputDomain(castExpr, outputDomain); + + assertTrue(inputDomain instanceof RegexDomain, "non-canonical input falls back to regex form"); + assertFalse(inputDomain.isEmpty()); + } + @Test public void testCastIntegerToStringSpecificValues() { // CAST(x AS VARCHAR) where output is RegexDomain "^(10|20|30)$" @@ -222,18 +244,24 @@ public void testCastStringToIntegerEmptyDomain() { @Test public void testCastStringToIntegerNonConvertibleRegex() { - // Output regex contains non-numeric patterns + // When the output regex isn't convertible to an IntegerDomain (here `.*`), the fallback + // intersects with the integer-shaped regex /^-?[0-9]+$/, so the result must accept + // integer strings and reject non-integer strings. RexNode inputRef = rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0); RexNode castExpr = rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.INTEGER), inputRef); - // This regex has wildcards, not convertible to IntegerDomain RegexDomain outputDomain = new RegexDomain("^.*$"); - Domain inputDomain = transformer.refineInputDomain(castExpr, outputDomain); - // Should fall back to integer format constraint assertTrue(inputDomain instanceof RegexDomain); - assertFalse(inputDomain.isEmpty()); + RegexDomain regexInput = (RegexDomain) inputDomain; + + assertTrue(regexInput.getAutomaton().run("0"), "should accept '0'"); + assertTrue(regexInput.getAutomaton().run("42"), "should accept '42'"); + assertTrue(regexInput.getAutomaton().run("-7"), "should accept '-7'"); + assertFalse(regexInput.getAutomaton().run("abc"), "should reject non-numeric 'abc'"); + assertFalse(regexInput.getAutomaton().run("1.5"), "should reject decimal '1.5'"); + assertFalse(regexInput.getAutomaton().run(""), "should reject empty"); } // ==================== canHandle Tests ==================== diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerDomainTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerDomainTest.java index 8fb6884b8..084729813 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerDomainTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerDomainTest.java @@ -31,7 +31,7 @@ public void testSingleValue() { List samples = domain.sample(5); assertFalse(samples.isEmpty()); for (long v : samples) { - assertTrue(domain.contains(v)); + assertEquals(v, 42L, "every sample from a singleton must be the value"); } } @@ -63,8 +63,9 @@ public void testMultipleIntervals() { assertTrue(domain.contains(25)); assertFalse(domain.contains(16)); + // Domain has 5 + 6 + 11 = 22 distinct values, so sample(10) should return 10 items. List samples = domain.sample(10); - assertFalse(samples.isEmpty()); + assertEquals(samples.size(), 10, "sample(10) should return 10 items when domain has >= 10 values"); for (long v : samples) { assertTrue(domain.contains(v)); } @@ -128,17 +129,22 @@ public void testAddConstant() { @Test public void testMultiplyConstant() { + // [10, 20] * 2: the exact image is the sparse set {20, 22, …, 40}, but the domain + // tracks the convex hull [20, 40] as a sound over-approximation. Verify both bounds + // and the rejection of values outside the hull. IntegerDomain domain = IntegerDomain.of(10, 20); IntegerDomain scaled = domain.multiply(2); - assertTrue(scaled.contains(20)); - assertTrue(scaled.contains(40)); - assertFalse(scaled.contains(19)); - assertFalse(scaled.contains(41)); + assertTrue(scaled.contains(20), "lower bound 20 = 10*2"); + assertTrue(scaled.contains(30), "midpoint must be in convex hull"); + assertTrue(scaled.contains(40), "upper bound 40 = 20*2"); + assertFalse(scaled.contains(19), "below lower bound"); + assertFalse(scaled.contains(41), "above upper bound"); List samples = scaled.sample(5); for (long v : samples) { assertTrue(scaled.contains(v)); + assertTrue(v >= 20 && v <= 40, "every sample must be in [20, 40]: " + v); } } diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerTransformerTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerTransformerTest.java index 2d7991a47..9f9483945 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerTransformerTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/IntegerTransformerTest.java @@ -8,16 +8,44 @@ import java.util.Arrays; import java.util.List; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; +import com.linkedin.coral.datagen.domain.transformer.AbsIntegerTransformer; +import com.linkedin.coral.datagen.domain.transformer.MinusIntegerTransformer; +import com.linkedin.coral.datagen.domain.transformer.NegateIntegerTransformer; + import static org.testng.Assert.*; /** - * Tests for integer domain transformers (Plus and Times). + * Tests for integer domain transformers (Plus, Times, Minus, Negate, Abs). */ public class IntegerTransformerTest { + private RexBuilder rexBuilder; + private RelDataTypeFactory typeFactory; + + @BeforeMethod + public void setup() { + rexBuilder = TestHelper.createRexBuilder(); + typeFactory = rexBuilder.getTypeFactory(); + } + + private RexNode intRef() { + return rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.INTEGER), 0); + } + + private RexLiteral intLit(long value) { + return (RexLiteral) rexBuilder.makeLiteral(value, typeFactory.createSqlType(SqlTypeName.BIGINT), false); + } + // ==================== Plus Transformer ==================== @Test @@ -151,4 +179,198 @@ public void testIntersectionWithArithmetic() { assertEquals(v, 15); } } + + // ==================== Minus Transformer ==================== + + @Test + public void testMinusTransformerVariableMinusLiteral() { + // x - 3 = 7 => x = 10 + MinusIntegerTransformer transformer = new MinusIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, intRef(), intLit(3)); + + assertTrue(transformer.canHandle(expr)); + assertTrue(transformer.isVariableOperandPositionValid(expr)); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(7)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(10)); + assertTrue(intResult.isSingleton()); + } + + @Test + public void testMinusTransformerLiteralMinusVariable() { + // 10 - x = 3 => x = 7 + MinusIntegerTransformer transformer = new MinusIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, intLit(10), intRef()); + + assertTrue(transformer.canHandle(expr)); + assertTrue(transformer.isVariableOperandPositionValid(expr)); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(3)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(7)); + assertTrue(intResult.isSingleton()); + } + + @Test + public void testMinusTransformerInterval() { + // x - 5 in [20, 30] => x in [25, 35] + MinusIntegerTransformer transformer = new MinusIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, intRef(), intLit(5)); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(20, 30)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(25)); + assertTrue(intResult.contains(35)); + assertFalse(intResult.contains(24)); + assertFalse(intResult.contains(36)); + } + + @Test + public void testMinusTransformerLiteralMinusVariableInterval() { + // 50 - x in [10, 20] => x in [30, 40] + MinusIntegerTransformer transformer = new MinusIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, intLit(50), intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(10, 20)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(30)); + assertTrue(intResult.contains(40)); + assertFalse(intResult.contains(29)); + assertFalse(intResult.contains(41)); + } + + // ==================== Negate Transformer ==================== + + @Test + public void testNegateTransformerSingleValue() { + // -x = -5 => x = 5 + NegateIntegerTransformer transformer = new NegateIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, intRef()); + + assertTrue(transformer.canHandle(expr)); + assertTrue(transformer.isVariableOperandPositionValid(expr)); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(-5)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(5)); + assertTrue(intResult.isSingleton()); + } + + @Test + public void testNegateTransformerInterval() { + // -x in [10, 20] => x in [-20, -10] + NegateIntegerTransformer transformer = new NegateIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(10, 20)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(-20)); + assertTrue(intResult.contains(-10)); + assertFalse(intResult.contains(-21)); + assertFalse(intResult.contains(-9)); + } + + @Test + public void testNegateTransformerCrossesZero() { + // -x in [-5, 5] => x in [-5, 5] (symmetric) + NegateIntegerTransformer transformer = new NegateIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(-5, 5)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(-5)); + assertTrue(intResult.contains(0)); + assertTrue(intResult.contains(5)); + assertFalse(intResult.contains(-6)); + assertFalse(intResult.contains(6)); + } + + // ==================== Abs Transformer ==================== + + @Test + public void testAbsTransformerPositiveValue() { + // ABS(x) = 5 => x in {-5, 5} + AbsIntegerTransformer transformer = new AbsIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.ABS, intRef()); + + assertTrue(transformer.canHandle(expr)); + assertTrue(transformer.isVariableOperandPositionValid(expr)); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(5)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(-5)); + assertTrue(intResult.contains(5)); + assertFalse(intResult.contains(0)); + assertFalse(intResult.contains(4)); + assertFalse(intResult.contains(-4)); + } + + @Test + public void testAbsTransformerZero() { + // ABS(x) = 0 => x = 0 + AbsIntegerTransformer transformer = new AbsIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.ABS, intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(0)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(0)); + assertTrue(intResult.isSingleton()); + } + + @Test + public void testAbsTransformerPositiveInterval() { + // ABS(x) in [3, 7] => x in [-7, -3] ∪ [3, 7] + AbsIntegerTransformer transformer = new AbsIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.ABS, intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(3, 7)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(3)); + assertTrue(intResult.contains(7)); + assertTrue(intResult.contains(-3)); + assertTrue(intResult.contains(-7)); + assertFalse(intResult.contains(0)); + assertFalse(intResult.contains(2)); + assertFalse(intResult.contains(-2)); + assertFalse(intResult.contains(8)); + assertFalse(intResult.contains(-8)); + } + + @Test + public void testAbsTransformerIntervalIncludingZero() { + // ABS(x) in [0, 5] => x in [-5, 5] + AbsIntegerTransformer transformer = new AbsIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.ABS, intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(0, 5)); + assertTrue(result instanceof IntegerDomain); + IntegerDomain intResult = (IntegerDomain) result; + assertTrue(intResult.contains(0)); + assertTrue(intResult.contains(-5)); + assertTrue(intResult.contains(5)); + assertFalse(intResult.contains(-6)); + assertFalse(intResult.contains(6)); + } + + @Test + public void testAbsTransformerNegativeIntervalProducesEmpty() { + // ABS(x) in [-5, -1] => no valid input (ABS is always non-negative) + AbsIntegerTransformer transformer = new AbsIntegerTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.ABS, intRef()); + + Domain result = transformer.refineInputDomain(expr, IntegerDomain.of(-5, -1)); + assertTrue(result instanceof IntegerDomain); + assertTrue(result.isEmpty()); + } } diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexDomainInferenceProgramTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexDomainInferenceProgramTest.java index 984c784ef..04b392b3f 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexDomainInferenceProgramTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexDomainInferenceProgramTest.java @@ -9,6 +9,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.List; +import java.util.Map; import java.util.UUID; import org.apache.calcite.rel.RelNode; @@ -59,6 +60,20 @@ public void setup() { Driver driver = new Driver(conf); run(driver, String.join("\n", "", "CREATE DATABASE IF NOT EXISTS test")); run(driver, String.join("\n", "", "CREATE TABLE IF NOT EXISTS test.T (name STRING, age INT, birthdate DATE)")); + run(driver, String.join("\n", "", + "CREATE TABLE IF NOT EXISTS test.complex (a INT, b STRING, c ARRAY, s STRUCT, m MAP, sarr ARRAY>)")); + run(driver, + String.join("\n", "", "CREATE TABLE IF NOT EXISTS test.deep (" + "id INT, " + + "nested_struct STRUCT>, " + + "map_of_structs MAP>, " + "str_map MAP" + ")")); + // Table with interleaved complex types: struct containing map, struct containing array, + // array of maps of structs, etc. + run(driver, + String.join("\n", "", + "CREATE TABLE IF NOT EXISTS test.interleaved (" + "id INT, " + + "sm STRUCT, scores:ARRAY, name:STRING>, " + + "ams ARRAY>, " + "msa MAP>, " + + "amss ARRAY, label:STRING, value:INT>>" + ")")); run(driver, String.join("\n", "", "USE test")); try { java.util.List dbs = Hive.get(conf).getMSC().getAllDatabases(); @@ -142,6 +157,7 @@ private interface DomainAssertion { /** * Common test logic for domain inference. * Extracts the predicate, derives the input domain, and runs optional assertions. + * Supports EQUALS, GREATER_THAN, LESS_THAN, GREATER_THAN_OR_EQUAL, LESS_THAN_OR_EQUAL. */ private void testDomainInference(String testName, String sql, DomainAssertion assertion) { RelNode normalized = convertAndNormalizeQuery(sql); @@ -152,27 +168,9 @@ private void testDomainInference(String testName, String sql, DomainAssertion as assertTrue(disjunct instanceof org.apache.calcite.rex.RexCall, testName + ": disjunct should be a RexCall"); org.apache.calcite.rex.RexCall call = (org.apache.calcite.rex.RexCall) disjunct; - assertEquals(call.getOperator(), org.apache.calcite.sql.fun.SqlStdOperatorTable.EQUALS, - testName + ": operator should be EQUALS"); - - RexNode lhs = call.getOperands().get(0); - RexNode rhs = call.getOperands().get(1); - - assertTrue(rhs instanceof org.apache.calcite.rex.RexLiteral, testName + ": RHS should be a literal"); - org.apache.calcite.rex.RexLiteral literal = (org.apache.calcite.rex.RexLiteral) rhs; - String rhsValue = literal.getValue2().toString(); - - // Create appropriate domain based on RHS literal type - Domain outputDomain; - if (isNumericType(literal.getType().getSqlTypeName())) { - long numericValue = Long.parseLong(rhsValue); - outputDomain = IntegerDomain.of(numericValue); - } else { - outputDomain = RegexDomain.literal(rhsValue); - } - // Derive input domain constraint - Domain inputDomain = program.deriveInputDomain(lhs, outputDomain); + // Use the program's predicate-based inference which handles all comparison operators + Domain inputDomain = program.deriveInputDomainFromPredicate(call); assertFalse(inputDomain.isEmpty(), testName + ": derived input domain should not be empty"); @@ -365,22 +363,1341 @@ public void testNestedSubstringDisjoint() { }); } + @Test + public void testSimpleUpper() { + testDomainInference("Simple UPPER Test", "SELECT * FROM test.T WHERE UPPER(name) = 'ABC'", inputDomain -> { + assertTrue(inputDomain instanceof RegexDomain, "Should be RegexDomain"); + List examples = inputDomain.sample(5); + assertEquals(5, examples.size(), "Should generate 5 examples"); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("ABC", s.toUpperCase(), "Should be case-insensitive 'ABC': " + s); + } + }); + } + + @Test + public void testSimpleMinus() { + testDomainInference("Simple MINUS Test", "SELECT * FROM test.T WHERE age - 3 = 7", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(10), "Should contain 10 (since 10 - 3 = 7)"); + assertTrue(intDomain.isSingleton(), "Should be singleton"); + }); + } + + @Test + public void testMinusWithArithmetic() { + testDomainInference("MINUS with Arithmetic Test", "SELECT * FROM test.T WHERE age * 2 - 5 = 15", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(10), "Should contain 10 (since 10 * 2 - 5 = 15)"); + assertTrue(intDomain.isSingleton(), "Should be singleton"); + }); + } + + @Test + public void testNestedUpperSubstring() { + testDomainInference("Nested UPPER(SUBSTRING) Test", + "SELECT * FROM test.T WHERE UPPER(SUBSTRING(name, 1, 3)) = 'ABC'", inputDomain -> { + assertTrue(inputDomain instanceof RegexDomain, "Should be RegexDomain"); + List examples = inputDomain.sample(3); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertTrue(s.length() >= 3, "Should have at least 3 characters"); + assertEquals("ABC", s.substring(0, 3).toUpperCase(), "First 3 chars should be 'ABC': " + s); + } + }); + } + + @Test + public void testAbsSimple() { + testDomainInference("Simple ABS Test", "SELECT * FROM test.T WHERE ABS(age) = 5", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(5), "Should contain 5"); + assertTrue(intDomain.contains(-5), "Should contain -5"); + assertFalse(intDomain.contains(0), "Should not contain 0"); + }); + } + + @Test + public void testAbsZero() { + testDomainInference("ABS Zero Test", "SELECT * FROM test.T WHERE ABS(age) = 0", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(0), "Should contain 0"); + assertTrue(intDomain.isSingleton(), "Should be singleton"); + }); + } + + @Test + public void testAbsWithArithmetic() { + testDomainInference("ABS with Arithmetic Test", "SELECT * FROM test.T WHERE ABS(age + 1) = 5", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(4), "Should contain 4 (since ABS(4 + 1) = 5)"); + assertTrue(intDomain.contains(-6), "Should contain -6 (since ABS(-6 + 1) = 5)"); + }); + } + + @Test + public void testSimpleNegate() { + testDomainInference("Simple NEGATE Test", "SELECT * FROM test.T WHERE -age = -5", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(5), "Should contain 5 (since -5 negated = 5)"); + assertTrue(intDomain.isSingleton(), "Should be singleton"); + }); + } + + @Test + public void testNegateWithArithmetic() { + testDomainInference("NEGATE with Arithmetic Test", "SELECT * FROM test.T WHERE -(age + 2) = -10", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(8), "Should contain 8 (since -(8 + 2) = -10)"); + assertTrue(intDomain.isSingleton(), "Should be singleton"); + }); + } + + @Test + public void testSimpleConcatSuffix() { + testDomainInference("Simple CONCAT(x, suffix) Test", + "SELECT * FROM test.T WHERE CONCAT(name, 'World') = 'HelloWorld'", inputDomain -> { + assertTrue(inputDomain instanceof RegexDomain, "Should be RegexDomain"); + RegexDomain regex = (RegexDomain) inputDomain; + assertTrue(regex.isLiteral(), "Should be literal"); + assertEquals(regex.getLiteralValue(), "Hello", "Should strip 'World' suffix"); + }); + } + + @Test + public void testSimpleConcatPrefix() { + testDomainInference("Simple CONCAT(prefix, x) Test", + "SELECT * FROM test.T WHERE CONCAT('Hello', name) = 'HelloWorld'", inputDomain -> { + assertTrue(inputDomain instanceof RegexDomain, "Should be RegexDomain"); + RegexDomain regex = (RegexDomain) inputDomain; + assertTrue(regex.isLiteral(), "Should be literal"); + assertEquals(regex.getLiteralValue(), "World", "Should strip 'Hello' prefix"); + }); + } + + @Test + public void testSimpleTrim() { + testDomainInference("Simple TRIM Test", "SELECT * FROM test.T WHERE TRIM(name) = 'abc'", inputDomain -> { + assertTrue(inputDomain instanceof RegexDomain, "Should be RegexDomain"); + RegexDomain regex = (RegexDomain) inputDomain; + // The result should accept 'abc' with optional surrounding spaces. + List samples = regex.sample(5); + assertFalse(samples.isEmpty(), "Should generate samples"); + for (Object ex : samples) { + String s = ex.toString(); + assertEquals(s.trim(), "abc", "Sample should trim to 'abc': '" + s + "'"); + } + }); + } + + @Test + public void testGreaterThan() { + testDomainInference("Greater Than Test", "SELECT * FROM test.T WHERE age > 10", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(11), "Should contain 11"); + assertTrue(intDomain.contains(100), "Should contain 100"); + assertFalse(intDomain.contains(10), "Should not contain 10"); + assertFalse(intDomain.contains(5), "Should not contain 5"); + }); + } + + @Test + public void testGreaterThanOrEqual() { + testDomainInference("Greater Than Or Equal Test", "SELECT * FROM test.T WHERE age >= 10", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(10), "Should contain 10"); + assertTrue(intDomain.contains(100), "Should contain 100"); + assertFalse(intDomain.contains(9), "Should not contain 9"); + }); + } + + @Test + public void testLessThan() { + testDomainInference("Less Than Test", "SELECT * FROM test.T WHERE age < 10", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(9), "Should contain 9"); + assertTrue(intDomain.contains(0), "Should contain 0"); + assertFalse(intDomain.contains(10), "Should not contain 10"); + }); + } + + @Test + public void testGreaterThanWithArithmetic() { + testDomainInference("Greater Than with Arithmetic Test", "SELECT * FROM test.T WHERE age + 3 > 10", inputDomain -> { + assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); + IntegerDomain intDomain = (IntegerDomain) inputDomain; + assertTrue(intDomain.contains(8), "Should contain 8 (since 8 + 3 = 11 > 10)"); + assertFalse(intDomain.contains(7), "Should not contain 7 (since 7 + 3 = 10 = 10)"); + }); + } + + // ========== Multi-column resolution helpers ========== + + /** + * Helper interface for assertions on resolved multi-column domains. + */ + @FunctionalInterface + private interface MultiColumnAssertion { + void accept(Map> pathDomains); + } + /** - * Helper method to determine if a SQL type is numeric. + * Resolves domains for all columns using resolveAllPaths and runs assertions. + * Table test.T has columns: name(0:STRING), age(1:INT), birthdate(2:DATE). */ - private boolean isNumericType(org.apache.calcite.sql.type.SqlTypeName typeName) { - switch (typeName) { - case TINYINT: - case SMALLINT: - case INTEGER: - case BIGINT: - case DECIMAL: - case FLOAT: - case REAL: - case DOUBLE: - return true; - default: - return false; + private void testMultiColumnResolution(String testName, String sql, MultiColumnAssertion assertion) { + RelNode normalized = convertAndNormalizeQuery(sql); + DnfRewriter.Output dnfOut = extractPredicatesAsDnf(normalized); + assertFalse(dnfOut.disjuncts.isEmpty(), testName + ": should have at least one disjunct"); + + Map> pathDomains = program.resolveAllPaths(dnfOut.disjuncts); + assertFalse(pathDomains.isEmpty(), testName + ": should resolve at least one column"); + + assertion.accept(pathDomains); + } + + // ========== Conjunction (AND) tests ========== + + @Test + public void testAndTwoColumnsDifferent() { + // AND of predicates on two different columns: name and age + testMultiColumnResolution("AND two different columns", + "SELECT * FROM test.T WHERE LOWER(name) = 'abc' AND age = 25", pathDomains -> { + // Column 0: name + assertTrue(pathDomains.containsKey(AccessPath.of(0)), "Should resolve column 0 (name)"); + Domain nameDomain = pathDomains.get(AccessPath.of(0)); + assertTrue(nameDomain instanceof RegexDomain, "name should be RegexDomain"); + assertFalse(nameDomain.isEmpty(), "name domain should not be empty"); + + // Column 1: age + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + Domain ageDomain = pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain instanceof IntegerDomain, "age should be IntegerDomain"); + IntegerDomain ageInt = (IntegerDomain) ageDomain; + assertTrue(ageInt.contains(25), "age should contain 25"); + assertTrue(ageInt.isSingleton(), "age should be singleton {25}"); + }); + } + + @Test + public void testAndSameColumnIntersection() { + // AND of two predicates on the same column: age > 5 AND age < 10 + testMultiColumnResolution("AND same column intersection", "SELECT * FROM test.T WHERE age > 5 AND age < 10", + pathDomains -> { + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(6), "Should contain 6"); + assertTrue(ageDomain.contains(9), "Should contain 9"); + assertFalse(ageDomain.contains(5), "Should not contain 5"); + assertFalse(ageDomain.contains(10), "Should not contain 10"); + }); + } + + @Test + public void testAndSameColumnTightRange() { + // AND that produces a very narrow range: age >= 10 AND age <= 12 + testMultiColumnResolution("AND same column tight range", "SELECT * FROM test.T WHERE age >= 10 AND age <= 12", + pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(10), "Should contain 10"); + assertTrue(ageDomain.contains(11), "Should contain 11"); + assertTrue(ageDomain.contains(12), "Should contain 12"); + assertFalse(ageDomain.contains(9), "Should not contain 9"); + assertFalse(ageDomain.contains(13), "Should not contain 13"); + }); + } + + @Test + public void testAndThreeColumns() { + // AND of predicates on all three columns + testMultiColumnResolution("AND three columns", + "SELECT * FROM test.T WHERE LOWER(name) = 'abc' AND age > 18 AND SUBSTRING(CAST(birthdate AS STRING), 1, 4) = '2000'", + pathDomains -> { + // Column 0: name + assertTrue(pathDomains.containsKey(AccessPath.of(0)), "Should resolve column 0 (name)"); + assertTrue(pathDomains.get(AccessPath.of(0)) instanceof RegexDomain, "name should be RegexDomain"); + + // Column 1: age + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(19), "age should contain 19"); + assertFalse(ageDomain.contains(18), "age should not contain 18"); + + // Column 2: birthdate + assertTrue(pathDomains.containsKey(AccessPath.of(2)), "Should resolve column 2 (birthdate)"); + assertTrue(pathDomains.get(AccessPath.of(2)) instanceof RegexDomain, "birthdate should be RegexDomain"); + }); + } + + @Test + public void testAndWithExpressions() { + // AND with expressions on different columns + testMultiColumnResolution("AND with expressions", + "SELECT * FROM test.T WHERE LOWER(SUBSTRING(name, 1, 3)) = 'abc' AND age * 2 + 5 = 25", pathDomains -> { + // Column 0: name via LOWER(SUBSTRING(...)) + assertTrue(pathDomains.containsKey(AccessPath.of(0)), "Should resolve column 0 (name)"); + RegexDomain nameDomain = (RegexDomain) pathDomains.get(AccessPath.of(0)); + List nameExamples = nameDomain.sample(3); + for (Object ex : nameExamples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertTrue(s.length() >= 3, "Should have at least 3 characters"); + assertEquals("abc", s.substring(0, 3).toLowerCase(), "First 3 chars should be 'abc'"); + } + + // Column 1: age via age * 2 + 5 = 25 → age = 10 + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(10), "age should contain 10"); + assertTrue(ageDomain.isSingleton(), "age should be singleton {10}"); + }); + } + + // ========== Disjunction (OR) tests ========== + + @Test + public void testOrSameColumnUnion() { + // OR of two predicates on the same column: age = 10 OR age = 20 + testMultiColumnResolution("OR same column union", "SELECT * FROM test.T WHERE age = 10 OR age = 20", + pathDomains -> { + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(10), "Should contain 10"); + assertTrue(ageDomain.contains(20), "Should contain 20"); + assertFalse(ageDomain.contains(15), "Should not contain 15"); + }); + } + + @Test + public void testOrSameColumnRanges() { + // OR of two ranges: age < 5 OR age > 95 + testMultiColumnResolution("OR same column ranges", "SELECT * FROM test.T WHERE age < 5 OR age > 95", + pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(0), "Should contain 0"); + assertTrue(ageDomain.contains(4), "Should contain 4"); + assertTrue(ageDomain.contains(96), "Should contain 96"); + assertTrue(ageDomain.contains(1000), "Should contain 1000"); + assertFalse(ageDomain.contains(5), "Should not contain 5"); + assertFalse(ageDomain.contains(50), "Should not contain 50"); + assertFalse(ageDomain.contains(95), "Should not contain 95"); + }); + } + + @Test + public void testOrThreeValues() { + // OR of three discrete values + testMultiColumnResolution("OR three values", "SELECT * FROM test.T WHERE age = 1 OR age = 2 OR age = 3", + pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(1), "Should contain 1"); + assertTrue(ageDomain.contains(2), "Should contain 2"); + assertTrue(ageDomain.contains(3), "Should contain 3"); + assertFalse(ageDomain.contains(0), "Should not contain 0"); + assertFalse(ageDomain.contains(4), "Should not contain 4"); + }); + } + + @Test + public void testOrDifferentColumns() { + // OR on different columns: name = 'abc' OR age = 25. + // Both columns should resolve; per-column union semantics give name a regex matching + // case-insensitive 'abc' (from disjunct 1) and age a domain containing 25 (from disjunct 2). + testMultiColumnResolution("OR different columns", "SELECT * FROM test.T WHERE LOWER(name) = 'abc' OR age = 25", + pathDomains -> { + assertTrue(pathDomains.containsKey(AccessPath.of(0)), "Should resolve column 0 (name)"); + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + + RegexDomain nameDomain = (RegexDomain) pathDomains.get(AccessPath.of(0)); + for (Object ex : nameDomain.sample(3)) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals(s.toLowerCase(), "abc", "name sample should lower to 'abc': " + s); + } + + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(25), "age domain should contain 25"); + }); + } + + // ========== Combined AND/OR tests ========== + + @Test + public void testOrOfConjunctions() { + // (age > 10 AND age < 20) OR (age > 50 AND age < 60) + // Each disjunct is a conjunction that produces a range, then union across disjuncts + testMultiColumnResolution("OR of conjunctions", + "SELECT * FROM test.T WHERE (age > 10 AND age < 20) OR (age > 50 AND age < 60)", pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(15), "Should contain 15 (in first range)"); + assertTrue(ageDomain.contains(55), "Should contain 55 (in second range)"); + assertFalse(ageDomain.contains(10), "Should not contain 10"); + assertFalse(ageDomain.contains(20), "Should not contain 20"); + assertFalse(ageDomain.contains(30), "Should not contain 30 (between ranges)"); + assertFalse(ageDomain.contains(50), "Should not contain 50"); + assertFalse(ageDomain.contains(60), "Should not contain 60"); + }); + } + + @Test + public void testOrOfConjunctionsMultiColumn() { + // (name = 'alice' AND age = 25) OR (name = 'bob' AND age = 30) + // Column 0 (name): union of 'alice' and 'bob' domains + // Column 1 (age): union of {25} and {30} + testMultiColumnResolution("OR of conjunctions multi-column", + "SELECT * FROM test.T WHERE (LOWER(name) = 'alice' AND age = 25) OR (LOWER(name) = 'bob' AND age = 30)", + pathDomains -> { + // Column 1: age should be {25} ∪ {30} + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(25), "age should contain 25"); + assertTrue(ageDomain.contains(30), "age should contain 30"); + assertFalse(ageDomain.contains(27), "age should not contain 27"); + + // Column 0: name should be union of 'alice' and 'bob' regex domains + assertTrue(pathDomains.containsKey(AccessPath.of(0)), "Should resolve column 0 (name)"); + assertTrue(pathDomains.get(AccessPath.of(0)) instanceof RegexDomain, "name should be RegexDomain"); + }); + } + + @Test + public void testAndRangeWithEquality() { + // Combine a range and an equality on the same column: age >= 20 AND age = 25 + // Intersection should produce {25} + testMultiColumnResolution("AND range with equality", "SELECT * FROM test.T WHERE age >= 20 AND age = 25", + pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(25), "Should contain 25"); + assertTrue(ageDomain.isSingleton(), "Should be singleton {25}"); + }); + } + + @Test + public void testOrWithExpressions() { + // OR with expressions: age + 5 = 15 OR age * 2 = 40 + // First disjunct: age = 10, Second: age = 20, Union: {10, 20} + testMultiColumnResolution("OR with expressions", "SELECT * FROM test.T WHERE age + 5 = 15 OR age * 2 = 40", + pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(10), "Should contain 10 (since 10 + 5 = 15)"); + assertTrue(ageDomain.contains(20), "Should contain 20 (since 20 * 2 = 40)"); + assertFalse(ageDomain.contains(15), "Should not contain 15"); + }); + } + + @Test + public void testAndMixedDomainTypes() { + // String column AND integer column together + testMultiColumnResolution("AND mixed domain types", + "SELECT * FROM test.T WHERE SUBSTRING(name, 1, 3) = 'abc' AND age >= 0 AND age <= 100", pathDomains -> { + // Column 0: name + RegexDomain nameDomain = (RegexDomain) pathDomains.get(AccessPath.of(0)); + List nameExamples = nameDomain.sample(3); + for (Object ex : nameExamples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertTrue(s.startsWith("abc"), "Should start with 'abc': " + s); + } + + // Column 1: age in [0, 100] + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(0), "Should contain 0"); + assertTrue(ageDomain.contains(50), "Should contain 50"); + assertTrue(ageDomain.contains(100), "Should contain 100"); + assertFalse(ageDomain.contains(-1), "Should not contain -1"); + assertFalse(ageDomain.contains(101), "Should not contain 101"); + }); + } + + @Test + public void testSinglePredicateViaResolveAll() { + // Verify that resolveAllPaths works correctly for a single simple predicate + testMultiColumnResolution("Single predicate via resolveAll", "SELECT * FROM test.T WHERE age = 42", pathDomains -> { + assertEquals(1, pathDomains.size(), "Should resolve exactly one column"); + assertTrue(pathDomains.containsKey(AccessPath.of(1)), "Should resolve column 1 (age)"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(42), "Should contain 42"); + assertTrue(ageDomain.isSingleton(), "Should be singleton"); + }); + } + + @Test + public void testSingleStringPredicateViaResolveAll() { + // Single string predicate through resolveAllPaths + testMultiColumnResolution("Single string predicate via resolveAll", + "SELECT * FROM test.T WHERE LOWER(name) = 'hello'", pathDomains -> { + assertEquals(1, pathDomains.size(), "Should resolve exactly one column"); + assertTrue(pathDomains.containsKey(AccessPath.of(0)), "Should resolve column 0 (name)"); + RegexDomain nameDomain = (RegexDomain) pathDomains.get(AccessPath.of(0)); + List examples = nameDomain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("hello", s.toLowerCase(), "Should be case-insensitive 'hello'"); + } + }); + } + + @Test + public void testOrOverlappingRanges() { + // OR of overlapping ranges: age > 5 OR age > 3 + // Union should be age > 3 (i.e., [4, MAX]) + testMultiColumnResolution("OR overlapping ranges", "SELECT * FROM test.T WHERE age > 5 OR age > 3", pathDomains -> { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.contains(4), "Should contain 4 (from age > 3)"); + assertTrue(ageDomain.contains(6), "Should contain 6 (from both)"); + assertTrue(ageDomain.contains(100), "Should contain 100"); + assertFalse(ageDomain.contains(3), "Should not contain 3"); + }); + } + + @Test + public void testAndContradictoryRange() { + // AND that produces empty intersection: age > 10 AND age < 5 + // The intersection should be empty + RelNode normalized = convertAndNormalizeQuery("SELECT * FROM test.T WHERE age > 10 AND age < 5"); + DnfRewriter.Output dnfOut = extractPredicatesAsDnf(normalized); + Map> pathDomains = program.resolveAllPaths(dnfOut.disjuncts); + + if (pathDomains.containsKey(AccessPath.of(1))) { + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(AccessPath.of(1)); + assertTrue(ageDomain.isEmpty(), "Contradictory range should be empty"); } + // It's also acceptable if the column isn't present at all (empty domain omitted) + } + + // ========== Complex type (struct, array, map) tests ========== + + /** + * Helper for multi-column resolution on test.complex table. + * Columns: a(0:INT), b(1:STRING), c(2:ARRAY), s(3:STRUCT), + * m(4:MAP), sarr(5:ARRAY>) + */ + private void testComplexTypeResolution(String testName, String sql, MultiColumnAssertion assertion) { + RelNode normalized = convertAndNormalizeQuery(sql); + DnfRewriter.Output dnfOut = extractPredicatesAsDnf(normalized); + assertFalse(dnfOut.disjuncts.isEmpty(), testName + ": should have at least one disjunct"); + + Map> pathDomains = program.resolveAllPaths(dnfOut.disjuncts); + assertFalse(pathDomains.isEmpty(), testName + ": should resolve at least one column"); + + assertion.accept(pathDomains); + } + + @Test + public void testStructFieldEquality() { + // Struct field access: s.name = 'alice' + testComplexTypeResolution("Struct field equality", "SELECT * FROM test.complex WHERE s.name = 'alice'", + pathDomains -> { + AccessPath sName = AccessPath.ofField(3, "name"); + assertTrue(pathDomains.containsKey(sName), "Should resolve s.name at " + sName); + RegexDomain domain = (RegexDomain) pathDomains.get(sName); + assertTrue(domain.isLiteral(), "s.name should be the literal 'alice'"); + assertEquals(domain.getLiteralValue(), "alice"); + }); + } + + @Test + public void testStructFieldWithFunction() { + // LOWER(s.name) = 'alice' + testComplexTypeResolution("Struct field with LOWER", "SELECT * FROM test.complex WHERE LOWER(s.name) = 'alice'", + pathDomains -> { + AccessPath sName = AccessPath.ofField(3, "name"); + assertTrue(pathDomains.containsKey(sName), "Should resolve s.name"); + RegexDomain domain = (RegexDomain) pathDomains.get(sName); + List examples = domain.sample(5); + for (Object ex : examples) { + String str = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("alice", str.toLowerCase(), "Should be case-insensitive 'alice': " + str); + } + }); + } + + @Test + public void testStructIntField() { + // s.age > 21 + testComplexTypeResolution("Struct int field", "SELECT * FROM test.complex WHERE s.age > 21", pathDomains -> { + AccessPath sAge = AccessPath.ofField(3, "age"); + assertTrue(pathDomains.containsKey(sAge), "Should resolve s.age"); + IntegerDomain domain = (IntegerDomain) pathDomains.get(sAge); + assertTrue(domain.contains(22), "Should contain 22"); + assertFalse(domain.contains(21), "Should not contain 21"); + }); + } + + @Test + public void testMapElementEquality() { + // m['key1'] = 'val' + testComplexTypeResolution("Map element equality", "SELECT * FROM test.complex WHERE m['key1'] = 'val'", + pathDomains -> { + AccessPath mKey1 = AccessPath.ofMapKey(4, "key1"); + assertTrue(pathDomains.containsKey(mKey1), "Should resolve m['key1']"); + RegexDomain domain = (RegexDomain) pathDomains.get(mKey1); + assertTrue(domain.isLiteral(), "m['key1'] should be the literal 'val'"); + assertEquals(domain.getLiteralValue(), "val"); + }); + } + + @Test + public void testMapElementWithFunction() { + // UPPER(m['k']) = 'ABC' + testComplexTypeResolution("Map element with UPPER", "SELECT * FROM test.complex WHERE UPPER(m['k']) = 'ABC'", + pathDomains -> { + AccessPath mK = AccessPath.ofMapKey(4, "k"); + assertTrue(pathDomains.containsKey(mK), "Should resolve m['k']"); + RegexDomain domain = (RegexDomain) pathDomains.get(mK); + List examples = domain.sample(5); + for (Object ex : examples) { + String str = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("abc", str.toLowerCase(), "Should be case-insensitive 'abc': " + str); + } + }); + } + + @Test + public void testMultiColumnWithStruct() { + // s.name = 'x' AND s.age > 10: two different struct fields → two AccessPath entries + testComplexTypeResolution("Multi-column with struct", + "SELECT * FROM test.complex WHERE s.name = 'x' AND s.age > 10", pathDomains -> { + AccessPath sName = AccessPath.ofField(3, "name"); + AccessPath sAge = AccessPath.ofField(3, "age"); + assertTrue(pathDomains.containsKey(sName), "Should resolve s.name"); + assertTrue(pathDomains.containsKey(sAge), "Should resolve s.age"); + RegexDomain nameDomain = (RegexDomain) pathDomains.get(sName); + assertTrue(nameDomain.isLiteral(), "s.name should be the literal 'x'"); + assertEquals(nameDomain.getLiteralValue(), "x"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(sAge); + assertTrue(ageDomain.contains(11), "s.age should contain 11"); + assertFalse(ageDomain.contains(10), "s.age should not contain 10"); + }); + } + + @Test + public void testStructAndFlatColumn() { + // a > 5 AND s.name = 'z': flat column + struct field + testComplexTypeResolution("Struct and flat column", "SELECT * FROM test.complex WHERE a > 5 AND s.name = 'z'", + pathDomains -> { + AccessPath colA = AccessPath.of(0); + AccessPath sName = AccessPath.ofField(3, "name"); + assertTrue(pathDomains.containsKey(colA), "Should resolve column a"); + assertTrue(pathDomains.containsKey(sName), "Should resolve s.name"); + IntegerDomain aDomain = (IntegerDomain) pathDomains.get(colA); + assertTrue(aDomain.contains(6), "a should contain 6"); + assertFalse(aDomain.contains(5), "a should not contain 5"); + RegexDomain nameDomain = (RegexDomain) pathDomains.get(sName); + assertTrue(nameDomain.isLiteral(), "s.name should be the literal 'z'"); + assertEquals(nameDomain.getLiteralValue(), "z"); + }); + } + + @Test + public void testNestedArrayOfStructs() { + // sarr[0].name = 'bob': array element → struct field + testComplexTypeResolution("Nested array of structs", "SELECT * FROM test.complex WHERE sarr[0].name = 'bob'", + pathDomains -> { + // Hive 0-based index is shifted to Calcite 1-based: sarr[0] becomes ITEM($5, 1). + // The resolved path should end at the "name" struct field and the domain should be + // the literal 'bob'. + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 5 && !path.isFlat() + && "name".equals(path.getPath().get(path.getPath().size() - 1).getFieldName())) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "sarr[].name should be the literal 'bob'"); + assertEquals(domain.getLiteralValue(), "bob"); + } + } + assertTrue(found, "Should resolve sarr[0].name to literal 'bob' (column 5, path ending at 'name')"); + }); + } + + // ========== Deeply nested types (test.deep) ========== + + /** + * Helper for multi-column resolution on test.deep table. + * Columns: id(0:INT), + * nested_struct(1:STRUCT>), + * map_of_structs(2:MAP>), + * str_map(3:MAP) + */ + private void testDeepTypeResolution(String testName, String sql, MultiColumnAssertion assertion) { + RelNode normalized = convertAndNormalizeQuery(sql); + DnfRewriter.Output dnfOut = extractPredicatesAsDnf(normalized); + assertFalse(dnfOut.disjuncts.isEmpty(), testName + ": should have at least one disjunct"); + + Map> pathDomains = program.resolveAllPaths(dnfOut.disjuncts); + assertFalse(pathDomains.isEmpty(), testName + ": should resolve at least one column"); + + assertion.accept(pathDomains); + } + + // --- Deeply nested struct: struct.inner.value --- + + @Test + public void testDoubleNestedStructField() { + // nested_struct.sub.value = 'hello': struct → inner struct → string field + testDeepTypeResolution("Double-nested struct field", + "SELECT * FROM test.deep WHERE nested_struct.sub.value = 'hello'", pathDomains -> { + // Path: $1.sub.value + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1 && path.getPath().size() == 2) { + found = true; + assertEquals("sub", path.getPath().get(0).getFieldName(), "First element should be 'sub'"); + assertEquals("value", path.getPath().get(1).getFieldName(), "Second element should be 'value'"); + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "nested_struct.sub.value should be literal 'hello'"); + assertEquals(domain.getLiteralValue(), "hello"); + } + } + assertTrue(found, "Should resolve nested_struct.sub.value"); + }); + } + + @Test + public void testDoubleNestedStructIntField() { + // nested_struct.sub.count > 100: struct → inner struct → int field with comparison + testDeepTypeResolution("Double-nested struct int field", + "SELECT * FROM test.deep WHERE nested_struct.sub.count > 100", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1 && path.getPath().size() == 2) { + found = true; + assertEquals("count", path.getPath().get(1).getFieldName()); + IntegerDomain domain = (IntegerDomain) entry.getValue(); + assertTrue(domain.contains(101), "Should contain 101"); + assertFalse(domain.contains(100), "Should not contain 100"); + } + } + assertTrue(found, "Should resolve nested_struct.sub.count"); + }); + } + + // --- Function on deeply nested struct field --- + + @Test + public void testFunctionOnDoubleNestedStruct() { + // LOWER(SUBSTRING(nested_struct.sub.value, 1, 3)) = 'abc' + // Chains: LOWER → SUBSTRING → RexFieldAccess(RexFieldAccess($1, sub), value) + testDeepTypeResolution("Function on double-nested struct", + "SELECT * FROM test.deep WHERE LOWER(SUBSTRING(nested_struct.sub.value, 1, 3)) = 'abc'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1 && path.getPath().size() == 2 + && "value".equals(path.getPath().get(1).getFieldName())) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + List examples = domain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertTrue(s.length() >= 3, "Should have at least 3 chars"); + assertEquals("abc", s.substring(0, 3).toLowerCase(), "First 3 chars should be 'abc': " + s); + } + } + } + assertTrue(found, "Should resolve nested_struct.inner.value through LOWER(SUBSTRING(...))"); + }); + } + + // --- Map of structs: map_of_structs['key'].label --- + + @Test + public void testMapOfStructsFieldAccess() { + // map_of_structs['key1'].label = 'important' + testDeepTypeResolution("Map of structs field access", + "SELECT * FROM test.deep WHERE map_of_structs['key1'].label = 'important'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 2 && !path.isFlat() + && "label".equals(path.getPath().get(path.getPath().size() - 1).getFieldName())) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "map_of_structs['key1'].label should be literal 'important'"); + assertEquals(domain.getLiteralValue(), "important"); + } + } + assertTrue(found, "Should resolve map_of_structs['key1'].label to literal 'important'"); + }); + } + + @Test + public void testMapOfStructsIntFieldComparison() { + // map_of_structs['key1'].score >= 80 AND map_of_structs['key1'].score <= 100 + testDeepTypeResolution("Map of structs int field range", + "SELECT * FROM test.deep WHERE map_of_structs['key1'].score >= 80 AND map_of_structs['key1'].score <= 100", + pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 2 && !path.isFlat()) { + // Check that the path includes 'score' field + List elements = path.getPath(); + boolean hasScore = elements.stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.FIELD && "score".equals(e.getFieldName())); + if (hasScore) { + found = true; + IntegerDomain domain = (IntegerDomain) entry.getValue(); + assertTrue(domain.contains(80), "Should contain 80"); + assertTrue(domain.contains(90), "Should contain 90"); + assertTrue(domain.contains(100), "Should contain 100"); + assertFalse(domain.contains(79), "Should not contain 79"); + assertFalse(domain.contains(101), "Should not contain 101"); + } + } + } + assertTrue(found, "Should resolve map_of_structs['key1'].score range"); + }); + } + + // --- Flat struct fields alongside nested struct fields --- + + @Test + public void testFlatAndNestedStructFieldsTogether() { + // nested_struct.inner_val = 'hello' AND nested_struct.sub.count > 50 + // Tests flat struct field (depth=1) alongside nested struct field (depth=2) on same column + testDeepTypeResolution("Flat and nested struct fields together", + "SELECT * FROM test.deep WHERE nested_struct.inner_val = 'hello' AND nested_struct.sub.count > 50", + pathDomains -> { + // Flat field: $1.inner_val + boolean foundFlat = false; + boolean foundNested = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1) { + if (path.getPath().size() == 1 && "inner_val".equals(path.getPath().get(0).getFieldName())) { + foundFlat = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "inner_val should be literal 'hello'"); + assertEquals(domain.getLiteralValue(), "hello"); + } + if (path.getPath().size() == 2 && "count".equals(path.getPath().get(1).getFieldName())) { + foundNested = true; + IntegerDomain domain = (IntegerDomain) entry.getValue(); + assertTrue(domain.contains(51), "sub.count should contain 51"); + assertFalse(domain.contains(50), "sub.count should not contain 50"); + } + } + } + assertTrue(foundFlat, "Should resolve nested_struct.inner_val"); + assertTrue(foundNested, "Should resolve nested_struct.sub.count"); + }); + } + + // ========== Combined multi-aspect tests ========== + + @Test + public void testFlatStructMapAllColumnsAndOr() { + // Combines: flat col, struct field, map access, AND within disjuncts, OR across disjuncts + // (a > 5 AND s.name = 'alice' AND m['role'] = 'admin') OR (a = 1 AND s.age <= 30) + testComplexTypeResolution("Flat + struct + map with AND/OR", "SELECT * FROM test.complex WHERE " + + "(a > 5 AND s.name = 'alice' AND m['role'] = 'admin') OR (a = 1 AND s.age <= 30)", pathDomains -> { + // Column 0 (a): union of [6, MAX] and {1} + AccessPath colA = AccessPath.of(0); + assertTrue(pathDomains.containsKey(colA), "Should resolve column a"); + IntegerDomain aDomain = (IntegerDomain) pathDomains.get(colA); + assertTrue(aDomain.contains(1), "a should contain 1 (from second disjunct)"); + assertTrue(aDomain.contains(6), "a should contain 6 (from first disjunct)"); + assertTrue(aDomain.contains(100), "a should contain 100 (from first disjunct)"); + assertFalse(aDomain.contains(2), "a should not contain 2 (not in either disjunct)"); + assertFalse(aDomain.contains(5), "a should not contain 5 (not in either disjunct)"); + }); + } + + @Test + public void testStructFieldWithArithmeticAndComparison() { + // Arithmetic on struct int field + comparison: s.age * 2 + 5 = 25 → s.age = 10 + testComplexTypeResolution("Struct field arithmetic + comparison", + "SELECT * FROM test.complex WHERE s.age * 2 + 5 = 25", pathDomains -> { + AccessPath sAge = AccessPath.ofField(3, "age"); + assertTrue(pathDomains.containsKey(sAge), "Should resolve s.age"); + IntegerDomain domain = (IntegerDomain) pathDomains.get(sAge); + assertTrue(domain.contains(10), "s.age should contain 10 (since 10*2+5=25)"); + assertTrue(domain.isSingleton(), "s.age should be singleton {10}"); + }); } + + @Test + public void testCastStructFieldCrossDomain() { + // CAST(s.age AS STRING) = '42': cross-domain from RegexDomain → IntegerDomain through struct field + testComplexTypeResolution("CAST struct field cross-domain", + "SELECT * FROM test.complex WHERE CAST(s.age AS STRING) = '42'", pathDomains -> { + AccessPath sAge = AccessPath.ofField(3, "age"); + assertTrue(pathDomains.containsKey(sAge), "Should resolve s.age"); + IntegerDomain domain = (IntegerDomain) pathDomains.get(sAge); + assertTrue(domain.contains(42), "s.age should contain 42"); + assertTrue(domain.isSingleton(), "s.age should be singleton {42}"); + }); + } + + @Test + public void testFunctionOnArrayOfStructsField() { + // UPPER(sarr[0].name) = 'ALICE': function → ITEM → field access + testComplexTypeResolution("Function on array-of-structs field", + "SELECT * FROM test.complex WHERE UPPER(sarr[0].name) = 'ALICE'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 5 && !path.isFlat()) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + List examples = domain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("alice", s.toLowerCase(), "Should be case-insensitive 'alice': " + s); + } + } + } + assertTrue(found, "Should resolve UPPER(sarr[0].name)"); + }); + } + + @Test + public void testArrayOfStructsArithmeticOnIntField() { + // sarr[0].age + 10 > 30 → sarr[0].age > 20 + testComplexTypeResolution("Array-of-structs int field arithmetic", + "SELECT * FROM test.complex WHERE sarr[0].age + 10 > 30", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 5 && !path.isFlat()) { + List elems = path.getPath(); + boolean hasAge = elems.stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.FIELD && "age".equals(e.getFieldName())); + if (hasAge) { + found = true; + IntegerDomain domain = (IntegerDomain) entry.getValue(); + assertTrue(domain.contains(21), "Should contain 21"); + assertFalse(domain.contains(20), "Should not contain 20"); + } + } + } + assertTrue(found, "Should resolve sarr[0].age through arithmetic"); + }); + } + + @Test + public void testMultipleMapKeysWithDifferentDomains() { + // m['x'] = 'hello' AND m['y'] = 'world': two different map keys → two AccessPath entries + testComplexTypeResolution("Multiple map keys", + "SELECT * FROM test.complex WHERE m['x'] = 'hello' AND m['y'] = 'world'", pathDomains -> { + AccessPath mX = AccessPath.ofMapKey(4, "x"); + AccessPath mY = AccessPath.ofMapKey(4, "y"); + assertTrue(pathDomains.containsKey(mX), "Should resolve m['x']"); + assertTrue(pathDomains.containsKey(mY), "Should resolve m['y']"); + RegexDomain xDomain = (RegexDomain) pathDomains.get(mX); + RegexDomain yDomain = (RegexDomain) pathDomains.get(mY); + assertTrue(xDomain.isLiteral(), "m['x'] should be literal 'hello'"); + assertEquals(xDomain.getLiteralValue(), "hello"); + assertTrue(yDomain.isLiteral(), "m['y'] should be literal 'world'"); + assertEquals(yDomain.getLiteralValue(), "world"); + }); + } + + @Test + public void testStructFieldOrDisjunction() { + // s.name = 'alice' OR s.name = 'bob': same struct field in OR → union of regex domains + testComplexTypeResolution("Struct field OR disjunction", + "SELECT * FROM test.complex WHERE s.name = 'alice' OR s.name = 'bob'", pathDomains -> { + AccessPath sName = AccessPath.ofField(3, "name"); + assertTrue(pathDomains.containsKey(sName), "Should resolve s.name"); + RegexDomain domain = (RegexDomain) pathDomains.get(sName); + // The union should accept exactly 'alice' and 'bob' — both must be in, others out. + assertTrue(domain.getAutomaton().run("alice"), "domain should accept 'alice'"); + assertTrue(domain.getAutomaton().run("bob"), "domain should accept 'bob'"); + assertFalse(domain.getAutomaton().run("carol"), "domain should not accept 'carol'"); + assertFalse(domain.getAutomaton().run(""), "domain should not accept empty string"); + }); + } + + @Test + public void testMixedTypesAcrossAllColumnKinds() { + // Combines flat, struct, map, array-of-struct in one query with AND + // a >= 0 AND a <= 1000 AND LOWER(s.name) = 'test' AND s.age > 0 AND m['env'] = 'prod' + testComplexTypeResolution("Mixed all column kinds", "SELECT * FROM test.complex WHERE " + + "a >= 0 AND a <= 1000 AND LOWER(s.name) = 'test' AND s.age > 0 AND m['env'] = 'prod'", pathDomains -> { + // Flat column: a in [0, 1000] + AccessPath colA = AccessPath.of(0); + assertTrue(pathDomains.containsKey(colA), "Should resolve column a"); + IntegerDomain aDomain = (IntegerDomain) pathDomains.get(colA); + assertTrue(aDomain.contains(0), "a should contain 0"); + assertTrue(aDomain.contains(500), "a should contain 500"); + assertTrue(aDomain.contains(1000), "a should contain 1000"); + assertFalse(aDomain.contains(-1), "a should not contain -1"); + assertFalse(aDomain.contains(1001), "a should not contain 1001"); + + // Struct string field: s.name (case-insensitive 'test') + AccessPath sName = AccessPath.ofField(3, "name"); + assertTrue(pathDomains.containsKey(sName), "Should resolve s.name"); + RegexDomain nameDomain = (RegexDomain) pathDomains.get(sName); + List nameExamples = nameDomain.sample(3); + for (Object ex : nameExamples) { + String str = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("test", str.toLowerCase(), "Should be case-insensitive 'test'"); + } + + // Struct int field: s.age > 0 + AccessPath sAge = AccessPath.ofField(3, "age"); + assertTrue(pathDomains.containsKey(sAge), "Should resolve s.age"); + IntegerDomain ageDomain = (IntegerDomain) pathDomains.get(sAge); + assertTrue(ageDomain.contains(1), "s.age should contain 1"); + assertFalse(ageDomain.contains(0), "s.age should not contain 0"); + + // Map access: m['env'] = 'prod' + AccessPath mEnv = AccessPath.ofMapKey(4, "env"); + assertTrue(pathDomains.containsKey(mEnv), "Should resolve m['env']"); + assertTrue(pathDomains.get(mEnv) instanceof RegexDomain, "m['env'] should be RegexDomain"); + }); + } + + @Test + public void testDeepNestedStructWithArithmeticAndFunction() { + // CAST(nested_struct.sub.count * 3 - 10 AS STRING) = '50' + // Chains: CAST → MINUS → TIMES → RexFieldAccess(RexFieldAccess($1, sub), count) + // → count * 3 - 10 = 50 → count * 3 = 60 → count = 20 + testDeepTypeResolution("Deep struct with arithmetic + CAST", + "SELECT * FROM test.deep WHERE CAST(nested_struct.sub.count * 3 - 10 AS STRING) = '50'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1 && path.getPath().size() == 2 + && "count".equals(path.getPath().get(1).getFieldName())) { + found = true; + IntegerDomain domain = (IntegerDomain) entry.getValue(); + assertTrue(domain.contains(20), "count should contain 20 (since 20*3-10=50)"); + assertTrue(domain.isSingleton(), "count should be singleton {20}"); + } + } + assertTrue(found, "Should resolve nested_struct.sub.count through CAST + arithmetic"); + }); + } + + @Test + public void testDeepMixedWithOrConjunctions() { + // (nested_struct.sub.count > 50 AND id < 100) OR (nested_struct.sub.count = 10 AND id > 200) + // Tests: double-nested struct field, flat column, OR of conjunctions, union + intersection + testDeepTypeResolution("Deep struct OR conjunctions", + "SELECT * FROM test.deep WHERE " + + "(nested_struct.sub.count > 50 AND id < 100) OR (nested_struct.sub.count = 10 AND id > 200)", + pathDomains -> { + // id: union of [MIN, 99] and [201, MAX] + AccessPath colId = AccessPath.of(0); + assertTrue(pathDomains.containsKey(colId), "Should resolve id"); + IntegerDomain idDomain = (IntegerDomain) pathDomains.get(colId); + assertTrue(idDomain.contains(50), "id should contain 50 (from first disjunct)"); + assertTrue(idDomain.contains(99), "id should contain 99 (from first disjunct)"); + assertTrue(idDomain.contains(201), "id should contain 201 (from second disjunct)"); + assertTrue(idDomain.contains(500), "id should contain 500 (from second disjunct)"); + assertFalse(idDomain.contains(100), "id should not contain 100"); + assertFalse(idDomain.contains(150), "id should not contain 150"); + assertFalse(idDomain.contains(200), "id should not contain 200"); + }); + } + + @Test + public void testFunctionOnMapOfStructsField() { + // LOWER(map_of_structs['team'].label) = 'engineering' + testDeepTypeResolution("Function on map-of-structs field", + "SELECT * FROM test.deep WHERE LOWER(map_of_structs['team'].label) = 'engineering'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 2 && !path.isFlat()) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + List examples = domain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("engineering", s.toLowerCase(), "Should be case-insensitive 'engineering': " + s); + } + } + } + assertTrue(found, "Should resolve LOWER(map_of_structs['team'].label)"); + }); + } + + @Test + public void testAbsOnStructFieldWithOrDisjunction() { + // ABS(s.age) = 5 OR s.age > 100: struct field in both disjuncts + testComplexTypeResolution("ABS on struct field with OR", + "SELECT * FROM test.complex WHERE ABS(s.age) = 5 OR s.age > 100", pathDomains -> { + AccessPath sAge = AccessPath.ofField(3, "age"); + assertTrue(pathDomains.containsKey(sAge), "Should resolve s.age"); + IntegerDomain domain = (IntegerDomain) pathDomains.get(sAge); + // ABS(age) = 5 → age = 5 or age = -5; union with age > 100 + assertTrue(domain.contains(5), "Should contain 5"); + assertTrue(domain.contains(-5), "Should contain -5"); + assertTrue(domain.contains(101), "Should contain 101 (from age > 100)"); + assertFalse(domain.contains(0), "Should not contain 0"); + assertFalse(domain.contains(50), "Should not contain 50"); + }); + } + + @Test + public void testSubstringOnMapValueWithCast() { + // SUBSTRING(m['date'], 1, 4) = '2024': function chain on map value + testComplexTypeResolution("SUBSTRING on map value", + "SELECT * FROM test.complex WHERE SUBSTRING(m['date'], 1, 4) = '2024'", pathDomains -> { + AccessPath mDate = AccessPath.ofMapKey(4, "date"); + assertTrue(pathDomains.containsKey(mDate), "Should resolve m['date']"); + RegexDomain domain = (RegexDomain) pathDomains.get(mDate); + List examples = domain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertTrue(s.startsWith("2024"), "Should start with '2024': " + s); + } + }); + } + + // ========== Interleaved complex types (struct↔map↔array) ========== + + /** + * Helper for test.interleaved table. + * Columns: id(0:INT), + * sm(1:STRUCT, scores:ARRAY, name:STRING>), + * ams(2:ARRAY>), + * msa(3:MAP>), + * amss(4:ARRAY, label:STRING, value:INT>>) + */ + private void testInterleavedResolution(String testName, String sql, MultiColumnAssertion assertion) { + RelNode normalized = convertAndNormalizeQuery(sql); + DnfRewriter.Output dnfOut = extractPredicatesAsDnf(normalized); + assertFalse(dnfOut.disjuncts.isEmpty(), testName + ": should have at least one disjunct"); + + Map> pathDomains = program.resolveAllPaths(dnfOut.disjuncts); + assertFalse(pathDomains.isEmpty(), testName + ": should resolve at least one column"); + + assertion.accept(pathDomains); + } + + @Test + public void testStructContainingMapAccess() { + // sm.tags['color'] = 'red': struct → map key + // Path: $1.tags → ITEM(..., 'color') + testInterleavedResolution("Struct containing map access", + "SELECT * FROM test.interleaved WHERE sm.tags['color'] = 'red'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1 && path.getPath().size() >= 2) { + // Should have FIELD:tags then MAP_KEY:color + boolean hasTags = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.FIELD && "tags".equals(e.getFieldName())); + boolean hasColor = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.MAP_KEY && "color".equals(e.getMapKey())); + if (hasTags && hasColor) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "sm.tags['color'] should be literal 'red'"); + assertEquals(domain.getLiteralValue(), "red"); + } + } + } + assertTrue(found, "Should resolve sm.tags['color'] (struct → map)"); + }); + } + + @Test + public void testStructContainingMapWithFunction() { + // LOWER(sm.tags['city']) = 'seattle': LOWER → struct → map + testInterleavedResolution("Function on struct→map", + "SELECT * FROM test.interleaved WHERE LOWER(sm.tags['city']) = 'seattle'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1 && !path.isFlat()) { + boolean hasTags = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.FIELD && "tags".equals(e.getFieldName())); + boolean hasCity = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.MAP_KEY && "city".equals(e.getMapKey())); + if (hasTags && hasCity) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + List examples = domain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("seattle", s.toLowerCase(), "Should be case-insensitive 'seattle': " + s); + } + } + } + } + assertTrue(found, "Should resolve LOWER(sm.tags['city']) (LOWER → struct → map)"); + }); + } + + @Test + public void testArrayOfMapsAccess() { + // ams[0]['status'] = 'active': array → map key + // Path: ITEM($2, 1) → ITEM(..., 'status') + testInterleavedResolution("Array of maps access", + "SELECT * FROM test.interleaved WHERE ams[0]['status'] = 'active'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 2 && path.getPath().size() >= 2) { + boolean hasArrayIdx = + path.getPath().stream().anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.ARRAY_INDEX); + boolean hasMapKey = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.MAP_KEY && "status".equals(e.getMapKey())); + if (hasArrayIdx && hasMapKey) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "ams[0]['status'] should be literal 'active'"); + assertEquals(domain.getLiteralValue(), "active"); + } + } + } + assertTrue(found, "Should resolve ams[0]['status'] (array → map)"); + }); + } + + @Test + public void testArrayOfStructsContainingMapAccess() { + // amss[0].props['env'] = 'prod': array → struct → map + // Path: ITEM($4, 1).props → ITEM(..., 'env') + testInterleavedResolution("Array → struct → map", + "SELECT * FROM test.interleaved WHERE amss[0].props['env'] = 'prod'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 4 && path.getPath().size() >= 3) { + boolean hasArrayIdx = + path.getPath().stream().anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.ARRAY_INDEX); + boolean hasProps = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.FIELD && "props".equals(e.getFieldName())); + boolean hasEnv = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.MAP_KEY && "env".equals(e.getMapKey())); + if (hasArrayIdx && hasProps && hasEnv) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + assertTrue(domain.isLiteral(), "amss[0].props['env'] should be literal 'prod'"); + assertEquals(domain.getLiteralValue(), "prod"); + } + } + } + assertTrue(found, "Should resolve amss[0].props['env'] (array → struct → map)"); + }); + } + + @Test + public void testArrayOfStructsMapWithFunction() { + // UPPER(amss[0].props['tag']) = 'CRITICAL': UPPER → array → struct → map + testInterleavedResolution("Function on array → struct → map", + "SELECT * FROM test.interleaved WHERE UPPER(amss[0].props['tag']) = 'CRITICAL'", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 4 && path.getPath().size() >= 3) { + boolean hasTag = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.MAP_KEY && "tag".equals(e.getMapKey())); + if (hasTag) { + found = true; + RegexDomain domain = (RegexDomain) entry.getValue(); + List examples = domain.sample(5); + for (Object ex : examples) { + String s = ex.toString().replaceAll("^\\^", "").replaceAll("\\$$", ""); + assertEquals("critical", s.toLowerCase(), "Should be case-insensitive 'critical': " + s); + } + } + } + } + assertTrue(found, "Should resolve UPPER(amss[0].props['tag'])"); + }); + } + + @Test + public void testArrayOfStructsIntFieldWithArithmetic() { + // amss[0].value * 2 + 1 = 101: arithmetic → array → struct (int field) + // → value * 2 = 100 → value = 50 + testInterleavedResolution("Arithmetic on array → struct int field", + "SELECT * FROM test.interleaved WHERE amss[0].value * 2 + 1 = 101", pathDomains -> { + boolean found = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 4 && !path.isFlat()) { + boolean hasValue = path.getPath().stream() + .anyMatch(e -> e.getKind() == AccessPath.PathElement.Kind.FIELD && "value".equals(e.getFieldName())); + if (hasValue) { + found = true; + IntegerDomain domain = (IntegerDomain) entry.getValue(); + assertTrue(domain.contains(50), "value should contain 50 (since 50*2+1=101)"); + assertTrue(domain.isSingleton(), "value should be singleton {50}"); + } + } + } + assertTrue(found, "Should resolve amss[0].value through arithmetic"); + }); + } + + @Test + public void testInterleavedMultiPathAndOr() { + // Combines struct→map, array→struct→map, and flat column with AND/OR: + // (sm.tags['env'] = 'prod' AND amss[0].props['tier'] = 'p0' AND id > 100) + // OR (sm.name = 'fallback' AND id = 0) + testInterleavedResolution("Interleaved multi-path AND/OR", + "SELECT * FROM test.interleaved WHERE " + + "(sm.tags['env'] = 'prod' AND amss[0].props['tier'] = 'p0' AND id > 100) " + + "OR (sm.name = 'fallback' AND id = 0)", + pathDomains -> { + // id: union of [101, MAX] and {0} + AccessPath colId = AccessPath.of(0); + assertTrue(pathDomains.containsKey(colId), "Should resolve id"); + IntegerDomain idDomain = (IntegerDomain) pathDomains.get(colId); + assertTrue(idDomain.contains(0), "id should contain 0 (from second disjunct)"); + assertTrue(idDomain.contains(101), "id should contain 101 (from first disjunct)"); + assertTrue(idDomain.contains(999), "id should contain 999 (from first disjunct)"); + assertFalse(idDomain.contains(1), "id should not contain 1"); + assertFalse(idDomain.contains(100), "id should not contain 100"); + }); + } + + @Test + public void testStructFieldAndStructMapSameColumn() { + // sm.name = 'test' AND sm.tags['role'] = 'admin': both paths on same root column (1), + // but different AccessPaths (FIELD:name vs FIELD:tags + MAP_KEY:role) + testInterleavedResolution("Struct field and struct→map on same column", + "SELECT * FROM test.interleaved WHERE sm.name = 'test' AND sm.tags['role'] = 'admin'", pathDomains -> { + // Should produce two distinct AccessPath entries both rooted at column 1 + boolean foundName = false; + boolean foundTagsRole = false; + for (Map.Entry> entry : pathDomains.entrySet()) { + AccessPath path = entry.getKey(); + if (path.getColumnIndex() == 1) { + if (path.getPath().size() == 1 && "name".equals(path.getPath().get(0).getFieldName())) { + foundName = true; + RegexDomain nameDomain = (RegexDomain) entry.getValue(); + assertTrue(nameDomain.isLiteral(), "sm.name should be literal 'test'"); + assertEquals(nameDomain.getLiteralValue(), "test"); + } + if (path.getPath().size() == 2) { + boolean hasTags = path.getPath().get(0).getKind() == AccessPath.PathElement.Kind.FIELD + && "tags".equals(path.getPath().get(0).getFieldName()); + boolean hasRole = path.getPath().get(1).getKind() == AccessPath.PathElement.Kind.MAP_KEY + && "role".equals(path.getPath().get(1).getMapKey()); + if (hasTags && hasRole) { + foundTagsRole = true; + RegexDomain roleDomain = (RegexDomain) entry.getValue(); + assertTrue(roleDomain.isLiteral(), "sm.tags['role'] should be literal 'admin'"); + assertEquals(roleDomain.getLiteralValue(), "admin"); + } + } + } + } + assertTrue(foundName, "Should resolve sm.name"); + assertTrue(foundTagsRole, "Should resolve sm.tags['role']"); + }); + } + } diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverterTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverterTest.java index 03d1f15e7..70b83e465 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverterTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexToIntegerDomainConverterTest.java @@ -66,16 +66,11 @@ public void testLiteralZero() { assertFalse(result.contains(1)); } - @Test - public void testLiteralWithLeadingZeros() { - // ^009$ should convert to integer 9 - RegexDomain regex = new RegexDomain("^009$"); - assertTrue(converter.isConvertible(regex)); - - IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(9)); - assertFalse(result.contains(8)); - assertFalse(result.contains(10)); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testLiteralWithLeadingZerosIsNonCanonical() { + // ^009$ only accepts the string "009", which is not the canonical decimal form of any integer + // (9's canonical form is "9"). Strict converter rejects. + converter.convert(new RegexDomain("^009$")); } // ==================== Character Class Tests ==================== @@ -92,59 +87,72 @@ public void testCharClassSingleDigit() { assertFalse(result.contains(10)); } + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testCharClassTwoDigitsIsNonCanonical() { + // ^[0-9][0-9]$ admits "00", "01", … which are not canonical decimal for any integer. + converter.convert(new RegexDomain("^[0-9][0-9]$")); + } + @Test - public void testCharClassTwoDigits() { - RegexDomain regex = new RegexDomain("^[0-9][0-9]$"); + public void testCanonicalTwoDigit() { + // The canonical form of all 2-digit integers: leading digit 1-9, then any digit. + RegexDomain regex = new RegexDomain("^[1-9][0-9]$"); assertTrue(converter.isConvertible(regex)); IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); + assertTrue(result.contains(10)); assertTrue(result.contains(50)); assertTrue(result.contains(99)); + assertFalse(result.contains(9)); assertFalse(result.contains(100)); } + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testCharClassThreeDigitsIsNonCanonical() { + // ^[0-9]{3}$ admits "000", "001", … — none of those are canonical decimals. + converter.convert(new RegexDomain("^[0-9]{3}$")); + } + @Test - public void testCharClassThreeDigits() { - RegexDomain regex = new RegexDomain("^[0-9]{3}$"); + public void testCanonicalThreeDigit() { + // Canonical 3-digit integers: leading digit 1-9, then two digits. + RegexDomain regex = new RegexDomain("^[1-9][0-9]{2}$"); assertTrue(converter.isConvertible(regex)); IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); + assertTrue(result.contains(100)); assertTrue(result.contains(500)); assertTrue(result.contains(999)); + assertFalse(result.contains(99)); assertFalse(result.contains(1000)); } // ==================== Bounded Repetition Tests ==================== + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testExactRepetitionIsNonCanonical() { + // ^[0-9]{4}$ admits "0000", "0001", … which are not canonical. + converter.convert(new RegexDomain("^[0-9]{4}$")); + } + @Test - public void testExactRepetition() { - RegexDomain regex = new RegexDomain("^[0-9]{4}$"); + public void testCanonicalFourDigit() { + // Canonical 4-digit integers: 1000..9999. + RegexDomain regex = new RegexDomain("^[1-9][0-9]{3}$"); assertTrue(converter.isConvertible(regex)); IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); + assertTrue(result.contains(1000)); assertTrue(result.contains(1234)); assertTrue(result.contains(9999)); + assertFalse(result.contains(999)); assertFalse(result.contains(10000)); } - @Test - public void testRangeRepetition() { - // {2,3} means 2 or 3 digits - RegexDomain regex = new RegexDomain("^[0-9]{2,3}$"); - assertTrue(converter.isConvertible(regex)); - - IntegerDomain result = converter.convert(regex); - // Note: 00-09 are included (which are 0-9 as integers) - assertTrue(result.contains(0)); // 00 -> 0 - assertTrue(result.contains(9)); // 09 -> 9 - assertTrue(result.contains(10)); // 2 digits - assertTrue(result.contains(99)); // 2 digits - assertTrue(result.contains(100)); // 3 digits - assertTrue(result.contains(999)); // 3 digits - assertFalse(result.contains(1000)); // 4 digits + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testRangeRepetitionIsNonCanonical() { + // ^[0-9]{2,3}$ admits "00", "000", "01", … (non-canonical). + converter.convert(new RegexDomain("^[0-9]{2,3}$")); } @Test @@ -174,16 +182,23 @@ public void testPrefixPattern() { assertFalse(result.contains(2000)); } + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testSuffixPatternIsNonCanonical() { + // ^[0-9]{2}99$ admits "0099", which is not canonical (canonical for 99 is "99"). + converter.convert(new RegexDomain("^[0-9]{2}99$")); + } + @Test - public void testSuffixPattern() { - // ^[0-9]{2}99$ matches 0099, 0199, ..., 9999 - RegexDomain regex = new RegexDomain("^[0-9]{2}99$"); + public void testCanonicalSuffixPattern() { + // Canonical 4-digit values ending in 99: ^[1-9][0-9]99$ matches 1099, 1199, …, 9999. + RegexDomain regex = new RegexDomain("^[1-9][0-9]99$"); assertTrue(converter.isConvertible(regex)); IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(99)); + assertTrue(result.contains(1099)); assertTrue(result.contains(1999)); assertTrue(result.contains(9999)); + assertFalse(result.contains(99)); assertFalse(result.contains(1998)); assertFalse(result.contains(2000)); } @@ -276,15 +291,10 @@ public void testUnboundedPlus() { converter.convert(regex); } - @Test - public void testOptionalDigit() { - // [0-9]? accepts "" and "0"-"9". The automaton is finite and digit-only, - // so it is convertible. Empty string maps to nothing; digits 0-9 convert normally. - RegexDomain regex = new RegexDomain("^[0-9]?$"); - assertTrue(converter.isConvertible(regex)); - IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); - assertTrue(result.contains(9)); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testOptionalDigitIsNonCanonical() { + // ^[0-9]?$ admits the empty string, which is not a canonical decimal integer. + converter.convert(new RegexDomain("^[0-9]?$")); } @Test(expectedExceptions = NonConvertibleDomainException.class) @@ -324,23 +334,24 @@ public void testDecimalPoint() { } @Test - public void testMissingStartAnchor() { - // Anchors are stripped by RegexDomain; brics matches whole string implicitly. - // So [0-9]{3}$ is equivalent to [0-9]{3} which is convertible. - RegexDomain regex = new RegexDomain("[0-9]{3}$"); - assertTrue(converter.isConvertible(regex)); - } + public void testAnchorsAreIgnored() { + // brics matches the whole string implicitly, so ^...$ anchors are no-ops. + // The three forms below should produce identical results on a canonical input. + RegexDomain anchored = new RegexDomain("^(100|999)$"); + RegexDomain missingStart = new RegexDomain("(100|999)$"); + RegexDomain missingEnd = new RegexDomain("^(100|999)"); + RegexDomain noAnchors = new RegexDomain("(100|999)"); - @Test - public void testMissingEndAnchor() { - RegexDomain regex = new RegexDomain("^[0-9]{3}"); - assertTrue(converter.isConvertible(regex)); - } + IntegerDomain a = converter.convert(anchored); + IntegerDomain b = converter.convert(missingStart); + IntegerDomain c = converter.convert(missingEnd); + IntegerDomain d = converter.convert(noAnchors); - @Test - public void testNoAnchors() { - RegexDomain regex = new RegexDomain("[0-9]{3}"); - assertTrue(converter.isConvertible(regex)); + for (IntegerDomain domain : new IntegerDomain[] { a, b, c, d }) { + assertTrue(domain.contains(100)); + assertTrue(domain.contains(999)); + assertFalse(domain.contains(101)); + } } @Test(expectedExceptions = NonConvertibleDomainException.class) @@ -357,28 +368,34 @@ public void testLookbehind() { @Test public void testBackreference() { - // Backreferences like \1 may not be detectable by our simple parser - // This is a known limitation - we'll test that the pattern doesn't cause a crash + // dk.brics.automaton does not support backreferences (\1 etc.); we expect either a + // construction-time failure from the regex parser, or a downstream conversion failure. try { RegexDomain regex = new RegexDomain("^([0-9])\\1$"); - boolean convertible = converter.isConvertible(regex); - // Either it's rejected (good) or accepted (parser limitation) - // As long as it doesn't crash, we're OK - } catch (Exception e) { - // Expected - backreference should ideally be rejected - assertTrue(e instanceof NonConvertibleDomainException || e instanceof IllegalArgumentException, - "Expected NonConvertibleDomainException or IllegalArgumentException but got: " + e.getClass()); + if (converter.isConvertible(regex)) { + // If reported convertible, the converted domain must at least be a valid IntegerDomain + // (i.e., not crash on use). We don't pin specific contents — that depends on how the + // parser ultimately handled the \1. + IntegerDomain domain = converter.convert(regex); + assertNotNull(domain); + } else { + // Reported non-convertible: convert() should throw. + try { + converter.convert(regex); + fail("convert() should throw when isConvertible reports false"); + } catch (NonConvertibleDomainException expected) { + // ok + } + } + } catch (IllegalArgumentException e) { + // Acceptable: regex parser rejected the backreference at construction time. } } - @Test - public void testShorthandDigitClass() { - // \d is now translated to [0-9] by RegexDomain.parseRegex, so this is convertible. - RegexDomain regex = new RegexDomain("^\\d{3}$"); - assertTrue(converter.isConvertible(regex)); - IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); - assertTrue(result.contains(999)); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testShorthandDigitClassIsNonCanonical() { + // \d is translated to [0-9], so ^\d{3}$ is equivalent to ^[0-9]{3}$ — admits "000" etc. + converter.convert(new RegexDomain("^\\d{3}$")); } @Test(expectedExceptions = NonConvertibleDomainException.class) @@ -401,14 +418,10 @@ public void testInvalidCharClassWithDigits() { converter.convert(regex); } - @Test - public void testEmptyPattern() { - // Empty regex accepts only the empty string "". - // The converter treats it as convertible (finite, digit-only vacuously) - // but produces an empty IntegerDomain since "" is not a valid integer. - RegexDomain regex = new RegexDomain(""); - IntegerDomain result = converter.convert(regex); - assertTrue(result.isEmpty()); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testEmptyPatternIsNonCanonical() { + // Empty regex accepts only "", which is not a canonical decimal integer. + converter.convert(new RegexDomain("")); } @Test(expectedExceptions = { NonConvertibleDomainException.class, NullPointerException.class }) @@ -429,30 +442,24 @@ public void testSingleZero() { assertFalse(result.contains(1)); } - @Test - public void testLeadingZeros() { - // ^00[0-9]$ matches 000-009, which are 0-9 as integers - RegexDomain regex = new RegexDomain("^00[0-9]$"); - assertTrue(converter.isConvertible(regex)); - - IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); - assertTrue(result.contains(5)); - assertTrue(result.contains(9)); - assertFalse(result.contains(10)); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testLeadingZerosIsNonCanonical() { + // ^00[0-9]$ admits "000", "001", … (non-canonical). + converter.convert(new RegexDomain("^00[0-9]$")); } @Test - public void testLargeSingleRange() { - // ^[0-9]{6}$ is a large domain (1M values) - // Should use bounds computation instead of enumeration - RegexDomain regex = new RegexDomain("^[0-9]{6}$"); + public void testLargeCanonicalRange() { + // Canonical 6-digit values: ^[1-9][0-9]{5}$ matches 100000..999999. + // Large enough that the converter uses bounds computation instead of enumeration. + RegexDomain regex = new RegexDomain("^[1-9][0-9]{5}$"); assertTrue(converter.isConvertible(regex)); IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); + assertTrue(result.contains(100000)); assertTrue(result.contains(500000)); assertTrue(result.contains(999999)); + assertFalse(result.contains(99999)); assertFalse(result.contains(1000000)); } @@ -533,48 +540,70 @@ public void testComplexYearRange() { assertFalse(result.contains(2100)); } - @Test - public void testPhoneNumberLastFourDigits() { - // Last 4 digits of phone number - RegexDomain regex = new RegexDomain("^[0-9]{4}$"); - assertTrue(converter.isConvertible(regex)); - - IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); - assertTrue(result.contains(5555)); - assertTrue(result.contains(9999)); - assertFalse(result.contains(10000)); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testZeroPaddedFourDigitFieldIsNonCanonical() { + // A fixed-width 4-digit field (e.g., phone-number last-4) is naturally a string-domain + // constraint, not an integer one, because "0042" is a valid value for the field. The + // converter rejects this regex so callers don't accidentally interpret it as an integer. + converter.convert(new RegexDomain("^[0-9]{4}$")); } - @Test - public void testZipCodeFirstThree() { - // First 3 digits of US zip code - RegexDomain regex = new RegexDomain("^[0-9]{3}$"); - assertTrue(converter.isConvertible(regex)); - - IntegerDomain result = converter.convert(regex); - assertTrue(result.contains(0)); - assertTrue(result.contains(500)); - assertTrue(result.contains(999)); + @Test(expectedExceptions = NonConvertibleDomainException.class) + public void testZipCodeFirstThreeIsNonCanonical() { + // ZIP-3 prefixes like "007" are valid strings but not canonical integer decimals. + converter.convert(new RegexDomain("^[0-9]{3}$")); } // ==================== isConvertible Tests ==================== @Test public void testIsConvertibleValidPatterns() { + // isConvertible should agree with convert() succeeding and producing the expected domain. + // Every example here is canonical-only (no leading-zero strings, no empty string). assertTrue(converter.isConvertible(new RegexDomain("^123$"))); + assertEquals(converter.convert(new RegexDomain("^123$")), IntegerDomain.of(123)); + assertTrue(converter.isConvertible(new RegexDomain("^[0-9]$"))); - assertTrue(converter.isConvertible(new RegexDomain("^[0-9]{3}$"))); + IntegerDomain singleDigit = converter.convert(new RegexDomain("^[0-9]$")); + assertTrue(singleDigit.contains(0)); + assertTrue(singleDigit.contains(9)); + assertFalse(singleDigit.contains(10)); + + assertTrue(converter.isConvertible(new RegexDomain("^[1-9][0-9]{2}$"))); + IntegerDomain threeDigit = converter.convert(new RegexDomain("^[1-9][0-9]{2}$")); + assertTrue(threeDigit.contains(100)); + assertTrue(threeDigit.contains(999)); + assertFalse(threeDigit.contains(99)); + assertFalse(threeDigit.contains(1000)); + assertTrue(converter.isConvertible(new RegexDomain("^(10|20|30)$"))); + IntegerDomain alts = converter.convert(new RegexDomain("^(10|20|30)$")); + assertTrue(alts.contains(10) && alts.contains(20) && alts.contains(30)); + assertFalse(alts.contains(15)); + assertTrue(converter.isConvertible(new RegexDomain("^19[0-9]{2}$"))); + IntegerDomain years = converter.convert(new RegexDomain("^19[0-9]{2}$")); + assertTrue(years.contains(1900) && years.contains(1999)); + assertFalse(years.contains(1899) || years.contains(2000)); } @Test public void testIsConvertibleInvalidPatterns() { - assertFalse(converter.isConvertible(new RegexDomain("^[0-9]+$"))); // infinite - assertFalse(converter.isConvertible(new RegexDomain("^[0-9]*$"))); // infinite - assertFalse(converter.isConvertible(new RegexDomain("^abc$"))); // non-digit - assertFalse(converter.isConvertible(new RegexDomain("^[0-9]{3,}$"))); // infinite + // isConvertible should agree with convert() throwing NonConvertibleDomainException. + RegexDomain[] invalid = { new RegexDomain("^[0-9]+$"), // unbounded + new RegexDomain("^[0-9]*$"), // unbounded + new RegexDomain("^abc$"), // non-digit + new RegexDomain("^[0-9]{3,}$") // unbounded + }; + for (RegexDomain regex : invalid) { + assertFalse(converter.isConvertible(regex), "should report not convertible: " + regex); + try { + converter.convert(regex); + fail("convert() should have thrown for non-convertible regex: " + regex); + } catch (NonConvertibleDomainException expected) { + // ok + } + } } // ==================== Integration Tests ==================== @@ -597,25 +626,26 @@ public void testRoundTrip_SmallDomain() { @Test public void testRoundTrip_MediumDomain() { - // 100 values: 0-99 - RegexDomain regex = new RegexDomain("^[0-9]{2}$"); + // Canonical 2-digit range: 10-99 (90 values). + RegexDomain regex = new RegexDomain("^[1-9][0-9]$"); IntegerDomain domain = converter.convert(regex); - assertTrue(domain.contains(0)); + assertTrue(domain.contains(10)); assertTrue(domain.contains(50)); assertTrue(domain.contains(99)); + assertFalse(domain.contains(9)); assertFalse(domain.contains(100)); } @Test public void testBoundaryValues() { - RegexDomain regex = new RegexDomain("^[0-9]{3}$"); + // Canonical 3-digit range: 100-999. + RegexDomain regex = new RegexDomain("^[1-9][0-9]{2}$"); IntegerDomain domain = converter.convert(regex); - // Test boundaries - assertTrue(domain.contains(0)); // Min - assertTrue(domain.contains(999)); // Max - assertFalse(domain.contains(-1)); // Below min - assertFalse(domain.contains(1000)); // Above max + assertTrue(domain.contains(100), "Min"); + assertTrue(domain.contains(999), "Max"); + assertFalse(domain.contains(99), "Below min"); + assertFalse(domain.contains(1000), "Above max"); } } diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexTransformerTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexTransformerTest.java new file mode 100644 index 000000000..9e03de10e --- /dev/null +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/domain/RegexTransformerTest.java @@ -0,0 +1,136 @@ +/** + * Copyright 2025-2026 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.datagen.domain; + +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import com.linkedin.coral.datagen.domain.transformer.ConcatRegexTransformer; + +import static org.testng.Assert.*; + + +/** + * Tests for regex domain transformers (Concat). Trim is covered end-to-end via + * the SQL-driven integration tests in {@link RegexDomainInferenceProgramTest}, since + * constructing a Calcite TRIM RexCall requires a SqlTrimFunction.Flag operand that is + * awkward to build without a parser. Lower/Upper/Substring are also covered there. + */ +public class RegexTransformerTest { + + private RexBuilder rexBuilder; + private RelDataTypeFactory typeFactory; + + @BeforeMethod + public void setup() { + rexBuilder = TestHelper.createRexBuilder(); + typeFactory = rexBuilder.getTypeFactory(); + } + + private RexNode stringRef() { + return rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0); + } + + private RexNode stringLit(String value) { + return rexBuilder.makeLiteral(value); + } + + // ==================== Concat Transformer ==================== + + @Test + public void testConcatVariablePlusLiteralSuffix() { + // CONCAT(x, 'World') = 'HelloWorld' => x = 'Hello' + ConcatRegexTransformer transformer = new ConcatRegexTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, stringRef(), stringLit("World")); + + assertTrue(transformer.canHandle(expr)); + assertTrue(transformer.isVariableOperandPositionValid(expr)); + assertEquals(transformer.getChildForVariable(expr), ((RexCall) expr).getOperands().get(0)); + + Domain result = transformer.refineInputDomain(expr, RegexDomain.literal("HelloWorld")); + assertTrue(result instanceof RegexDomain); + RegexDomain regexResult = (RegexDomain) result; + assertTrue(regexResult.isLiteral()); + assertEquals(regexResult.getLiteralValue(), "Hello"); + } + + @Test + public void testConcatLiteralPrefixPlusVariable() { + // CONCAT('Hello', x) = 'HelloWorld' => x = 'World' + ConcatRegexTransformer transformer = new ConcatRegexTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, stringLit("Hello"), stringRef()); + + assertTrue(transformer.canHandle(expr)); + assertTrue(transformer.isVariableOperandPositionValid(expr)); + + Domain result = transformer.refineInputDomain(expr, RegexDomain.literal("HelloWorld")); + assertTrue(result instanceof RegexDomain); + RegexDomain regexResult = (RegexDomain) result; + assertTrue(regexResult.isLiteral()); + assertEquals(regexResult.getLiteralValue(), "World"); + } + + @Test + public void testConcatSuffixMismatchProducesEmpty() { + // CONCAT(x, 'World') = 'HelloEarth' => no valid input (suffix mismatch) + ConcatRegexTransformer transformer = new ConcatRegexTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, stringRef(), stringLit("World")); + + Domain result = transformer.refineInputDomain(expr, RegexDomain.literal("HelloEarth")); + assertTrue(result instanceof RegexDomain); + assertTrue(result.isEmpty()); + } + + @Test + public void testConcatPrefixMismatchProducesEmpty() { + // CONCAT('Hello', x) = 'GreetingsWorld' => no valid input (prefix mismatch) + ConcatRegexTransformer transformer = new ConcatRegexTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, stringLit("Hello"), stringRef()); + + Domain result = transformer.refineInputDomain(expr, RegexDomain.literal("GreetingsWorld")); + assertTrue(result instanceof RegexDomain); + assertTrue(result.isEmpty()); + } + + @Test + public void testConcatEmptySuffixIsIdentity() { + // CONCAT(x, '') = 'Hello' => x = 'Hello' + ConcatRegexTransformer transformer = new ConcatRegexTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, stringRef(), stringLit("")); + + Domain result = transformer.refineInputDomain(expr, RegexDomain.literal("Hello")); + assertTrue(result instanceof RegexDomain); + RegexDomain regexResult = (RegexDomain) result; + assertTrue(regexResult.isLiteral()); + assertEquals(regexResult.getLiteralValue(), "Hello"); + } + + @Test + public void testConcatNonLiteralOutputReturnsAsIs() { + // CONCAT(x, 'World') with a non-literal regex output: ConcatRegexTransformer falls + // back to a sound but lossy passthrough (returns the output regex unchanged). Verify + // the passthrough by checking the returned automaton accepts the same strings as the input. + ConcatRegexTransformer transformer = new ConcatRegexTransformer(); + RexNode expr = rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, stringRef(), stringLit("World")); + + RegexDomain nonLiteral = new RegexDomain("^[A-Z][a-z]+World$"); + Domain result = transformer.refineInputDomain(expr, nonLiteral); + assertTrue(result instanceof RegexDomain); + RegexDomain regexResult = (RegexDomain) result; + + // Passthrough: should accept whatever the input pattern accepted. + assertTrue(regexResult.getAutomaton().run("HelloWorld"), "should still accept 'HelloWorld'"); + assertTrue(regexResult.getAutomaton().run("AbcWorld"), "should still accept 'AbcWorld'"); + assertFalse(regexResult.getAutomaton().run("helloWorld"), "should reject lowercase first letter"); + assertFalse(regexResult.getAutomaton().run("Hello"), "should reject missing 'World' suffix"); + } +} diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/CanonicalPredicateExtractorTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/CanonicalPredicateExtractorTest.java index 223d61d0b..0bcfe049a 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/CanonicalPredicateExtractorTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/CanonicalPredicateExtractorTest.java @@ -180,7 +180,8 @@ public void testExtractorTraversesProject() { @Test public void testDnfRewriterWithExtractedPredicates() { - // Filter(a = 'x') -> Scan(T1) + // Filter(a = 'x') -> Scan(T1). DNF of a single equality is one conjunctive disjunct + // referencing field $0 (T1.a). RelNode tree = builder.scan("T1") .filter(builder.call(SqlStdOperatorTable.EQUALS, builder.field("a"), builder.literal("x"))).build(); @@ -188,7 +189,10 @@ public void testDnfRewriterWithExtractedPredicates() { DnfRewriter.Output dnf = DnfRewriter.convert(extracted, tree.getCluster().getRexBuilder()); assertEquals(dnf.sequentialScans.size(), 1); - assertFalse(dnf.disjuncts.isEmpty(), "Should have at least one disjunct"); + assertEquals(dnf.disjuncts.size(), 1, "single equality should yield exactly one DNF disjunct"); + String disjunctStr = dnf.disjuncts.get(0).toString(); + assertTrue(disjunctStr.contains("$0"), "disjunct should reference $0 (T1.a), got: " + disjunctStr); + assertTrue(disjunctStr.contains("'x'"), "disjunct should reference literal 'x', got: " + disjunctStr); } @Test diff --git a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriterTest.java b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriterTest.java index 19a61ccd0..1f5b3e06e 100644 --- a/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriterTest.java +++ b/coral-data-generation/src/test/java/com/linkedin/coral/datagen/rel/ProjectPullUpRewriterTest.java @@ -112,6 +112,7 @@ public void testFilterProjectConditionRewrite() { @Test public void testProjectPreservesRowType() { // After pull-up, the pulled-up Project should have the same row type as the original + // (same field names and types, in the same order). RelNode tree = builder.scan("T1").project(builder.field("a"), builder.field("b")) .filter(builder.call(SqlStdOperatorTable.GREATER_THAN, builder.field("b"), builder.literal(5))).build(); @@ -120,6 +121,12 @@ public void testProjectPreservesRowType() { assertEquals(result.getRowType().getFieldCount(), tree.getRowType().getFieldCount(), "Row type field count should be preserved"); + assertEquals(result.getRowType().getFieldNames(), tree.getRowType().getFieldNames(), + "Row type field names should be preserved (in order)"); + for (int i = 0; i < result.getRowType().getFieldCount(); i++) { + assertEquals(result.getRowType().getFieldList().get(i).getType(), + tree.getRowType().getFieldList().get(i).getType(), "Field type at index " + i + " should match"); + } } // ==================== Join->Project Pull-Up ==================== @@ -144,6 +151,12 @@ public void testJoinLeftProjectPullUpFieldCountPreserved() { assertTrue(joinNode instanceof Join, "Project child should be Join"); Join newJoin = (Join) joinNode; assertFalse(newJoin.getLeft() instanceof Project, "Left should no longer be Project"); + + // The pulled-up Project must reproduce the original row type (names + types). + assertEquals(result.getRowType().getFieldNames(), tree.getRowType().getFieldNames(), + "pulled-up row type should match the original"); + // Both join inputs are now raw scans (3-col T1 + 2-col T2 = 5 cols at join). + assertEquals(newJoin.getRowType().getFieldCount(), 5, "join below pulled-up Project sees raw 3+2 columns"); } @Test @@ -163,6 +176,12 @@ public void testJoinRightProjectPullUpFieldCountPreserved() { assertTrue(result instanceof Project, "Root should be Project after pull-up"); RelNode joinNode = ((Project) result).getInput(); assertTrue(joinNode instanceof Join, "Project child should be Join"); + Join newJoin = (Join) joinNode; + assertFalse(newJoin.getRight() instanceof Project, "Right should no longer be Project"); + + assertEquals(result.getRowType().getFieldNames(), tree.getRowType().getFieldNames(), + "pulled-up row type should match the original"); + assertEquals(newJoin.getRowType().getFieldCount(), 5, "join below pulled-up Project sees raw 3+2 columns"); } @Test @@ -185,16 +204,20 @@ public void testJoinBothProjectsPullUpFieldCountPreserved() { Join newJoin = (Join) joinNode; assertFalse(newJoin.getLeft() instanceof Project, "Left should no longer be Project"); assertFalse(newJoin.getRight() instanceof Project, "Right should no longer be Project"); + + assertEquals(result.getRowType().getFieldNames(), tree.getRowType().getFieldNames(), + "pulled-up row type should match the original"); + assertEquals(newJoin.getRowType().getFieldCount(), 5, "join below pulled-up Project sees raw 3+2 columns"); } // ==================== Field Count Change ==================== @Test public void testJoinLeftProjectPullUpFieldCountChanged() { - // Build: Join with left = Project(a, b) -> Scan(T1: a,b,c), right = Scan(T2) - // Left Project reduces field count 3->2. Previously failed with type mismatch - // because RexInputRef offsets used the Project's output field count (2) instead - // of the new left side's scan field count (3). Fixed in items #5/#6. + // Build: Join with left = Project(a, b) -> Scan(T1: a,b,c), right = Scan(T2: x,y) + // Left Project reduces field count 3->2. Originally the right-side reference $2 ("x" in + // the projected join frame) must shift to $3 once the left expands back to 3 columns; + // verify the rewritten condition uses the corrected offsets ($1 for b, $3 for x). builder.scan("T1").project(builder.field("a"), builder.field("b")); builder.scan("T2"); RelNode tree = builder.join(JoinRelType.INNER, @@ -206,9 +229,18 @@ public void testJoinLeftProjectPullUpFieldCountChanged() { RelNode result = ProjectPullUpRewriter.rewriteOneStep(tree); assertTrue(result instanceof Project, "Root should be Project after pull-up"); - RelNode joinNode = ((Project) result).getInput(); - assertTrue(joinNode instanceof Join, "Project child should be Join"); - assertFalse(((Join) joinNode).getLeft() instanceof Project, "Left should no longer be Project"); + Join newJoin = (Join) ((Project) result).getInput(); + assertFalse(newJoin.getLeft() instanceof Project, "Left should no longer be Project"); + assertEquals(newJoin.getRowType().getFieldCount(), 5, "join below pulled-up Project sees raw 3+2 columns"); + + // The join condition must now reference $1 (T1.b) and $3 (T2.x at offset 3 after left grew + // from 2 to 3 columns). If the rewriter had kept the old offset $2, that would now point + // at T1.c — a column type mismatch. + String condStr = newJoin.getCondition().toString(); + assertTrue(condStr.contains("$1"), "condition should reference $1 (T1.b), got: " + condStr); + assertTrue(condStr.contains("$3"), "condition should reference $3 (T2.x at shifted offset), got: " + condStr); + assertFalse(condStr.contains("$2"), + "condition should NOT reference stale $2 (T2.x at old offset), got: " + condStr); } // ==================== Controller (Fixed Point) ==================== @@ -241,8 +273,8 @@ public void testControllerNoChange() { public void testControllerNestedPullUp() { // Build a tree requiring multiple iterations: // Filter -> Project -> Filter -> Project -> Scan(T1) - // Iteration 1: pulls inner Project above inner Filter - // Iteration 2: inner Project merges/outer pattern matches + // At fixed point, all Projects should be above all Filters and the row type should + // still equal the original (the output row count and field types are preserved). RelNode tree = builder.scan("T1").project(builder.field("a"), builder.field("b")) .filter(builder.call(SqlStdOperatorTable.GREATER_THAN, builder.field("b"), builder.literal(5))) .project(builder.field("a")) @@ -250,9 +282,11 @@ public void testControllerNestedPullUp() { RelNode result = ProjectPullUpController.applyUntilFixedPoint(tree, 100); - // All Projects should be above all Filters at fixed point - // Verify no Filter has a Project child assertNoFilterProjectPattern(result); + assertTrue(result instanceof Project, "Root should be a Project after fixed-point pull-up"); + assertEquals(result.getRowType().getFieldNames(), tree.getRowType().getFieldNames(), + "Row type field names must be preserved by the pull-up"); + assertEquals(result.getRowType().getFieldCount(), 1, "Outer Project keeps only column 'a'"); } // ==================== Helpers ====================