diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java index 186ec6903f91..5874f2b9aa31 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java @@ -17,6 +17,8 @@ package org.apache.calcite.adapter.enumerable; import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.Enumerator; import org.apache.calcite.linq4j.tree.BlockBuilder; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; @@ -28,17 +30,23 @@ import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.ModifiableTable; import org.apache.calcite.util.BuiltInMethod; import org.checkerframework.checker.nullness.qual.Nullable; -import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Deque; +import java.util.HashMap; import java.util.List; +import java.util.ListIterator; +import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; @@ -80,100 +88,394 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits, final BlockBuilder builder = new BlockBuilder(); final Result result = implementor.visitChild(this, 0, (EnumerableRel) getInput(), pref); - Expression childExp = - builder.append( - "child", result.block); + + // Enumerable produced by the input relational expression. + final Expression sourceExp = + builder.append("source", result.block); + + // Variable that will hold the table's mutable backing collection. final ParameterExpression collectionParameter = Expressions.parameter(Collection.class, builder.newName("collection")); + + // Expression that yields the ModifiableTable instance at runtime. final Expression expression = table.getExpression(ModifiableTable.class); requireNonNull(expression, "expression"); // TODO: user error in validator checkArgument( ModifiableTable.class.isAssignableFrom( Types.toClass(expression.getType())), "not assignable from type %s", expression.getType()); + + // collection = table.getModifiableCollection() builder.add( Expressions.declare( Modifier.FINAL, collectionParameter, Expressions.call( expression, - BuiltInMethod.MODIFIABLE_TABLE_GET_MODIFIABLE_COLLECTION - .method))); + BuiltInMethod.MODIFIABLE_TABLE_GET_MODIFIABLE_COLLECTION.method))); + + // Physical row representation of this TableModify node's output + // (a single ROWCOUNT field). + final PhysType physType = + PhysTypeImpl.of( + implementor.getTypeFactory(), + getRowType(), + pref == Prefer.ARRAY ? JavaRowFormat.ARRAY : JavaRowFormat.SCALAR); + + switch (getOperation()) { + case INSERT: + return implementInsert(implementor, result, builder, sourceExp, collectionParameter, physType); + case UPDATE: + return implementUpdate(implementor, builder, sourceExp, collectionParameter, physType); + case DELETE: + return implementDelete(implementor, builder, sourceExp, collectionParameter, physType); + default: + throw new AssertionError("unsupported operation: " + getOperation()); + } + } + + /** Generates code for an UPDATE statement. + * + *

Applies updates to matching rows in the backing collection and returns + * the number of updated rows as a single-element enumerable. + * + * @param implementor code-generation context + * @param builder block under construction + * @param sourceExp enumerable of source rows produced by the input + * @param collectionParameter the modifiable backing collection of the table + * @param physType physical type of this node's output row + */ + private Result implementUpdate( + EnumerableRelImplementor implementor, + BlockBuilder builder, + Expression sourceExp, + ParameterExpression collectionParameter, + PhysType physType) { + final List updateCols = requireNonNull(getUpdateColumnList()); + final List tableFields = table.getRowType().getFieldList(); + final int tableFieldCount = tableFields.size(); + + // Child row layout for UPDATE: + // [originalField_0, ..., originalField_N-1, newValue_0, ..., newValue_M-1] + // where N = tableFieldCount and M = updateCols.size(). + + // Resolve each SET-column name to its 0-based position in the table row. + final int[] updateColumnIndices = new int[updateCols.size()]; + for (int i = 0; i < updateCols.size(); i++) { + final String colName = updateCols.get(i); + int found = -1; + for (int j = 0; j < tableFields.size(); j++) { + if (tableFields.get(j).getName().equals(colName)) { + found = j; + break; + } + } + if (found < 0) { + throw new AssertionError("column '" + colName + "' not found in table"); + } + updateColumnIndices[i] = found; + } + + // Generate code that applies one-to-one update consumption: + // each source row updates at most one matching sink row. + final Expression updateCountExp = + builder.append( + "updateCount", + Expressions.call( + EnumerableTableModify.class, + "applyUpdateOneToOne", + // Source rows are produced by the child relational expression. + sourceExp, + // Sink is the table's mutable backing collection. + Expressions.convert_(collectionParameter, List.class), + // Number of original-row fields in each source payload. + Expressions.constant(tableFieldCount), + // Table column positions to overwrite from trailing source values. + Expressions.constant(updateColumnIndices))); + + // Return the number of updated rows as the single output row. + builder.add( + Expressions.return_( + null, + Expressions.call( + BuiltInMethod.SINGLETON_ENUMERABLE.method, + Expressions.convert_( + updateCountExp, + long.class)))); + + return implementor.result(physType, builder.toBlock()); + } + + /** + * Applies UPDATE with one-to-one, first-match consumption semantics. + * + *

Each source row contributes one replacement row keyed by the original + * row content. As sink rows are scanned in order, the first matching row for + * each queued source update is replaced and consumed, so duplicate keys update + * only as many rows as appear in the source. + */ + @SuppressWarnings({"unchecked"}) + public static long applyUpdateOneToOne(Enumerable source, List sink, + int tableFieldCount, int[] updateColumnIndices) { + final Map, Deque> updatesByKey = new HashMap<>(); + try (Enumerator e = source.enumerator()) { + while (e.moveNext()) { + final Object[] sourceRow = (Object[]) e.current(); + final List key = Arrays.asList(Arrays.copyOf(sourceRow, tableFieldCount)); + final Object[] newRow = applyUpdate(sourceRow, tableFieldCount, updateColumnIndices); + updatesByKey.computeIfAbsent(key, k -> new ArrayDeque<>()).addLast(newRow); + } + } + + long updateCount = 0; + final ListIterator it = ((List) sink).listIterator(); + while (it.hasNext()) { + final Object[] current = it.next(); + final List key = Arrays.asList(current); + final Deque pending = updatesByKey.get(key); + if (pending == null || pending.isEmpty()) { + continue; + } + it.set(pending.removeFirst()); + updateCount++; + if (pending.isEmpty()) { + updatesByKey.remove(key); + } + } + return updateCount; + } + + /** Generates code for a DELETE statement. + * + *

The source produces every row that matches the WHERE clause. Those rows + * are removed from the backing collection and the number of deleted rows is + * returned as a single-element enumerable. + * + * @param implementor code-generation context + * @param builder block under construction + * @param sourceExp enumerable of rows to delete — every row matched + * by the WHERE clause, as produced by the input + * relational expression + * @param collectionParameter the modifiable backing collection of the table + * @param physType physical type of this node's output row + */ + private Result implementDelete( + EnumerableRelImplementor implementor, + BlockBuilder builder, + Expression sourceExp, + ParameterExpression collectionParameter, + PhysType physType) { + + // Snapshot the collection size before the delete so we can compute the + // number of rows removed as (sizeBefore - sizeAfter). final Expression countParameter = builder.append( "count", Expressions.call(collectionParameter, "size"), false); - Expression convertedChildExp; + + final String deleteMethodName = + EnumerableTableScan.deduceFormat(table) == JavaRowFormat.ARRAY + ? "applyDeleteArrayRows" + : "applyDeleteScalarRows"; + + // Remove every row in sourceExp from the backing collection. + builder.add( + Expressions.statement( + Expressions.call( + EnumerableTableModify.class, + deleteMethodName, + sourceExp, + collectionParameter))); + + // Snapshot the size again and return (sizeBefore - sizeAfter) as the delete count. + final Expression updatedCountParameter = + builder.append( + "updatedCount", + Expressions.call(collectionParameter, "size"), + false); + + builder.add( + Expressions.return_( + null, + Expressions.call( + BuiltInMethod.SINGLETON_ENUMERABLE.method, + Expressions.convert_( + Expressions.subtract(countParameter, updatedCountParameter), + long.class)))); + + return implementor.result(physType, builder.toBlock()); + } + + /** Generates code for an INSERT statement. + * + *

All rows produced by the source are added to the backing collection and + * the number of inserted rows is returned as a single-element enumerable. + * + * @param implementor code-generation context + * @param result compiled result of the input relational expression + * @param builder block under construction + * @param sourceExp enumerable of rows to insert — the output of the input + * relational expression (VALUES, SELECT, etc.) with any + * upstream filtering or projection already applied + * @param collectionParameter the modifiable backing collection of the table + * @param physType physical type of this node's output row + */ + private Result implementInsert( + EnumerableRelImplementor implementor, + Result result, + BlockBuilder builder, + Expression sourceExp, + ParameterExpression collectionParameter, + PhysType physType) { + + // Snapshot the collection size before the insert so we can compute the + // number of rows added as (sizeAfter - sizeBefore). + final Expression countParameter = + builder.append( + "count", + Expressions.call(collectionParameter, "size"), + false); + + // insertExp is the enumerable that will actually be streamed into the + // collection. When the source row type matches the table's row type + // exactly it is the same as sourceExp; otherwise it is sourceExp wrapped + // in a field-by-field cast projection so that every value lands in the + // Java type the table's backing collection expects. + Expression insertExp; if (!getInput().getRowType().equals(getRowType())) { + // The source row type doesn't match the table's row type (e.g. types + // differ in nullability or precision), so wrap the source in a projection + // that casts each field to the exact Java type the table expects. final JavaTypeFactory typeFactory = (JavaTypeFactory) getCluster().getTypeFactory(); final JavaRowFormat format = EnumerableTableScan.deduceFormat(table); - PhysType physType = + PhysType tablePhysType = PhysTypeImpl.of(typeFactory, table.getRowType(), format); + // One cast expression per field: sourceField -> tableFieldType. List expressionList = new ArrayList<>(); - final PhysType childPhysType = result.physType; + final PhysType sourcePhysType = result.physType; final ParameterExpression o_ = - Expressions.parameter(childPhysType.getJavaRowType(), "o"); - final int fieldCount = - childPhysType.getRowType().getFieldCount(); + Expressions.parameter(sourcePhysType.getJavaRowType(), "o"); + final int fieldCount = sourcePhysType.getRowType().getFieldCount(); for (int i = 0; i < fieldCount; i++) { expressionList.add( - childPhysType.fieldReference(o_, i, physType.getJavaFieldType(i))); + sourcePhysType.fieldReference(o_, i, tablePhysType.getJavaFieldType(i))); } - convertedChildExp = + // insertExp = sourceExp.select(o -> new TableRow(cast(o.f0), cast(o.f1), ...)) + insertExp = builder.append( - "convertedChild", + "insertRows", Expressions.call( - childExp, + sourceExp, BuiltInMethod.SELECT.method, - Expressions.lambda( - physType.record(expressionList), o_))); + Expressions.lambda(tablePhysType.record(expressionList), o_))); } else { - convertedChildExp = childExp; - } - final Method method; - switch (getOperation()) { - case INSERT: - method = BuiltInMethod.INTO.method; - break; - case DELETE: - method = BuiltInMethod.REMOVE_ALL.method; - break; - default: - throw new AssertionError(getOperation()); + insertExp = sourceExp; } + + // Stream all rows from insertExp into the backing collection. builder.add( Expressions.statement( Expressions.call( - convertedChildExp, method, collectionParameter))); + insertExp, BuiltInMethod.INTO.method, collectionParameter))); + + // Snapshot the size again and return (sizeAfter - sizeBefore) as the insert count. final Expression updatedCountParameter = builder.append( "updatedCount", Expressions.call(collectionParameter, "size"), false); + builder.add( Expressions.return_( null, Expressions.call( BuiltInMethod.SINGLETON_ENUMERABLE.method, Expressions.convert_( - Expressions.condition( - Expressions.greaterThanOrEqual( - updatedCountParameter, countParameter), - Expressions.subtract( - updatedCountParameter, countParameter), - Expressions.subtract( - countParameter, updatedCountParameter)), + Expressions.subtract(updatedCountParameter, countParameter), long.class)))); - final PhysType physType = - PhysTypeImpl.of( - implementor.getTypeFactory(), - getRowType(), - pref == Prefer.ARRAY - ? JavaRowFormat.ARRAY : JavaRowFormat.SCALAR); + return implementor.result(physType, builder.toBlock()); } + /** + * Removes from {@code sink} one occurrence per occurrence in {@code source} + * when rows are represented as {@code Object[]}. + */ + @SuppressWarnings({"unchecked"}) + public static void applyDeleteArrayRows(Enumerable source, Collection sink) { + final Collection sinkRows = (Collection) sink; + final List sourceRows = new ArrayList<>(); + try (Enumerator e = source.enumerator()) { + while (e.moveNext()) { + sourceRows.add((Object[]) e.current()); + } + } + // Drain source first, then mutate sink, to avoid iterator interference + // when source and sink share the same backing collection. + for (Object[] row : sourceRows) { + removeFirstArrayMatch(sinkRows, row); + } + } + + /** + * Removes from {@code sink} one occurrence per occurrence in {@code source} + * for scalar/custom row representations. + */ + @SuppressWarnings({"unchecked"}) + public static void applyDeleteScalarRows(Enumerable source, Collection sink) { + final Collection sinkRows = (Collection) sink; + final List sourceRows = new ArrayList<>(); + try (Enumerator e = source.enumerator()) { + while (e.moveNext()) { + sourceRows.add(e.current()); + } + } + // Drain source first, then mutate sink, to avoid iterator interference + // when source and sink share the same backing collection. + for (Object row : sourceRows) { + sinkRows.remove(row); + } + } + + /** + * Removes the first row in {@code sinkRows} whose values match {@code target}. + * + *

This helper is used for delete semantics that remove at most one sink + * row per source row when rows are represented as {@code Object[]}. + */ + private static void removeFirstArrayMatch(Collection sinkRows, + Object[] target) { + for (java.util.Iterator it = sinkRows.iterator(); it.hasNext();) { + if (Arrays.equals(it.next(), target)) { + it.remove(); + return; + } + } + } + + /** + * Builds the replacement row for an UPDATE source row. + * + * @param row source row produced by the child expression + * @param tableFieldCount number of fields in the original table row + * @param updateColumnIndices 0-based indices of the columns being updated + * @return the replacement row + */ + public static Object[] applyUpdate( + Object[] row, + int tableFieldCount, + int[] updateColumnIndices) { + // Source row layout: [originalField_0, ..., originalField_N-1, newValue_0, ..., newValue_M-1] + // where N = tableFieldCount and M = updateColumnIndices.length. + // Copy the first N fields and overwrite the positions named in the SET clause. + final Object[] newRow = Arrays.copyOf(row, tableFieldCount); + for (int i = 0; i < updateColumnIndices.length; i++) { + newRow[updateColumnIndices[i]] = row[tableFieldCount + i]; + } + return newRow; + } + } diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 98b67a02bbf3..98c334799506 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -181,6 +181,8 @@ public enum BuiltInMethod { Integer.class, int.class, int.class, BigDecimal.class, RoundingMode.class), INTO(ExtendedEnumerable.class, "into", Collection.class), REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class), + UPDATE(ExtendedEnumerable.class, "update", List.class, Function1.class, + Function1.class, Function1.class), SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class), SCHEMA_GET_TABLE(Schema.class, "getTable", String.class), SCHEMA_PLUS_ADD_TABLE(SchemaPlus.class, "add", String.class, Table.class), diff --git a/core/src/test/java/org/apache/calcite/adapter/enumerable/EnumerableTableModifyTest.java b/core/src/test/java/org/apache/calcite/adapter/enumerable/EnumerableTableModifyTest.java new file mode 100644 index 000000000000..103a00c0998e --- /dev/null +++ b/core/src/test/java/org/apache/calcite/adapter/enumerable/EnumerableTableModifyTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.adapter.enumerable; + +import org.apache.calcite.linq4j.Linq4j; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** Tests for {@link EnumerableTableModify} row-consumption semantics. */ +class EnumerableTableModifyTest { + + @Test void testApplyUpdateOneToOneUpdatesOnlyFirstNMatchingRows() { + final List sink = new ArrayList<>(); + sink.add(new Object[] {1, 10}); + sink.add(new Object[] {1, 10}); + sink.add(new Object[] {1, 10}); + sink.add(new Object[] {2, 20}); + + // Source row layout: [original_i, original_j, new_j]. + final List source = Arrays.asList( + new Object[] {1, 10, 100}, + new Object[] {1, 10, 200}); + + final long count = EnumerableTableModify.applyUpdateOneToOne( + Linq4j.asEnumerable(source), sink, 2, new int[] {1}); + + assertThat(count, is(2L)); + assertThat(toValueRows(sink), + is(Arrays.asList( + Arrays.asList(1, 100), + Arrays.asList(1, 200), + Arrays.asList(1, 10), + Arrays.asList(2, 20)))); + } + + @Test void testApplyDeleteArrayRowsDeletesOnlyFirstNMatchingRows() { + final List sink = new ArrayList<>(); + sink.add(new Object[] {1, 10}); + sink.add(new Object[] {1, 10}); + sink.add(new Object[] {1, 10}); + sink.add(new Object[] {2, 20}); + + final List source = Arrays.asList( + new Object[] {1, 10}, + new Object[] {1, 10}); + + EnumerableTableModify.applyDeleteArrayRows(Linq4j.asEnumerable(source), sink); + + assertThat(toValueRows(sink), + is(Arrays.asList( + Arrays.asList(1, 10), + Arrays.asList(2, 20)))); + } + + private static List> toValueRows(List rows) { + final List> valueRows = new ArrayList<>(); + for (Object[] row : rows) { + valueRows.add(Arrays.asList(row)); + } + return valueRows; + } +} + diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java index d859519b4178..54dcc20f41e7 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java @@ -377,6 +377,15 @@ protected OrderedQueryable asOrderedQueryable() { return EnumerableDefaults.remove(getThis(), sink); } + @Override public long update( + List sink, + Function1 sinkKeySelector, + Function1 sourceKeySelector, + Function1 transform) { + return EnumerableDefaults.update(getThis(), sink, sinkKeySelector, + sourceKeySelector, transform); + } + @Override public Enumerable hashJoin( Enumerable inner, Function1 outerKeySelector, Function1 innerKeySelector, diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 988450df3643..ae26a602132d 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -61,6 +61,7 @@ import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; @@ -4399,6 +4400,40 @@ public static > C remove( return sink; } + /** + * Default implementation of + * {@link ExtendedEnumerable#update(List, Function1, Function1, Function1)}. + * + *

Builds a map from source-row keys to replacement rows in a single pass + * over the source, then performs a single pass over the sink, replacing + * matched rows in place. + */ + public static long update( + Enumerable source, + List sink, + Function1 sinkKeySelector, + Function1 sourceKeySelector, + Function1 sourceTransform) { + final Map updateMap = new HashMap<>(); + try (Enumerator e = source.enumerator()) { + while (e.moveNext()) { + final T row = e.current(); + updateMap.put(sourceKeySelector.apply(row), sourceTransform.apply(row)); + } + } + long updateCount = 0; + final ListIterator it = sink.listIterator(); + while (it.hasNext()) { + final T current = it.next(); + final T newRow = updateMap.get(sinkKeySelector.apply(current)); + if (newRow != null) { + it.set(newRow); + updateCount++; + } + } + return updateCount; + } + /** * Hash table with null-safe key set. * diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java index 160f2afa0b1e..ff28c3fe822c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java @@ -547,6 +547,32 @@ Enumerable intersect(Enumerable enumerable1, */ > C removeAll(C sink); + /** + * Updates rows of {@code sink} based on the contents of this sequence. + * + *

For each element {@code x} of this sequence, {@code sourceKeySelector} + * computes a key, and {@code sourceTransform} computes a replacement row. + * Then for each element {@code y} of {@code sink}, {@code sinkKeySelector} + * computes a key; if it matches a key produced from this sequence, {@code y} + * is replaced (in place) with the corresponding replacement row. + * + *

The sink is a {@link List} so that elements can be replaced + * in place while preserving order. + * + * @param sink List to be updated in place + * @param sinkKeySelector Function that extracts a key from a sink row + * @param sourceKeySelector Function that extracts a key from a source row + * @param transform Function that produces the replacement row from a + * source row + * @param Key type + * @return Number of rows replaced + */ + long update( + List sink, + Function1 sinkKeySelector, + Function1 sourceKeySelector, + Function1 transform); + /** * Correlates the elements of two sequences based on * matching keys. The default equality comparer is used to compare diff --git a/server/src/test/java/org/apache/calcite/test/ServerTest.java b/server/src/test/java/org/apache/calcite/test/ServerTest.java index 40e1430c8799..a58ef5b2b8ea 100644 --- a/server/src/test/java/org/apache/calcite/test/ServerTest.java +++ b/server/src/test/java/org/apache/calcite/test/ServerTest.java @@ -108,6 +108,159 @@ static Connection connect() throws SQLException { executor.execute((SqlTruncateTable) o, context); } + @Test void testUpdate() throws Exception { + try (Connection c = connect(); + Statement s = c.createStatement()) { + s.execute("create table t (i int not null, j int not null)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (2, 20)"); + s.executeUpdate("insert into t values (3, 30)"); + + // Update one row + int count = s.executeUpdate("update t set j = 99 where i = 2"); + assertThat(count, is(1)); + + try (ResultSet r = s.executeQuery("select i, j from t order by i")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(1)); + assertThat(r.getInt(2), is(10)); + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(2)); + assertThat(r.getInt(2), is(99)); + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(3)); + assertThat(r.getInt(2), is(30)); + assertThat(r.next(), is(false)); + } + + // Update multiple rows + count = s.executeUpdate("update t set j = 0 where i > 1"); + assertThat(count, is(2)); + + try (ResultSet r = s.executeQuery("select sum(j) from t")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(10)); + assertThat(r.next(), is(false)); + } + + // Update zero rows (no predicate match) + count = s.executeUpdate("update t set j = 100 where i = 99"); + assertThat(count, is(0)); + } + } + + @Test void testUpdateDuplicateRows() throws Exception { + try (Connection c = connect(); + Statement s = c.createStatement()) { + s.execute("create table t (i int not null, j int not null)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (2, 20)"); + + final int count = s.executeUpdate("update t set j = 99 where i = 1 and j = 10"); + assertThat(count, is(3)); + + try (ResultSet r = s.executeQuery( + "select " + + "sum(case when i = 1 and j = 99 then 1 else 0 end), " + + "sum(case when i = 1 and j = 10 then 1 else 0 end), " + + "sum(case when i = 2 and j = 20 then 1 else 0 end) " + + "from t")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(3)); + assertThat(r.getInt(2), is(0)); + assertThat(r.getInt(3), is(1)); + assertThat(r.next(), is(false)); + } + } + } + + /** Tests that INSERT ... SELECT returns the correct row count when + * 0, 1, or multiple rows are produced by the source query. + * Exercises {@link org.apache.calcite.server.MutableArrayTable} via the + * enumerable INSERT path. */ + @Test void testInsertSelectRowCount() throws Exception { + try (Connection c = connect(); + Statement s = c.createStatement()) { + s.execute("create table src (i int not null, j int not null)"); + s.executeUpdate("insert into src values (1, 10)"); + s.executeUpdate("insert into src values (2, 20)"); + s.execute("create table dst (i int not null, j int not null)"); + + // Insert 0 rows (source query returns nothing) + int count = s.executeUpdate("insert into dst select * from src where 1 = 0"); + assertThat(count, is(0)); + + // Insert 1 row + count = s.executeUpdate("insert into dst select * from src where i = 1"); + assertThat(count, is(1)); + + // Insert multiple rows + count = s.executeUpdate("insert into dst select * from src"); + assertThat(count, is(2)); + } + } + + /** Tests that DELETE returns the correct row count when + * 0, 1, or multiple rows match the predicate. + * Exercises {@link org.apache.calcite.server.MutableArrayTable} via the + * enumerable DELETE path. */ + @Test void testDelete() throws Exception { + try (Connection c = connect(); + Statement s = c.createStatement()) { + s.execute("create table t (i int not null, j int not null)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (2, 20)"); + s.executeUpdate("insert into t values (3, 30)"); + + // Delete 0 rows (no predicate match) + int count = s.executeUpdate("delete from t where i = 99"); + assertThat(count, is(0)); + + // Verify all 3 rows are still present + try (ResultSet r = s.executeQuery("select count(*) from t")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(3)); + } + + // Delete 1 row + count = s.executeUpdate("delete from t where i = 2"); + assertThat(count, is(1)); + + // Delete multiple rows (both remaining rows: i=1 and i=3) + count = s.executeUpdate("delete from t where i > 0"); + assertThat(count, is(2)); + + // Verify table is empty + try (ResultSet r = s.executeQuery("select count(*) from t")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(0)); + } + } + } + + @Test void testDeleteDuplicateRows() throws Exception { + try (Connection c = connect(); + Statement s = c.createStatement()) { + s.execute("create table t (i int not null, j int not null)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (2, 20)"); + + final int count = s.executeUpdate("delete from t where i = 1 and j = 10"); + assertThat(count, is(3)); + + try (ResultSet r = s.executeQuery("select i, j from t")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(2)); + assertThat(r.getInt(2), is(20)); + assertThat(r.next(), is(false)); + } + } + } + @Test void testStatement() throws Exception { try (Connection c = connect(); Statement s = c.createStatement();