diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java
index 186ec6903f91..5874f2b9aa31 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java
@@ -17,6 +17,8 @@
package org.apache.calcite.adapter.enumerable;
import org.apache.calcite.adapter.java.JavaTypeFactory;
+import org.apache.calcite.linq4j.Enumerable;
+import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
@@ -28,17 +30,23 @@
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.TableModify;
+import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.ModifiableTable;
import org.apache.calcite.util.BuiltInMethod;
import org.checkerframework.checker.nullness.qual.Nullable;
-import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
+import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
+import java.util.Deque;
+import java.util.HashMap;
import java.util.List;
+import java.util.ListIterator;
+import java.util.Map;
import static com.google.common.base.Preconditions.checkArgument;
@@ -80,100 +88,394 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits,
final BlockBuilder builder = new BlockBuilder();
final Result result =
implementor.visitChild(this, 0, (EnumerableRel) getInput(), pref);
- Expression childExp =
- builder.append(
- "child", result.block);
+
+ // Enumerable produced by the input relational expression.
+ final Expression sourceExp =
+ builder.append("source", result.block);
+
+ // Variable that will hold the table's mutable backing collection.
final ParameterExpression collectionParameter =
Expressions.parameter(Collection.class,
builder.newName("collection"));
+
+ // Expression that yields the ModifiableTable instance at runtime.
final Expression expression = table.getExpression(ModifiableTable.class);
requireNonNull(expression, "expression"); // TODO: user error in validator
checkArgument(
ModifiableTable.class.isAssignableFrom(
Types.toClass(expression.getType())),
"not assignable from type %s", expression.getType());
+
+ // collection = table.getModifiableCollection()
builder.add(
Expressions.declare(
Modifier.FINAL,
collectionParameter,
Expressions.call(
expression,
- BuiltInMethod.MODIFIABLE_TABLE_GET_MODIFIABLE_COLLECTION
- .method)));
+ BuiltInMethod.MODIFIABLE_TABLE_GET_MODIFIABLE_COLLECTION.method)));
+
+ // Physical row representation of this TableModify node's output
+ // (a single ROWCOUNT field).
+ final PhysType physType =
+ PhysTypeImpl.of(
+ implementor.getTypeFactory(),
+ getRowType(),
+ pref == Prefer.ARRAY ? JavaRowFormat.ARRAY : JavaRowFormat.SCALAR);
+
+ switch (getOperation()) {
+ case INSERT:
+ return implementInsert(implementor, result, builder, sourceExp, collectionParameter, physType);
+ case UPDATE:
+ return implementUpdate(implementor, builder, sourceExp, collectionParameter, physType);
+ case DELETE:
+ return implementDelete(implementor, builder, sourceExp, collectionParameter, physType);
+ default:
+ throw new AssertionError("unsupported operation: " + getOperation());
+ }
+ }
+
+ /** Generates code for an UPDATE statement.
+ *
+ *
Applies updates to matching rows in the backing collection and returns
+ * the number of updated rows as a single-element enumerable.
+ *
+ * @param implementor code-generation context
+ * @param builder block under construction
+ * @param sourceExp enumerable of source rows produced by the input
+ * @param collectionParameter the modifiable backing collection of the table
+ * @param physType physical type of this node's output row
+ */
+ private Result implementUpdate(
+ EnumerableRelImplementor implementor,
+ BlockBuilder builder,
+ Expression sourceExp,
+ ParameterExpression collectionParameter,
+ PhysType physType) {
+ final List updateCols = requireNonNull(getUpdateColumnList());
+ final List tableFields = table.getRowType().getFieldList();
+ final int tableFieldCount = tableFields.size();
+
+ // Child row layout for UPDATE:
+ // [originalField_0, ..., originalField_N-1, newValue_0, ..., newValue_M-1]
+ // where N = tableFieldCount and M = updateCols.size().
+
+ // Resolve each SET-column name to its 0-based position in the table row.
+ final int[] updateColumnIndices = new int[updateCols.size()];
+ for (int i = 0; i < updateCols.size(); i++) {
+ final String colName = updateCols.get(i);
+ int found = -1;
+ for (int j = 0; j < tableFields.size(); j++) {
+ if (tableFields.get(j).getName().equals(colName)) {
+ found = j;
+ break;
+ }
+ }
+ if (found < 0) {
+ throw new AssertionError("column '" + colName + "' not found in table");
+ }
+ updateColumnIndices[i] = found;
+ }
+
+ // Generate code that applies one-to-one update consumption:
+ // each source row updates at most one matching sink row.
+ final Expression updateCountExp =
+ builder.append(
+ "updateCount",
+ Expressions.call(
+ EnumerableTableModify.class,
+ "applyUpdateOneToOne",
+ // Source rows are produced by the child relational expression.
+ sourceExp,
+ // Sink is the table's mutable backing collection.
+ Expressions.convert_(collectionParameter, List.class),
+ // Number of original-row fields in each source payload.
+ Expressions.constant(tableFieldCount),
+ // Table column positions to overwrite from trailing source values.
+ Expressions.constant(updateColumnIndices)));
+
+ // Return the number of updated rows as the single output row.
+ builder.add(
+ Expressions.return_(
+ null,
+ Expressions.call(
+ BuiltInMethod.SINGLETON_ENUMERABLE.method,
+ Expressions.convert_(
+ updateCountExp,
+ long.class))));
+
+ return implementor.result(physType, builder.toBlock());
+ }
+
+ /**
+ * Applies UPDATE with one-to-one, first-match consumption semantics.
+ *
+ * Each source row contributes one replacement row keyed by the original
+ * row content. As sink rows are scanned in order, the first matching row for
+ * each queued source update is replaced and consumed, so duplicate keys update
+ * only as many rows as appear in the source.
+ */
+ @SuppressWarnings({"unchecked"})
+ public static long applyUpdateOneToOne(Enumerable> source, List> sink,
+ int tableFieldCount, int[] updateColumnIndices) {
+ final Map, Deque> updatesByKey = new HashMap<>();
+ try (Enumerator> e = source.enumerator()) {
+ while (e.moveNext()) {
+ final Object[] sourceRow = (Object[]) e.current();
+ final List key = Arrays.asList(Arrays.copyOf(sourceRow, tableFieldCount));
+ final Object[] newRow = applyUpdate(sourceRow, tableFieldCount, updateColumnIndices);
+ updatesByKey.computeIfAbsent(key, k -> new ArrayDeque<>()).addLast(newRow);
+ }
+ }
+
+ long updateCount = 0;
+ final ListIterator it = ((List) sink).listIterator();
+ while (it.hasNext()) {
+ final Object[] current = it.next();
+ final List key = Arrays.asList(current);
+ final Deque pending = updatesByKey.get(key);
+ if (pending == null || pending.isEmpty()) {
+ continue;
+ }
+ it.set(pending.removeFirst());
+ updateCount++;
+ if (pending.isEmpty()) {
+ updatesByKey.remove(key);
+ }
+ }
+ return updateCount;
+ }
+
+ /** Generates code for a DELETE statement.
+ *
+ * The source produces every row that matches the WHERE clause. Those rows
+ * are removed from the backing collection and the number of deleted rows is
+ * returned as a single-element enumerable.
+ *
+ * @param implementor code-generation context
+ * @param builder block under construction
+ * @param sourceExp enumerable of rows to delete — every row matched
+ * by the WHERE clause, as produced by the input
+ * relational expression
+ * @param collectionParameter the modifiable backing collection of the table
+ * @param physType physical type of this node's output row
+ */
+ private Result implementDelete(
+ EnumerableRelImplementor implementor,
+ BlockBuilder builder,
+ Expression sourceExp,
+ ParameterExpression collectionParameter,
+ PhysType physType) {
+
+ // Snapshot the collection size before the delete so we can compute the
+ // number of rows removed as (sizeBefore - sizeAfter).
final Expression countParameter =
builder.append(
"count",
Expressions.call(collectionParameter, "size"),
false);
- Expression convertedChildExp;
+
+ final String deleteMethodName =
+ EnumerableTableScan.deduceFormat(table) == JavaRowFormat.ARRAY
+ ? "applyDeleteArrayRows"
+ : "applyDeleteScalarRows";
+
+ // Remove every row in sourceExp from the backing collection.
+ builder.add(
+ Expressions.statement(
+ Expressions.call(
+ EnumerableTableModify.class,
+ deleteMethodName,
+ sourceExp,
+ collectionParameter)));
+
+ // Snapshot the size again and return (sizeBefore - sizeAfter) as the delete count.
+ final Expression updatedCountParameter =
+ builder.append(
+ "updatedCount",
+ Expressions.call(collectionParameter, "size"),
+ false);
+
+ builder.add(
+ Expressions.return_(
+ null,
+ Expressions.call(
+ BuiltInMethod.SINGLETON_ENUMERABLE.method,
+ Expressions.convert_(
+ Expressions.subtract(countParameter, updatedCountParameter),
+ long.class))));
+
+ return implementor.result(physType, builder.toBlock());
+ }
+
+ /** Generates code for an INSERT statement.
+ *
+ *
All rows produced by the source are added to the backing collection and
+ * the number of inserted rows is returned as a single-element enumerable.
+ *
+ * @param implementor code-generation context
+ * @param result compiled result of the input relational expression
+ * @param builder block under construction
+ * @param sourceExp enumerable of rows to insert — the output of the input
+ * relational expression (VALUES, SELECT, etc.) with any
+ * upstream filtering or projection already applied
+ * @param collectionParameter the modifiable backing collection of the table
+ * @param physType physical type of this node's output row
+ */
+ private Result implementInsert(
+ EnumerableRelImplementor implementor,
+ Result result,
+ BlockBuilder builder,
+ Expression sourceExp,
+ ParameterExpression collectionParameter,
+ PhysType physType) {
+
+ // Snapshot the collection size before the insert so we can compute the
+ // number of rows added as (sizeAfter - sizeBefore).
+ final Expression countParameter =
+ builder.append(
+ "count",
+ Expressions.call(collectionParameter, "size"),
+ false);
+
+ // insertExp is the enumerable that will actually be streamed into the
+ // collection. When the source row type matches the table's row type
+ // exactly it is the same as sourceExp; otherwise it is sourceExp wrapped
+ // in a field-by-field cast projection so that every value lands in the
+ // Java type the table's backing collection expects.
+ Expression insertExp;
if (!getInput().getRowType().equals(getRowType())) {
+ // The source row type doesn't match the table's row type (e.g. types
+ // differ in nullability or precision), so wrap the source in a projection
+ // that casts each field to the exact Java type the table expects.
final JavaTypeFactory typeFactory =
(JavaTypeFactory) getCluster().getTypeFactory();
final JavaRowFormat format = EnumerableTableScan.deduceFormat(table);
- PhysType physType =
+ PhysType tablePhysType =
PhysTypeImpl.of(typeFactory, table.getRowType(), format);
+ // One cast expression per field: sourceField -> tableFieldType.
List expressionList = new ArrayList<>();
- final PhysType childPhysType = result.physType;
+ final PhysType sourcePhysType = result.physType;
final ParameterExpression o_ =
- Expressions.parameter(childPhysType.getJavaRowType(), "o");
- final int fieldCount =
- childPhysType.getRowType().getFieldCount();
+ Expressions.parameter(sourcePhysType.getJavaRowType(), "o");
+ final int fieldCount = sourcePhysType.getRowType().getFieldCount();
for (int i = 0; i < fieldCount; i++) {
expressionList.add(
- childPhysType.fieldReference(o_, i, physType.getJavaFieldType(i)));
+ sourcePhysType.fieldReference(o_, i, tablePhysType.getJavaFieldType(i)));
}
- convertedChildExp =
+ // insertExp = sourceExp.select(o -> new TableRow(cast(o.f0), cast(o.f1), ...))
+ insertExp =
builder.append(
- "convertedChild",
+ "insertRows",
Expressions.call(
- childExp,
+ sourceExp,
BuiltInMethod.SELECT.method,
- Expressions.lambda(
- physType.record(expressionList), o_)));
+ Expressions.lambda(tablePhysType.record(expressionList), o_)));
} else {
- convertedChildExp = childExp;
- }
- final Method method;
- switch (getOperation()) {
- case INSERT:
- method = BuiltInMethod.INTO.method;
- break;
- case DELETE:
- method = BuiltInMethod.REMOVE_ALL.method;
- break;
- default:
- throw new AssertionError(getOperation());
+ insertExp = sourceExp;
}
+
+ // Stream all rows from insertExp into the backing collection.
builder.add(
Expressions.statement(
Expressions.call(
- convertedChildExp, method, collectionParameter)));
+ insertExp, BuiltInMethod.INTO.method, collectionParameter)));
+
+ // Snapshot the size again and return (sizeAfter - sizeBefore) as the insert count.
final Expression updatedCountParameter =
builder.append(
"updatedCount",
Expressions.call(collectionParameter, "size"),
false);
+
builder.add(
Expressions.return_(
null,
Expressions.call(
BuiltInMethod.SINGLETON_ENUMERABLE.method,
Expressions.convert_(
- Expressions.condition(
- Expressions.greaterThanOrEqual(
- updatedCountParameter, countParameter),
- Expressions.subtract(
- updatedCountParameter, countParameter),
- Expressions.subtract(
- countParameter, updatedCountParameter)),
+ Expressions.subtract(updatedCountParameter, countParameter),
long.class))));
- final PhysType physType =
- PhysTypeImpl.of(
- implementor.getTypeFactory(),
- getRowType(),
- pref == Prefer.ARRAY
- ? JavaRowFormat.ARRAY : JavaRowFormat.SCALAR);
+
return implementor.result(physType, builder.toBlock());
}
+ /**
+ * Removes from {@code sink} one occurrence per occurrence in {@code source}
+ * when rows are represented as {@code Object[]}.
+ */
+ @SuppressWarnings({"unchecked"})
+ public static void applyDeleteArrayRows(Enumerable> source, Collection> sink) {
+ final Collection sinkRows = (Collection) sink;
+ final List sourceRows = new ArrayList<>();
+ try (Enumerator> e = source.enumerator()) {
+ while (e.moveNext()) {
+ sourceRows.add((Object[]) e.current());
+ }
+ }
+ // Drain source first, then mutate sink, to avoid iterator interference
+ // when source and sink share the same backing collection.
+ for (Object[] row : sourceRows) {
+ removeFirstArrayMatch(sinkRows, row);
+ }
+ }
+
+ /**
+ * Removes from {@code sink} one occurrence per occurrence in {@code source}
+ * for scalar/custom row representations.
+ */
+ @SuppressWarnings({"unchecked"})
+ public static void applyDeleteScalarRows(Enumerable> source, Collection> sink) {
+ final Collection sinkRows = (Collection) sink;
+ final List sourceRows = new ArrayList<>();
+ try (Enumerator> e = source.enumerator()) {
+ while (e.moveNext()) {
+ sourceRows.add(e.current());
+ }
+ }
+ // Drain source first, then mutate sink, to avoid iterator interference
+ // when source and sink share the same backing collection.
+ for (Object row : sourceRows) {
+ sinkRows.remove(row);
+ }
+ }
+
+ /**
+ * Removes the first row in {@code sinkRows} whose values match {@code target}.
+ *
+ * This helper is used for delete semantics that remove at most one sink
+ * row per source row when rows are represented as {@code Object[]}.
+ */
+ private static void removeFirstArrayMatch(Collection sinkRows,
+ Object[] target) {
+ for (java.util.Iterator it = sinkRows.iterator(); it.hasNext();) {
+ if (Arrays.equals(it.next(), target)) {
+ it.remove();
+ return;
+ }
+ }
+ }
+
+ /**
+ * Builds the replacement row for an UPDATE source row.
+ *
+ * @param row source row produced by the child expression
+ * @param tableFieldCount number of fields in the original table row
+ * @param updateColumnIndices 0-based indices of the columns being updated
+ * @return the replacement row
+ */
+ public static Object[] applyUpdate(
+ Object[] row,
+ int tableFieldCount,
+ int[] updateColumnIndices) {
+ // Source row layout: [originalField_0, ..., originalField_N-1, newValue_0, ..., newValue_M-1]
+ // where N = tableFieldCount and M = updateColumnIndices.length.
+ // Copy the first N fields and overwrite the positions named in the SET clause.
+ final Object[] newRow = Arrays.copyOf(row, tableFieldCount);
+ for (int i = 0; i < updateColumnIndices.length; i++) {
+ newRow[updateColumnIndices[i]] = row[tableFieldCount + i];
+ }
+ return newRow;
+ }
+
}
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 98b67a02bbf3..98c334799506 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -181,6 +181,8 @@ public enum BuiltInMethod {
Integer.class, int.class, int.class, BigDecimal.class, RoundingMode.class),
INTO(ExtendedEnumerable.class, "into", Collection.class),
REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class),
+ UPDATE(ExtendedEnumerable.class, "update", List.class, Function1.class,
+ Function1.class, Function1.class),
SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class),
SCHEMA_GET_TABLE(Schema.class, "getTable", String.class),
SCHEMA_PLUS_ADD_TABLE(SchemaPlus.class, "add", String.class, Table.class),
diff --git a/core/src/test/java/org/apache/calcite/adapter/enumerable/EnumerableTableModifyTest.java b/core/src/test/java/org/apache/calcite/adapter/enumerable/EnumerableTableModifyTest.java
new file mode 100644
index 000000000000..103a00c0998e
--- /dev/null
+++ b/core/src/test/java/org/apache/calcite/adapter/enumerable/EnumerableTableModifyTest.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.adapter.enumerable;
+
+import org.apache.calcite.linq4j.Linq4j;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+/** Tests for {@link EnumerableTableModify} row-consumption semantics. */
+class EnumerableTableModifyTest {
+
+ @Test void testApplyUpdateOneToOneUpdatesOnlyFirstNMatchingRows() {
+ final List sink = new ArrayList<>();
+ sink.add(new Object[] {1, 10});
+ sink.add(new Object[] {1, 10});
+ sink.add(new Object[] {1, 10});
+ sink.add(new Object[] {2, 20});
+
+ // Source row layout: [original_i, original_j, new_j].
+ final List source = Arrays.asList(
+ new Object[] {1, 10, 100},
+ new Object[] {1, 10, 200});
+
+ final long count = EnumerableTableModify.applyUpdateOneToOne(
+ Linq4j.asEnumerable(source), sink, 2, new int[] {1});
+
+ assertThat(count, is(2L));
+ assertThat(toValueRows(sink),
+ is(Arrays.asList(
+ Arrays.asList(1, 100),
+ Arrays.asList(1, 200),
+ Arrays.asList(1, 10),
+ Arrays.asList(2, 20))));
+ }
+
+ @Test void testApplyDeleteArrayRowsDeletesOnlyFirstNMatchingRows() {
+ final List sink = new ArrayList<>();
+ sink.add(new Object[] {1, 10});
+ sink.add(new Object[] {1, 10});
+ sink.add(new Object[] {1, 10});
+ sink.add(new Object[] {2, 20});
+
+ final List source = Arrays.asList(
+ new Object[] {1, 10},
+ new Object[] {1, 10});
+
+ EnumerableTableModify.applyDeleteArrayRows(Linq4j.asEnumerable(source), sink);
+
+ assertThat(toValueRows(sink),
+ is(Arrays.asList(
+ Arrays.asList(1, 10),
+ Arrays.asList(2, 20))));
+ }
+
+ private static List> toValueRows(List rows) {
+ final List> valueRows = new ArrayList<>();
+ for (Object[] row : rows) {
+ valueRows.add(Arrays.asList(row));
+ }
+ return valueRows;
+ }
+}
+
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
index d859519b4178..54dcc20f41e7 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java
@@ -377,6 +377,15 @@ protected OrderedQueryable asOrderedQueryable() {
return EnumerableDefaults.remove(getThis(), sink);
}
+ @Override public long update(
+ List sink,
+ Function1 sinkKeySelector,
+ Function1 sourceKeySelector,
+ Function1 transform) {
+ return EnumerableDefaults.update(getThis(), sink, sinkKeySelector,
+ sourceKeySelector, transform);
+ }
+
@Override public Enumerable hashJoin(
Enumerable inner, Function1 outerKeySelector,
Function1 innerKeySelector,
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
index 988450df3643..ae26a602132d 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java
@@ -61,6 +61,7 @@
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
+import java.util.ListIterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
@@ -4399,6 +4400,40 @@ public static > C remove(
return sink;
}
+ /**
+ * Default implementation of
+ * {@link ExtendedEnumerable#update(List, Function1, Function1, Function1)}.
+ *
+ * Builds a map from source-row keys to replacement rows in a single pass
+ * over the source, then performs a single pass over the sink, replacing
+ * matched rows in place.
+ */
+ public static long update(
+ Enumerable source,
+ List sink,
+ Function1 sinkKeySelector,
+ Function1 sourceKeySelector,
+ Function1 sourceTransform) {
+ final Map updateMap = new HashMap<>();
+ try (Enumerator e = source.enumerator()) {
+ while (e.moveNext()) {
+ final T row = e.current();
+ updateMap.put(sourceKeySelector.apply(row), sourceTransform.apply(row));
+ }
+ }
+ long updateCount = 0;
+ final ListIterator it = sink.listIterator();
+ while (it.hasNext()) {
+ final T current = it.next();
+ final T newRow = updateMap.get(sinkKeySelector.apply(current));
+ if (newRow != null) {
+ it.set(newRow);
+ updateCount++;
+ }
+ }
+ return updateCount;
+ }
+
/**
* Hash table with null-safe key set.
*
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
index 160f2afa0b1e..ff28c3fe822c 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java
@@ -547,6 +547,32 @@ Enumerable intersect(Enumerable enumerable1,
*/
> C removeAll(C sink);
+ /**
+ * Updates rows of {@code sink} based on the contents of this sequence.
+ *
+ * For each element {@code x} of this sequence, {@code sourceKeySelector}
+ * computes a key, and {@code sourceTransform} computes a replacement row.
+ * Then for each element {@code y} of {@code sink}, {@code sinkKeySelector}
+ * computes a key; if it matches a key produced from this sequence, {@code y}
+ * is replaced (in place) with the corresponding replacement row.
+ *
+ *
The sink is a {@link List} so that elements can be replaced
+ * in place while preserving order.
+ *
+ * @param sink List to be updated in place
+ * @param sinkKeySelector Function that extracts a key from a sink row
+ * @param sourceKeySelector Function that extracts a key from a source row
+ * @param transform Function that produces the replacement row from a
+ * source row
+ * @param Key type
+ * @return Number of rows replaced
+ */
+ long update(
+ List sink,
+ Function1 sinkKeySelector,
+ Function1 sourceKeySelector,
+ Function1 transform);
+
/**
* Correlates the elements of two sequences based on
* matching keys. The default equality comparer is used to compare
diff --git a/server/src/test/java/org/apache/calcite/test/ServerTest.java b/server/src/test/java/org/apache/calcite/test/ServerTest.java
index 40e1430c8799..a58ef5b2b8ea 100644
--- a/server/src/test/java/org/apache/calcite/test/ServerTest.java
+++ b/server/src/test/java/org/apache/calcite/test/ServerTest.java
@@ -108,6 +108,159 @@ static Connection connect() throws SQLException {
executor.execute((SqlTruncateTable) o, context);
}
+ @Test void testUpdate() throws Exception {
+ try (Connection c = connect();
+ Statement s = c.createStatement()) {
+ s.execute("create table t (i int not null, j int not null)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (2, 20)");
+ s.executeUpdate("insert into t values (3, 30)");
+
+ // Update one row
+ int count = s.executeUpdate("update t set j = 99 where i = 2");
+ assertThat(count, is(1));
+
+ try (ResultSet r = s.executeQuery("select i, j from t order by i")) {
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(1));
+ assertThat(r.getInt(2), is(10));
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(2));
+ assertThat(r.getInt(2), is(99));
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(3));
+ assertThat(r.getInt(2), is(30));
+ assertThat(r.next(), is(false));
+ }
+
+ // Update multiple rows
+ count = s.executeUpdate("update t set j = 0 where i > 1");
+ assertThat(count, is(2));
+
+ try (ResultSet r = s.executeQuery("select sum(j) from t")) {
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(10));
+ assertThat(r.next(), is(false));
+ }
+
+ // Update zero rows (no predicate match)
+ count = s.executeUpdate("update t set j = 100 where i = 99");
+ assertThat(count, is(0));
+ }
+ }
+
+ @Test void testUpdateDuplicateRows() throws Exception {
+ try (Connection c = connect();
+ Statement s = c.createStatement()) {
+ s.execute("create table t (i int not null, j int not null)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (2, 20)");
+
+ final int count = s.executeUpdate("update t set j = 99 where i = 1 and j = 10");
+ assertThat(count, is(3));
+
+ try (ResultSet r = s.executeQuery(
+ "select "
+ + "sum(case when i = 1 and j = 99 then 1 else 0 end), "
+ + "sum(case when i = 1 and j = 10 then 1 else 0 end), "
+ + "sum(case when i = 2 and j = 20 then 1 else 0 end) "
+ + "from t")) {
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(3));
+ assertThat(r.getInt(2), is(0));
+ assertThat(r.getInt(3), is(1));
+ assertThat(r.next(), is(false));
+ }
+ }
+ }
+
+ /** Tests that INSERT ... SELECT returns the correct row count when
+ * 0, 1, or multiple rows are produced by the source query.
+ * Exercises {@link org.apache.calcite.server.MutableArrayTable} via the
+ * enumerable INSERT path. */
+ @Test void testInsertSelectRowCount() throws Exception {
+ try (Connection c = connect();
+ Statement s = c.createStatement()) {
+ s.execute("create table src (i int not null, j int not null)");
+ s.executeUpdate("insert into src values (1, 10)");
+ s.executeUpdate("insert into src values (2, 20)");
+ s.execute("create table dst (i int not null, j int not null)");
+
+ // Insert 0 rows (source query returns nothing)
+ int count = s.executeUpdate("insert into dst select * from src where 1 = 0");
+ assertThat(count, is(0));
+
+ // Insert 1 row
+ count = s.executeUpdate("insert into dst select * from src where i = 1");
+ assertThat(count, is(1));
+
+ // Insert multiple rows
+ count = s.executeUpdate("insert into dst select * from src");
+ assertThat(count, is(2));
+ }
+ }
+
+ /** Tests that DELETE returns the correct row count when
+ * 0, 1, or multiple rows match the predicate.
+ * Exercises {@link org.apache.calcite.server.MutableArrayTable} via the
+ * enumerable DELETE path. */
+ @Test void testDelete() throws Exception {
+ try (Connection c = connect();
+ Statement s = c.createStatement()) {
+ s.execute("create table t (i int not null, j int not null)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (2, 20)");
+ s.executeUpdate("insert into t values (3, 30)");
+
+ // Delete 0 rows (no predicate match)
+ int count = s.executeUpdate("delete from t where i = 99");
+ assertThat(count, is(0));
+
+ // Verify all 3 rows are still present
+ try (ResultSet r = s.executeQuery("select count(*) from t")) {
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(3));
+ }
+
+ // Delete 1 row
+ count = s.executeUpdate("delete from t where i = 2");
+ assertThat(count, is(1));
+
+ // Delete multiple rows (both remaining rows: i=1 and i=3)
+ count = s.executeUpdate("delete from t where i > 0");
+ assertThat(count, is(2));
+
+ // Verify table is empty
+ try (ResultSet r = s.executeQuery("select count(*) from t")) {
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(0));
+ }
+ }
+ }
+
+ @Test void testDeleteDuplicateRows() throws Exception {
+ try (Connection c = connect();
+ Statement s = c.createStatement()) {
+ s.execute("create table t (i int not null, j int not null)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (1, 10)");
+ s.executeUpdate("insert into t values (2, 20)");
+
+ final int count = s.executeUpdate("delete from t where i = 1 and j = 10");
+ assertThat(count, is(3));
+
+ try (ResultSet r = s.executeQuery("select i, j from t")) {
+ assertThat(r.next(), is(true));
+ assertThat(r.getInt(1), is(2));
+ assertThat(r.getInt(2), is(20));
+ assertThat(r.next(), is(false));
+ }
+ }
+ }
+
@Test void testStatement() throws Exception {
try (Connection c = connect();
Statement s = c.createStatement();