diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 87299800f..f3ba1fa3b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,7 @@ jobs: report_paths: "**/build/test-results/test/TEST-*.xml" unit_test_jdk8: - name: Unit test with docker service [JDK8] + name: Unit test with CLI runs-on: ubuntu-latest-16-cores timeout-minutes: 30 steps: @@ -82,7 +82,7 @@ jobs: - name: Set up Gradle uses: gradle/actions/setup-gradle@ac396bf1a80af16236baf54bd7330ae21dc6ece5 # v6 - - name: Start containerized server and dependencies + - name: Start CLI server env: TEMPORAL_CLI_VERSION: 1.7.0 run: | @@ -110,6 +110,7 @@ jobs: --dynamic-config-value history.enableRequestIdRefLinks=true \ --dynamic-config-value frontend.WorkerHeartbeatsEnabled=true \ --dynamic-config-value frontend.ListWorkersEnabled=true \ + --dynamic-config-value frontend.enableCancelWorkerPollsOnShutdown=true \ --dynamic-config-value 'component.callbacks.allowedAddresses=[{"Pattern":"localhost:7243","AllowInsecure":true}]' \ --dynamic-config-value frontend.activityAPIsEnabled=true \ --dynamic-config-value activity.enableStandalone=true \ diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java index b01503956..1dceb67fb 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java @@ -37,6 +37,7 @@ public ActivityPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, double activitiesPerSecond, @Nonnull TrackingSlotSupplier slotSupplier, @@ -53,6 +54,7 @@ public ActivityPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java index 520ce7a37..d2fddde3f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java @@ -105,6 +105,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), taskQueueActivitiesPerSecond, this.slotSupplier, @@ -113,7 +114,7 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { @@ -125,6 +126,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), taskQueueActivitiesPerSecond, this.slotSupplier, @@ -133,7 +135,8 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java index 60ebcbf65..b23d16184 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java @@ -43,6 +43,7 @@ public AsyncActivityPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, double activitiesPerSecond, @Nonnull TrackingSlotSupplier slotSupplier, @@ -59,6 +60,7 @@ public AsyncActivityPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java index efc4dc807..1ba3b84d1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java @@ -41,6 +41,7 @@ public AsyncNexusPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, @@ -57,6 +58,8 @@ public AsyncNexusPollTask( .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); + if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequest.setDeploymentOptions( WorkerVersioningProtoUtils.deploymentOptionsToProto( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java index 510634379..c56111c02 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java @@ -29,7 +29,6 @@ final class AsyncPoller extends BasePoller { private final List> asyncTaskPollers; private final PollerOptions pollerOptions; private final PollerBehaviorAutoscaling pollerBehavior; - private final boolean serverSupportsAutoscaling; private final Scope workerMetricsScope; private Throttler pollRateThrottler; private final Thread.UncaughtExceptionHandler uncaughtExceptionHandler = @@ -43,7 +42,7 @@ final class AsyncPoller extends BasePoller { PollTaskAsync asyncTaskPoller, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - boolean serverSupportsAutoscaling, + NamespaceCapabilities namespaceCapabilities, Scope workerMetricsScope) { this( slotSupplier, @@ -51,7 +50,7 @@ final class AsyncPoller extends BasePoller { Collections.singletonList(asyncTaskPoller), taskExecutor, pollerOptions, - serverSupportsAutoscaling, + namespaceCapabilities, workerMetricsScope); } @@ -61,9 +60,9 @@ final class AsyncPoller extends BasePoller { List> asyncTaskPollers, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - boolean serverSupportsAutoscaling, + NamespaceCapabilities namespaceCapabilities, Scope workerMetricsScope) { - super(taskExecutor); + super(taskExecutor, namespaceCapabilities); Objects.requireNonNull(slotSupplier, "slotSupplier cannot be null"); Objects.requireNonNull(slotReservationData, "slotReservation data should not be null"); Objects.requireNonNull(asyncTaskPollers, "asyncTaskPollers should not be null"); @@ -82,7 +81,6 @@ final class AsyncPoller extends BasePoller { + " is not supported for AsyncPoller. Only PollerBehaviorAutoscaling is supported."); } this.pollerBehavior = (PollerBehaviorAutoscaling) pollerOptions.getPollerBehavior(); - this.serverSupportsAutoscaling = serverSupportsAutoscaling; this.pollerOptions = pollerOptions; this.workerMetricsScope = workerMetricsScope; } @@ -114,7 +112,7 @@ public boolean start() { pollerBehavior.getMinConcurrentTaskPollers(), pollerBehavior.getMaxConcurrentTaskPollers(), pollerBehavior.getInitialConcurrentTaskPollers(), - serverSupportsAutoscaling, + namespaceCapabilities.isPollerAutoscaling(), (newTarget) -> { log.debug( "Updating maximum number of pollers for {} to: {}", @@ -136,12 +134,14 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean return super.shutdown(shutdownManager, interruptTasks) .thenApply( (f) -> { - for (PollTaskAsync asyncTaskPoller : asyncTaskPollers) { - try { - log.debug("Shutting down async poller: {}", asyncTaskPoller.getLabel()); - asyncTaskPoller.cancel(new RuntimeException("Shutting down poller")); - } catch (Throwable e) { - log.error("Error while cancelling poll task", e); + if (interruptTasks || !namespaceCapabilities.isGracefulPollShutdown()) { + for (PollTaskAsync asyncTaskPoller : asyncTaskPollers) { + try { + log.debug("Shutting down async poller: {}", asyncTaskPoller.getLabel()); + asyncTaskPoller.cancel(new RuntimeException("Shutting down poller")); + } catch (Throwable e) { + log.error("Error while cancelling poll task", e); + } } } return null; diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java index c30dbc9e1..3bfa796a3 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java @@ -52,6 +52,7 @@ public AsyncWorkflowPollTask( @Nonnull String taskQueue, @Nullable String stickyTaskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @@ -67,6 +68,8 @@ public AsyncWorkflowPollTask( .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); + pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); + if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( WorkerVersioningProtoUtils.deploymentOptionsToProto( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java index 9b8141fc0..855145b31 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java @@ -27,9 +27,14 @@ abstract class BasePoller implements SuspendableWorker { protected ExecutorService pollExecutor; - protected BasePoller(ShutdownableTaskExecutor taskExecutor) { + protected final NamespaceCapabilities namespaceCapabilities; + + protected BasePoller( + ShutdownableTaskExecutor taskExecutor, NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(taskExecutor, "taskExecutor should not be null"); this.taskExecutor = taskExecutor; + this.namespaceCapabilities = + Objects.requireNonNull(namespaceCapabilities, "namespaceCapabilities should not be null"); } @Override @@ -55,15 +60,24 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean return CompletableFuture.completedFuture(null); } - return shutdownManager - // it's ok to forcefully shutdown pollers, because they are stuck in a long poll call - // so we don't risk loosing any progress doing that. - .shutdownExecutorNow(pollExecutor, this + "#pollExecutor", Duration.ofSeconds(1)) - .exceptionally( - e -> { - log.error("Unexpected exception during shutdown", e); - return null; - }); + CompletableFuture pollExecutorShutdown; + if (namespaceCapabilities.isGracefulPollShutdown() && !interruptTasks) { + // When graceful poll shutdown is enabled, the server will complete outstanding polls with + // empty responses after ShutdownWorker is called. We simply wait for polls to return. + pollExecutorShutdown = + shutdownManager.shutdownExecutor( + pollExecutor, this + "#pollExecutor", Duration.ofSeconds(80)); + } else { + // ShutdownNow and old servers forcibly stop outstanding polls. + pollExecutorShutdown = + shutdownManager.shutdownExecutorNow( + pollExecutor, this + "#pollExecutor", Duration.ofSeconds(1)); + } + return pollExecutorShutdown.exceptionally( + e -> { + log.error("Unexpected exception during shutdown", e); + return null; + }); } @Override diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java index 8dcaa6f33..7fe0335b1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java @@ -52,8 +52,9 @@ public MultiThreadedPoller( PollTask pollTask, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - Scope workerMetricsScope) { - super(taskExecutor); + Scope workerMetricsScope, + NamespaceCapabilities namespaceCapabilities) { + super(taskExecutor, namespaceCapabilities); Objects.requireNonNull(identity, "identity cannot be null"); Objects.requireNonNull(pollTask, "poll service should not be null"); Objects.requireNonNull(pollerOptions, "pollerOptions should not be null"); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java index 4fa9d09a5..8c9f23270 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java @@ -1,5 +1,6 @@ package io.temporal.internal.worker; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -9,14 +10,31 @@ */ public final class NamespaceCapabilities { private final AtomicBoolean pollerAutoscaling = new AtomicBoolean(false); + private final AtomicBoolean gracefulPollShutdown = new AtomicBoolean(false); private final AtomicBoolean workerHeartbeats = new AtomicBoolean(false); + public void setFromCapabilities(Capabilities capabilities) { + if (capabilities.getPollerAutoscaling()) { + pollerAutoscaling.set(true); + } + if (capabilities.getWorkerPollCompleteOnShutdown()) { + gracefulPollShutdown.set(true); + } + if (capabilities.getWorkerHeartbeats()) { + workerHeartbeats.set(true); + } + } + public boolean isPollerAutoscaling() { return pollerAutoscaling.get(); } - public void setPollerAutoscaling(boolean value) { - pollerAutoscaling.set(value); + public boolean isGracefulPollShutdown() { + return gracefulPollShutdown.get(); + } + + public void setGracefulPollShutdown(boolean value) { + gracefulPollShutdown.set(value); } public boolean isWorkerHeartbeats() { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java index 4116825b9..0ccab5944 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java @@ -34,6 +34,7 @@ public NexusPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @@ -49,6 +50,7 @@ public NexusPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequest.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java index d826e5543..ac364a747 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java @@ -111,6 +111,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), workerMetricsScope, service.getServerCapabilities(), @@ -118,7 +119,7 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { poller = @@ -129,6 +130,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), this.slotSupplier, workerMetricsScope, @@ -136,7 +138,8 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java index d59911b0a..2e534841e 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ShutdownManager.java @@ -149,10 +149,18 @@ public CompletableFuture waitOnWorkerShutdownRequest( future.complete(null); } }, - scheduledExecutorService); + this::executeOrRunDirect); return future; } + private void executeOrRunDirect(Runnable command) { + try { + scheduledExecutorService.execute(command); + } catch (RejectedExecutionException e) { + command.run(); + } + } + @Override public void close() { scheduledExecutorService.shutdownNow(); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java index f8baba01d..559370772 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java @@ -40,6 +40,7 @@ public static final class Builder { private Duration drainStickyTaskQueueTimeout; private boolean usingVirtualThreads; private WorkerDeploymentOptions deploymentOptions; + private String workerInstanceKey; private Builder() {} @@ -64,6 +65,7 @@ private Builder(SingleWorkerOptions options) { this.drainStickyTaskQueueTimeout = options.getDrainStickyTaskQueueTimeout(); this.usingVirtualThreads = options.isUsingVirtualThreads(); this.deploymentOptions = options.getDeploymentOptions(); + this.workerInstanceKey = options.getWorkerInstanceKey(); } public Builder setIdentity(String identity) { @@ -155,6 +157,11 @@ public Builder setDeploymentOptions(WorkerDeploymentOptions deploymentOptions) { return this; } + public Builder setWorkerInstanceKey(String workerInstanceKey) { + this.workerInstanceKey = workerInstanceKey; + return this; + } + public SingleWorkerOptions build() { PollerOptions pollerOptions = this.pollerOptions; if (pollerOptions == null) { @@ -193,7 +200,8 @@ public SingleWorkerOptions build() { this.defaultHeartbeatThrottleInterval, drainStickyTaskQueueTimeout, usingVirtualThreads, - this.deploymentOptions); + this.deploymentOptions, + this.workerInstanceKey); } } @@ -214,6 +222,7 @@ public SingleWorkerOptions build() { private final Duration drainStickyTaskQueueTimeout; private final boolean usingVirtualThreads; private final WorkerDeploymentOptions deploymentOptions; + private final String workerInstanceKey; private SingleWorkerOptions( String identity, @@ -232,7 +241,8 @@ private SingleWorkerOptions( Duration defaultHeartbeatThrottleInterval, Duration drainStickyTaskQueueTimeout, boolean usingVirtualThreads, - WorkerDeploymentOptions deploymentOptions) { + WorkerDeploymentOptions deploymentOptions, + String workerInstanceKey) { this.identity = identity; this.binaryChecksum = binaryChecksum; this.buildId = buildId; @@ -250,6 +260,7 @@ private SingleWorkerOptions( this.drainStickyTaskQueueTimeout = drainStickyTaskQueueTimeout; this.usingVirtualThreads = usingVirtualThreads; this.deploymentOptions = deploymentOptions; + this.workerInstanceKey = workerInstanceKey; } public String getIdentity() { @@ -331,6 +342,10 @@ public WorkerDeploymentOptions getDeploymentOptions() { return deploymentOptions; } + public String getWorkerInstanceKey() { + return workerInstanceKey; + } + public WorkerVersioningOptions getWorkerVersioningOptions() { return new WorkerVersioningOptions( this.getBuildId(), this.isUsingBuildIdForVersioning(), this.getDeploymentOptions()); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java index 51ab7a700..18cf7fd4a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java @@ -3,9 +3,7 @@ import static io.temporal.internal.common.InternalUtils.createStickyTaskQueue; import io.temporal.api.common.v1.Payloads; -import io.temporal.api.enums.v1.TaskQueueType; import io.temporal.api.taskqueue.v1.TaskQueue; -import io.temporal.api.worker.v1.WorkerHeartbeat; import io.temporal.client.WorkflowClient; import io.temporal.common.converter.DataConverter; import io.temporal.common.converter.EncodedValues; @@ -24,11 +22,9 @@ import io.temporal.workflow.Functions.Func1; import java.lang.reflect.Type; import java.time.Duration; -import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.*; -import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -64,8 +60,6 @@ public SyncWorkflowWorker( @Nonnull WorkflowClient client, @Nonnull String namespace, @Nonnull String taskQueue, - @Nonnull String workerInstanceKey, - @Nonnull Supplier> activeTaskQueueTypesSupplier, @Nonnull SingleWorkerOptions singleWorkerOptions, @Nonnull SingleWorkerOptions localActivityOptions, @Nonnull WorkflowRunLockManager runLocks, @@ -123,8 +117,6 @@ public SyncWorkflowWorker( client.getWorkflowServiceStubs(), namespace, taskQueue, - workerInstanceKey, - activeTaskQueueTypesSupplier, stickyTaskQueueName, singleWorkerOptions, runLocks, @@ -250,10 +242,6 @@ public TrackingSlotSupplier getLocalActivitySlotSupplier( return laWorker.getSlotSupplier(); } - public void setHeartbeatSupplier(Supplier supplier) { - workflowWorker.setHeartbeatSupplier(supplier); - } - public boolean hasStickyQueue() { return workflowWorker.hasStickyQueue(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java index cdb5e5163..18607b5d1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java @@ -47,6 +47,7 @@ public WorkflowPollTask( @Nonnull String taskQueue, @Nullable String stickyTaskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull StickyQueueBalancer stickyQueueBalancer, @@ -73,6 +74,7 @@ public WorkflowPollTask( PollWorkflowTaskQueueRequest.newBuilder() .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); + pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java index f98316d5d..fbb82e467 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java @@ -13,11 +13,8 @@ import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.enums.v1.QueryResultType; import io.temporal.api.enums.v1.TaskQueueKind; -import io.temporal.api.enums.v1.TaskQueueType; -import io.temporal.api.enums.v1.WorkerStatus; import io.temporal.api.enums.v1.WorkflowTaskFailedCause; import io.temporal.api.failure.v1.Failure; -import io.temporal.api.worker.v1.WorkerHeartbeat; import io.temporal.api.workflowservice.v1.*; import io.temporal.failure.ApplicationFailure; import io.temporal.internal.logging.LoggerTag; @@ -33,7 +30,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -41,7 +37,6 @@ import org.slf4j.MDC; final class WorkflowWorker implements SuspendableWorker { - private static final String GRACEFUL_SHUTDOWN_MESSAGE = "graceful shutdown"; private static final Logger log = LoggerFactory.getLogger(WorkflowWorker.class); private final WorkflowRunLockManager runLocks; @@ -58,9 +53,6 @@ final class WorkflowWorker implements SuspendableWorker { private final GrpcRetryer grpcRetryer; private final EagerActivityDispatcher eagerActivityDispatcher; private final TrackingSlotSupplier slotSupplier; - private volatile Supplier heartbeatSupplier; - private final String workerInstanceKey; - private final Supplier> activeTaskQueueTypesSupplier; private final TaskCounter taskCounter = new TaskCounter(); private final PollerTracker pollerTracker = new PollerTracker(); @@ -79,8 +71,6 @@ public WorkflowWorker( @Nonnull WorkflowServiceStubs service, @Nonnull String namespace, @Nonnull String taskQueue, - @Nonnull String workerInstanceKey, - @Nonnull Supplier> activeTaskQueueTypesSupplier, @Nullable String stickyTaskQueueName, @Nonnull SingleWorkerOptions options, @Nonnull WorkflowRunLockManager runLocks, @@ -92,8 +82,6 @@ public WorkflowWorker( this.service = Objects.requireNonNull(service); this.namespace = Objects.requireNonNull(namespace); this.taskQueue = Objects.requireNonNull(taskQueue); - this.workerInstanceKey = Objects.requireNonNull(workerInstanceKey); - this.activeTaskQueueTypesSupplier = Objects.requireNonNull(activeTaskQueueTypesSupplier); this.options = Objects.requireNonNull(options); this.stickyTaskQueueName = stickyTaskQueueName; this.pollerOptions = getPollerOptions(options); @@ -133,6 +121,7 @@ public boolean start() { taskQueue, null, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -146,6 +135,7 @@ public boolean start() { taskQueue, stickyTaskQueueName, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -162,6 +152,7 @@ public boolean start() { taskQueue, null, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -175,7 +166,7 @@ public boolean start() { pollers, this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { PollerBehaviorSimpleMaximum pollerBehavior = @@ -193,6 +184,7 @@ public boolean start() { taskQueue, stickyTaskQueueName, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, stickyQueueBalancer, @@ -202,7 +194,8 @@ public boolean start() { stickyPollerTracker), pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); @@ -221,7 +214,8 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean !interruptTasks && !options.getDrainStickyTaskQueueTimeout().isZero() && stickyTaskQueueName != null - && stickyQueueBalancer != null; + && stickyQueueBalancer != null + && !namespaceCapabilities.isGracefulPollShutdown(); CompletableFuture pollerShutdown = CompletableFuture.completedFuture(null) @@ -232,46 +226,23 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean stickyQueueBalancer, options.getDrainStickyTaskQueueTimeout()) : CompletableFuture.completedFuture(null)) .thenCompose(ignore -> poller.shutdown(shutdownManager, interruptTasks)); - return CompletableFuture.allOf( - pollerShutdown.thenCompose( - ignore -> { - ShutdownWorkerRequest.Builder shutdownReq = - ShutdownWorkerRequest.newBuilder() - .setIdentity(options.getIdentity()) - .setNamespace(namespace) - .setTaskQueue(taskQueue) - .setWorkerInstanceKey(workerInstanceKey) - .setReason(GRACEFUL_SHUTDOWN_MESSAGE) - .addAllTaskQueueTypes(activeTaskQueueTypesSupplier.get()); - if (stickyTaskQueueName != null) { - shutdownReq.setStickyTaskQueue(stickyTaskQueueName); - } - if (heartbeatSupplier != null) { - shutdownReq.setWorkerHeartbeat( - heartbeatSupplier.get().toBuilder() - .setStatus(WorkerStatus.WORKER_STATUS_SHUTTING_DOWN) - .build()); - } - return shutdownManager.waitOnWorkerShutdownRequest( - service.futureStub().shutdownWorker(shutdownReq.build())); - }), - pollerShutdown - .thenCompose( - ignore -> - !interruptTasks - ? shutdownManager.waitForSupplierPermitsReleasedUnlimited( - slotSupplier, supplierName) - : CompletableFuture.completedFuture(null)) - .thenCompose( - ignore -> - pollTaskExecutor != null - ? pollTaskExecutor.shutdown(shutdownManager, interruptTasks) - : CompletableFuture.completedFuture(null)) - .exceptionally( - e -> { - log.error("Unexpected exception during shutdown", e); - return null; - })); + return pollerShutdown + .thenCompose( + ignore -> + !interruptTasks + ? shutdownManager.waitForSupplierPermitsReleasedUnlimited( + slotSupplier, supplierName) + : CompletableFuture.completedFuture(null)) + .thenCompose( + ignore -> + pollTaskExecutor != null + ? pollTaskExecutor.shutdown(shutdownManager, interruptTasks) + : CompletableFuture.completedFuture(null)) + .exceptionally( + e -> { + log.error("Unexpected exception during shutdown", e); + return null; + }); } @Override @@ -363,10 +334,6 @@ public WorkflowTaskDispatchHandle reserveWorkflowExecutor() { .orElse(null); } - public void setHeartbeatSupplier(Supplier supplier) { - this.heartbeatSupplier = supplier; - } - public TrackingSlotSupplier getSlotSupplier() { return slotSupplier; } diff --git a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java index ce599c6ad..c846667d6 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java @@ -13,6 +13,7 @@ import io.temporal.api.worker.v1.WorkerHostInfo; import io.temporal.api.worker.v1.WorkerPollerInfo; import io.temporal.api.worker.v1.WorkerSlotsInfo; +import io.temporal.api.workflowservice.v1.ShutdownWorkerRequest; import io.temporal.client.WorkflowClient; import io.temporal.client.WorkflowClientOptions; import io.temporal.common.Experimental; @@ -27,6 +28,7 @@ import io.temporal.internal.worker.TaskCounter; import io.temporal.serviceclient.MetricsTag; import io.temporal.serviceclient.Version; +import io.temporal.serviceclient.WorkflowServiceStubs; import io.temporal.worker.tuning.*; import io.temporal.workflow.Functions; import io.temporal.workflow.Functions.Func; @@ -59,17 +61,22 @@ public final class Worker { private static final Logger log = LoggerFactory.getLogger(Worker.class); private final WorkerOptions options; private final String taskQueue; + private final String workerInstanceKey = UUID.randomUUID().toString(); private final List plugins; + private final WorkflowServiceStubs service; + private final String namespace; + private final String identity; + private final String stickyTaskQueueName; final SyncWorkflowWorker workflowWorker; final SyncActivityWorker activityWorker; final SyncNexusWorker nexusWorker; private final AtomicBoolean started = new AtomicBoolean(); private volatile boolean shuttingDown = false; - private final String workerInstanceKey = UUID.randomUUID().toString(); private volatile Instant startTime; private final WorkflowClientOptions clientOptions; private final @Nonnull WorkflowExecutorCache cache; private final Map previousHeartbeatSnapshots = new ConcurrentHashMap<>(); + private volatile Supplier heartbeatSupplier; private static final class TaskSnapshot { final int processed; @@ -106,22 +113,30 @@ private static final class TaskSnapshot { @Nonnull NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(client, "client should not be null"); + Objects.requireNonNull(namespaceCapabilities, "namespaceCapabilities should not be null"); this.plugins = Objects.requireNonNull(plugins, "plugins should not be null"); Preconditions.checkArgument( !Strings.isNullOrEmpty(taskQueue), "taskQueue should not be an empty string"); this.taskQueue = taskQueue; + this.service = client.getWorkflowServiceStubs(); this.options = WorkerOptions.newBuilder(options).validateAndBuildWithDefaults(); this.clientOptions = client.getOptions(); this.cache = cache; factoryOptions = WorkerFactoryOptions.newBuilder(factoryOptions).validateAndBuildWithDefaults(); WorkflowClientOptions clientOptions = client.getOptions(); String namespace = clientOptions.getNamespace(); + this.namespace = namespace; Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); Scope taggedScope = metricsScope.tagged(tags); SingleWorkerOptions activityOptions = toActivityOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); if (this.options.isLocalActivityWorkerOnly()) { activityWorker = null; } else { @@ -149,7 +164,12 @@ private static final class TaskSnapshot { SingleWorkerOptions nexusOptions = toNexusOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); SlotSupplier nexusSlotSupplier = this.options.getWorkerTuner() == null ? new FixedSizeSlotSupplier<>(this.options.getMaxConcurrentNexusExecutionSize()) @@ -167,10 +187,16 @@ private static final class TaskSnapshot { clientOptions, taskQueue, contextPropagators, - taggedScope); + taggedScope, + workerInstanceKey); SingleWorkerOptions localActivityOptions = toLocalActivityOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); SlotSupplier workflowSlotSupplier = this.options.getWorkerTuner() == null @@ -183,18 +209,20 @@ private static final class TaskSnapshot { : this.options.getWorkerTuner().getLocalActivitySlotSupplier(); attachMetricsToResourceController(taggedScope, localActivitySlotSupplier); + this.identity = singleWorkerOptions.getIdentity(); + this.stickyTaskQueueName = + useStickyTaskQueue ? getStickyTaskQueueName(client.getOptions().getIdentity()) : null; + workflowWorker = new SyncWorkflowWorker( client, namespace, taskQueue, - workerInstanceKey, - this::getActiveTaskQueueTypes, singleWorkerOptions, localActivityOptions, runLocks, cache, - useStickyTaskQueue ? getStickyTaskQueueName(client.getOptions().getIdentity()) : null, + stickyTaskQueueName, workflowThreadExecutor, eagerActivityDispatcher, workflowSlotSupplier, @@ -455,18 +483,48 @@ void start() { CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptUserTasks) { shuttingDown = true; - CompletableFuture workflowWorkerShutdownFuture = - workflowWorker.shutdown(shutdownManager, interruptUserTasks); - CompletableFuture nexusWorkerShutdownFuture = - nexusWorker.shutdown(shutdownManager, interruptUserTasks); - if (activityWorker != null) { - return CompletableFuture.allOf( - activityWorker.shutdown(shutdownManager, interruptUserTasks), - workflowWorkerShutdownFuture, - nexusWorkerShutdownFuture); - } else { - return CompletableFuture.allOf(workflowWorkerShutdownFuture, nexusWorkerShutdownFuture); + ShutdownWorkerRequest.Builder requestBuilder = + ShutdownWorkerRequest.newBuilder() + .setNamespace(namespace) + .setIdentity(identity) + .setWorkerInstanceKey(workerInstanceKey) + .setTaskQueue(taskQueue) + .setReason("graceful shutdown") + .addAllTaskQueueTypes(getActiveTaskQueueTypes()); + if (stickyTaskQueueName != null) { + requestBuilder.setStickyTaskQueue(stickyTaskQueueName); + } + if (heartbeatSupplier != null) { + requestBuilder.setWorkerHeartbeat( + heartbeatSupplier.get().toBuilder() + .setStatus(WorkerStatus.WORKER_STATUS_SHUTTING_DOWN) + .build()); } + CompletableFuture shutdownWorkerRpc = + shutdownManager.waitOnWorkerShutdownRequest( + service.futureStub().shutdownWorker(requestBuilder.build())); + + // When interrupting tasks (shutdownNow), fire the RPC but don't block on it — proceed to + // shut down pollers immediately. For graceful shutdown, wait for the RPC so the server can + // complete outstanding polls with empty responses before we start waiting on them. + CompletableFuture preShutdown = + interruptUserTasks ? CompletableFuture.completedFuture(null) : shutdownWorkerRpc; + + return preShutdown.thenCompose( + ignore -> { + CompletableFuture workflowWorkerShutdownFuture = + workflowWorker.shutdown(shutdownManager, interruptUserTasks); + CompletableFuture nexusWorkerShutdownFuture = + nexusWorker.shutdown(shutdownManager, interruptUserTasks); + if (activityWorker != null) { + return CompletableFuture.allOf( + activityWorker.shutdown(shutdownManager, interruptUserTasks), + workflowWorkerShutdownFuture, + nexusWorkerShutdownFuture); + } else { + return CompletableFuture.allOf(workflowWorkerShutdownFuture, nexusWorkerShutdownFuture); + } + }); } boolean isTerminated() { @@ -491,6 +549,10 @@ String getWorkerInstanceKey() { return workerInstanceKey; } + void setHeartbeatSupplier(Supplier supplier) { + this.heartbeatSupplier = supplier; + } + List getActiveTaskQueueTypes() { List types = new ArrayList<>(); if (workflowWorker.isAnyTypeSupported()) { @@ -826,8 +888,10 @@ private static SingleWorkerOptions toActivityOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setUsingVirtualThreads(options.isUsingVirtualThreadsOnActivityWorker()) .setPollerOptions( PollerOptions.newBuilder() @@ -848,8 +912,10 @@ private static SingleWorkerOptions toNexusOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -870,7 +936,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( WorkflowClientOptions clientOptions, String taskQueue, List contextPropagators, - Scope metricsScope) { + Scope metricsScope, + String workerInstanceKey) { Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); @@ -899,7 +966,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( } } - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -921,8 +989,10 @@ private static SingleWorkerOptions toLocalActivityOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -939,7 +1009,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( WorkerFactoryOptions factoryOptions, WorkerOptions options, WorkflowClientOptions clientOptions, - List contextPropagators) { + List contextPropagators, + String workerInstanceKey) { String buildId = null; if (options.getBuildId() != null) { buildId = options.getBuildId(); @@ -962,7 +1033,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( .setWorkerInterceptors(factoryOptions.getWorkerInterceptors()) .setMaxHeartbeatThrottleInterval(options.getMaxHeartbeatThrottleInterval()) .setDefaultHeartbeatThrottleInterval(options.getDefaultHeartbeatThrottleInterval()) - .setDeploymentOptions(options.getDeploymentOptions()); + .setDeploymentOptions(options.getDeploymentOptions()) + .setWorkerInstanceKey(workerInstanceKey); } /** diff --git a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java index bbf8af3db..c0a949f82 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java @@ -268,17 +268,8 @@ public synchronized void start() { DescribeNamespaceRequest.newBuilder() .setNamespace(workflowClient.getOptions().getNamespace()) .build()); - if (describeNamespaceResponse.getNamespaceInfo().getCapabilities().getWorkerHeartbeats()) { - namespaceCapabilities.setWorkerHeartbeats(true); - } else { - log.debug( - "Server does not support worker heartbeats for namespace {}", - workflowClient.getOptions().getNamespace()); - } - - if (describeNamespaceResponse.getNamespaceInfo().getCapabilities().getPollerAutoscaling()) { - namespaceCapabilities.setPollerAutoscaling(true); - } + namespaceCapabilities.setFromCapabilities( + describeNamespaceResponse.getNamespaceInfo().getCapabilities()); // Build plugin execution chain (reverse order for proper nesting) Consumer startChain = WorkerFactory::doStart; @@ -321,7 +312,7 @@ private void doStart() { Supplier heartbeatSupplier = worker.buildHeartbeatCallback(workerGroupingKey); hbManager.registerWorker(namespace, worker.getWorkerInstanceKey(), heartbeatSupplier); - worker.workflowWorker.setHeartbeatSupplier(heartbeatSupplier); + worker.setHeartbeatSupplier(heartbeatSupplier); } } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java index 2ade97762..5faa34ca7 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java @@ -133,7 +133,7 @@ private AsyncPoller newPoller( pollTask, taskExecutor, options, - false, + new NamespaceCapabilities(), new NoopScope()); } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java new file mode 100644 index 000000000..137efa1da --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java @@ -0,0 +1,246 @@ +package io.temporal.internal.worker; + +import static org.junit.Assert.*; + +import com.uber.m3.tally.NoopScope; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; +import io.temporal.worker.tuning.PollerBehaviorSimpleMaximum; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nonnull; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Tests that an in-flight poll survives shutdown when graceful poll shutdown is enabled, and is + * killed promptly when it is not. + */ +@RunWith(Parameterized.class) +public class GracefulPollShutdownTest { + + @Parameterized.Parameter public boolean graceful; + + @Parameterized.Parameters(name = "graceful={0}") + public static Object[] data() { + return new Object[] {true, false}; + } + + @Test(timeout = 10_000) + public void inflightPollSurvivesShutdownOnlyWhenGraceful() throws Exception { + NamespaceCapabilities capabilities = new NamespaceCapabilities(); + capabilities.setFromCapabilities( + Capabilities.newBuilder().setWorkerPollCompleteOnShutdown(graceful).build()); + + AtomicReference processedTask = new AtomicReference<>(); + CountDownLatch taskProcessedLatch = new CountDownLatch(1); + ShutdownableTaskExecutor taskExecutor = + new ShutdownableTaskExecutor() { + @Override + public void process(@Nonnull String task) { + processedTask.set(task); + taskProcessedLatch.countDown(); + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public CompletableFuture shutdown( + ShutdownManager shutdownManager, boolean interruptTasks) { + return CompletableFuture.completedFuture(null); + } + + @Override + public void awaitTermination(long timeout, TimeUnit unit) {} + }; + + // -- poll task: first call returns immediately, second blocks until released -- + CountDownLatch secondPollStarted = new CountDownLatch(1); + CountDownLatch releasePoll = new CountDownLatch(1); + + MultiThreadedPoller.PollTask pollTask = + new MultiThreadedPoller.PollTask() { + private int callCount = 0; + + @Override + public synchronized String poll() { + callCount++; + if (callCount == 1) { + return "task-1"; + } else if (callCount == 2) { + secondPollStarted.countDown(); + try { + releasePoll.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + return "task-2"; + } + // Subsequent calls just block until interrupted (simulates long poll) + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return null; + } + }; + + // -- create poller with 1 thread so polls are sequential -- + MultiThreadedPoller poller = + new MultiThreadedPoller<>( + "test-identity", + pollTask, + taskExecutor, + PollerOptions.newBuilder() + .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) + .setPollThreadNamePrefix("test-poller") + .build(), + new NoopScope(), + capabilities); + + poller.start(); + + // Wait for the first task to be processed (proves poller is running) + assertTrue("first task should be processed", taskProcessedLatch.await(5, TimeUnit.SECONDS)); + assertEquals("task-1", processedTask.get()); + + // Wait for the second poll to be in-flight + assertTrue("second poll should start", secondPollStarted.await(5, TimeUnit.SECONDS)); + + // Trigger shutdown (don't interrupt tasks) + ShutdownManager shutdownManager = new ShutdownManager(); + CompletableFuture shutdownFuture = poller.shutdown(shutdownManager, false); + + if (graceful) { + // In graceful mode the poller waits for the in-flight poll to complete. + // The shutdown should NOT have completed yet since the poll is still blocked. + assertFalse("shutdown should not complete while poll is in-flight", shutdownFuture.isDone()); + + // Simulate the server returning the poll response (as it would after ShutdownWorker RPC) + releasePoll.countDown(); + + // Wait for shutdown to complete - the poll should return "task-2" and be processed + shutdownFuture.get(5, TimeUnit.SECONDS); + + assertEquals("task-2", processedTask.get()); + } else { + // In legacy mode the poller forcefully interrupts in-flight polls. + // Shutdown should complete quickly without releasing the blocked poll. + shutdownFuture.get(5, TimeUnit.SECONDS); + + // The second task should NOT have been processed since the poll was killed. + assertNotEquals( + "task-2 should not be processed in legacy mode", "task-2", processedTask.get()); + } + + shutdownManager.close(); + } + + @Test(timeout = 10_000) + public void shutdownNowInterruptsInflightPollWhenGraceful() throws Exception { + NamespaceCapabilities capabilities = new NamespaceCapabilities(); + capabilities.setFromCapabilities( + Capabilities.newBuilder().setWorkerPollCompleteOnShutdown(true).build()); + + AtomicReference processedTask = new AtomicReference<>(); + CountDownLatch taskProcessedLatch = new CountDownLatch(1); + ShutdownableTaskExecutor taskExecutor = + new ShutdownableTaskExecutor() { + @Override + public void process(@Nonnull String task) { + processedTask.set(task); + taskProcessedLatch.countDown(); + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public CompletableFuture shutdown( + ShutdownManager shutdownManager, boolean interruptTasks) { + return CompletableFuture.completedFuture(null); + } + + @Override + public void awaitTermination(long timeout, TimeUnit unit) {} + }; + + CountDownLatch secondPollStarted = new CountDownLatch(1); + CountDownLatch releasePoll = new CountDownLatch(1); + + MultiThreadedPoller.PollTask pollTask = + new MultiThreadedPoller.PollTask() { + private int callCount = 0; + + @Override + public synchronized String poll() { + callCount++; + if (callCount == 1) { + return "task-1"; + } else if (callCount == 2) { + secondPollStarted.countDown(); + try { + releasePoll.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + return "task-2"; + } + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return null; + } + }; + + MultiThreadedPoller poller = + new MultiThreadedPoller<>( + "test-identity", + pollTask, + taskExecutor, + PollerOptions.newBuilder() + .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) + .setPollThreadNamePrefix("test-poller") + .build(), + new NoopScope(), + capabilities); + + poller.start(); + + assertTrue("first task should be processed", taskProcessedLatch.await(5, TimeUnit.SECONDS)); + assertEquals("task-1", processedTask.get()); + assertTrue("second poll should start", secondPollStarted.await(5, TimeUnit.SECONDS)); + + ShutdownManager shutdownManager = new ShutdownManager(); + try { + poller.shutdown(shutdownManager, true).get(5, TimeUnit.SECONDS); + assertNotEquals( + "task-2 should not be processed by shutdownNow", "task-2", processedTask.get()); + } finally { + releasePoll.countDown(); + shutdownManager.close(); + } + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java index e4223c0b5..c6f11a61a 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java @@ -80,6 +80,7 @@ public void supplierIsCalledAppropriately() { TASK_QUEUE, "stickytaskqueue", "", + "test-instance-key", new WorkerVersioningOptions("", false, null), trackingSS, stickyQueueBalancer, @@ -172,6 +173,7 @@ public void asyncPollerSupplierIsCalledAppropriately() throws Exception { TASK_QUEUE, null, "", + "test-instance-key", new WorkerVersioningOptions("", false, null), trackingSS, metricsScope, diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java index 59538ac8b..ab806c960 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java @@ -68,6 +68,7 @@ public void stickyQueueBacklogResetTest() { "taskqueue", "stickytaskqueue", "", + "test-instance-key", new WorkerVersioningOptions("", false, null), slotSupplier, stickyQueueBalancer, @@ -97,6 +98,7 @@ public void stickyQueueBacklogResetTest() { .setKind(TaskQueueKind.TASK_QUEUE_KIND_STICKY) .build()) .setNamespace("default") + .setWorkerInstanceKey("test-instance-key") .build()))) .thenReturn(pollResponse); if (throwOnPoll) { diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java index d4e6e947c..d4f1824c2 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java @@ -14,7 +14,6 @@ import com.uber.m3.util.ImmutableMap; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.common.v1.WorkflowType; -import io.temporal.api.enums.v1.TaskQueueType; import io.temporal.api.workflowservice.v1.*; import io.temporal.common.reporter.TestStatsReporter; import io.temporal.internal.common.InternalUtils; @@ -30,12 +29,8 @@ import io.temporal.worker.tuning.SlotSupplier; import io.temporal.worker.tuning.WorkflowSlotInfo; import java.time.Duration; -import java.util.Arrays; -import java.util.List; import java.util.UUID; import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; import org.junit.Test; import org.mockito.stubbing.Answer; import org.slf4j.Logger; @@ -74,12 +69,11 @@ public void concurrentPollRequestLockTest() throws Exception { client, "default", "task_queue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky_task_queue", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(3)) @@ -246,12 +240,11 @@ public void respondWorkflowTaskFailureMetricTest() throws Exception { client, "default", "task_queue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky_task_queue", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -391,12 +384,11 @@ public boolean isAnyTypeSupported() { client, "default", "taskQueue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -444,80 +436,6 @@ public boolean isAnyTypeSupported() { worker.shutdown(new ShutdownManager(), true).get(); } - @Test - public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { - WorkflowServiceStubs client = mock(WorkflowServiceStubs.class); - when(client.getServerCapabilities()) - .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); - - WorkflowRunLockManager runLockManager = new WorkflowRunLockManager(); - Scope metricsScope = new NoopScope(); - WorkflowExecutorCache cache = new WorkflowExecutorCache(10, runLockManager, metricsScope); - SlotSupplier slotSupplier = new FixedSizeSlotSupplier<>(10); - - WorkflowTaskHandler taskHandler = mock(WorkflowTaskHandler.class); - when(taskHandler.isAnyTypeSupported()).thenReturn(true); - - // Supplier that starts with WORKFLOW only, then adds NEXUS later - AtomicReference> typesRef = - new AtomicReference<>(Arrays.asList(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); - Supplier> supplier = typesRef::get; - - EagerActivityDispatcher eagerActivityDispatcher = mock(EagerActivityDispatcher.class); - WorkflowWorker worker = - new WorkflowWorker( - client, - "default", - "task_queue", - "test-worker-instance-key", - supplier, - null, - SingleWorkerOptions.newBuilder() - .setIdentity("test_identity") - .setBuildId(UUID.randomUUID().toString()) - .setPollerOptions( - PollerOptions.newBuilder() - .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) - .build()) - .setMetricsScope(metricsScope) - .build(), - runLockManager, - cache, - taskHandler, - eagerActivityDispatcher, - slotSupplier, - new NamespaceCapabilities()); - - // Simulate registering Nexus after construction - typesRef.set( - Arrays.asList( - TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW, - TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY, - TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); - - WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = - mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); - when(client.futureStub()).thenReturn(futureStub); - when(futureStub.shutdownWorker(any(ShutdownWorkerRequest.class))) - .thenReturn(Futures.immediateFuture(ShutdownWorkerResponse.newBuilder().build())); - - worker.shutdown(new ShutdownManager(), true).get(5, TimeUnit.SECONDS); - - org.mockito.ArgumentCaptor captor = - org.mockito.ArgumentCaptor.forClass(ShutdownWorkerRequest.class); - verify(futureStub).shutdownWorker(captor.capture()); - List shutdownTypes = captor.getValue().getTaskQueueTypesList(); - assertTrue( - "ShutdownWorkerRequest should include NEXUS type added after construction", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); - assertTrue( - "ShutdownWorkerRequest should include WORKFLOW type", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); - assertTrue( - "ShutdownWorkerRequest should include ACTIVITY type", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY)); - } - private ReplayWorkflowFactory setUpMockWorkflowFactory() throws Throwable { ReplayWorkflow mockWorkflow = mock(ReplayWorkflow.class); ReplayWorkflowFactory mockFactory = mock(ReplayWorkflowFactory.class); diff --git a/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java new file mode 100644 index 000000000..e9f4c9a36 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java @@ -0,0 +1,155 @@ +package io.temporal.worker; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.google.common.util.concurrent.Futures; +import com.uber.m3.tally.NoopScope; +import com.uber.m3.tally.Scope; +import io.nexusrpc.handler.OperationHandler; +import io.nexusrpc.handler.OperationImpl; +import io.nexusrpc.handler.ServiceImpl; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.api.enums.v1.TaskQueueType; +import io.temporal.api.enums.v1.WorkerStatus; +import io.temporal.api.worker.v1.WorkerHeartbeat; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; +import io.temporal.api.workflowservice.v1.ShutdownWorkerRequest; +import io.temporal.api.workflowservice.v1.ShutdownWorkerResponse; +import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.internal.sync.WorkflowThreadExecutor; +import io.temporal.internal.worker.NamespaceCapabilities; +import io.temporal.internal.worker.ShutdownManager; +import io.temporal.internal.worker.WorkflowExecutorCache; +import io.temporal.internal.worker.WorkflowRunLockManager; +import io.temporal.serviceclient.WorkflowServiceStubs; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import io.temporal.workflow.shared.TestNexusServices; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +public class WorkerShutdownTest { + + @WorkflowInterface + public interface TestWorkflow { + @WorkflowMethod + void run(); + } + + public static class TestWorkflowImpl implements TestWorkflow { + @Override + public void run() {} + } + + @ActivityInterface + public interface TestActivity { + @ActivityMethod + void doThing(); + } + + public static class TestActivityImpl implements TestActivity { + @Override + public void doThing() {} + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return OperationHandler.sync((ctx, details, now) -> "Hello " + now); + } + } + + /** + * Verifies that the active task queue types in the ShutdownWorkerRequest are evaluated at + * shutdown time, not at Worker construction time. Types registered after construction must be + * reflected in the request. + */ + @Test + public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { + WorkflowServiceStubs service = mock(WorkflowServiceStubs.class); + when(service.getServerCapabilities()) + .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); + + WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = + mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); + when(service.futureStub()).thenReturn(futureStub); + when(futureStub.shutdownWorker(any(ShutdownWorkerRequest.class))) + .thenReturn(Futures.immediateFuture(ShutdownWorkerResponse.newBuilder().build())); + + WorkflowServiceGrpc.WorkflowServiceBlockingStub blockingStub = + mock(WorkflowServiceGrpc.WorkflowServiceBlockingStub.class); + when(service.blockingStub()).thenReturn(blockingStub); + when(blockingStub.withOption(any(), any())).thenReturn(blockingStub); + + WorkflowClient client = mock(WorkflowClient.class); + when(client.getWorkflowServiceStubs()).thenReturn(service); + when(client.getOptions()) + .thenReturn( + WorkflowClientOptions.newBuilder() + .setNamespace("test-ns") + .setIdentity("test-worker") + .validateAndBuildWithDefaults()); + + Scope metricsScope = new NoopScope(); + WorkflowRunLockManager runLocks = new WorkflowRunLockManager(); + WorkflowExecutorCache cache = new WorkflowExecutorCache(10, runLocks, metricsScope); + WorkflowThreadExecutor wfThreadExecutor = mock(WorkflowThreadExecutor.class); + + Worker worker = + new Worker( + client, + "test-task-queue", + WorkerFactoryOptions.newBuilder().build(), + WorkerOptions.newBuilder().build(), + metricsScope, + runLocks, + cache, + true, + wfThreadExecutor, + Collections.emptyList(), + Collections.emptyList(), + new NamespaceCapabilities()); + + // Register types AFTER worker construction. The request built by shutdown should reflect + // these registrations, proving that getActiveTaskQueueTypes() is evaluated lazily. + worker.registerWorkflowImplementationTypes(TestWorkflowImpl.class); + worker.registerActivitiesImplementations(new TestActivityImpl()); + worker.registerNexusServiceImplementation(new TestNexusServiceImpl()); + Supplier heartbeatSupplier = + () -> WorkerHeartbeat.newBuilder().setStatus(WorkerStatus.WORKER_STATUS_RUNNING).build(); + worker.setHeartbeatSupplier(heartbeatSupplier); + + worker.shutdown(new ShutdownManager(), true).get(5, TimeUnit.SECONDS); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(ShutdownWorkerRequest.class); + verify(futureStub).shutdownWorker(captor.capture()); + List shutdownTypes = captor.getValue().getTaskQueueTypesList(); + assertTrue( + "ShutdownWorkerRequest should include WORKFLOW type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); + assertTrue( + "ShutdownWorkerRequest should include ACTIVITY type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY)); + assertTrue( + "ShutdownWorkerRequest should include NEXUS type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); + assertEquals( + "ShutdownWorkerRequest heartbeat should report SHUTTING_DOWN", + WorkerStatus.WORKER_STATUS_SHUTTING_DOWN, + captor.getValue().getWorkerHeartbeat().getStatus()); + assertTrue( + "ShutdownWorkerRequest sticky task queue should be derived from worker identity", + captor.getValue().getStickyTaskQueue().startsWith("test-worker:")); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/worker/shutdown/GracefulPollShutdownIntegrationTest.java b/temporal-sdk/src/test/java/io/temporal/worker/shutdown/GracefulPollShutdownIntegrationTest.java new file mode 100644 index 000000000..e72acdb6b --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/worker/shutdown/GracefulPollShutdownIntegrationTest.java @@ -0,0 +1,149 @@ +package io.temporal.worker.shutdown; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; + +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.activity.ActivityOptions; +import io.temporal.api.enums.v1.EventType; +import io.temporal.api.history.v1.HistoryEvent; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; +import io.temporal.api.workflowservice.v1.DescribeNamespaceRequest; +import io.temporal.api.workflowservice.v1.DescribeNamespaceResponse; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowOptions; +import io.temporal.common.WorkflowExecutionHistory; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.workflow.Workflow; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.junit.Rule; +import org.junit.Test; + +public class GracefulPollShutdownIntegrationTest { + + private static final int WORKFLOW_COUNT = 10; + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setUseExternalService(true) + .setDoNotStart(true) + .setTestTimeoutSeconds(30) + .setWorkflowTypes(LoopWorkflowImpl.class) + .setActivityImplementations(new NoopActivityImpl()) + .build(); + + @Test + public void shutdownDuringActiveTimerActivityWorkflows() throws Exception { + assumeTrue( + "Requires real server with graceful poll shutdown support", + SDKTestWorkflowRule.useExternalService); + assumeTrue( + "Server does not support graceful poll shutdown", + getNamespaceCapabilities().getWorkerPollCompleteOnShutdown()); + + testWorkflowRule.getTestEnvironment().start(); + + WorkflowClient client = testWorkflowRule.getWorkflowClient(); + List workflowIds = new ArrayList<>(WORKFLOW_COUNT); + for (int i = 0; i < WORKFLOW_COUNT; i++) { + String workflowId = testWorkflowRule.getTaskQueue() + "-" + i; + LoopWorkflow workflow = + client.newWorkflowStub( + LoopWorkflow.class, + WorkflowOptions.newBuilder() + .setWorkflowId(workflowId) + .setTaskQueue(testWorkflowRule.getTaskQueue()) + .setWorkflowExecutionTimeout(Duration.ofMinutes(5)) + .build()); + WorkflowClient.start(workflow::run); + workflowIds.add(workflowId); + } + + Thread.sleep(2_000); + + long shutdownStartNanos = System.nanoTime(); + try { + testWorkflowRule.getTestEnvironment().shutdown(); + testWorkflowRule.getTestEnvironment().awaitTermination(10, TimeUnit.SECONDS); + Duration shutdownElapsed = Duration.ofNanos(System.nanoTime() - shutdownStartNanos); + assertTrue( + "Worker shutdown took " + shutdownElapsed + ", expected less than 5 seconds", + shutdownElapsed.compareTo(Duration.ofSeconds(5)) < 0); + } finally { + for (String workflowId : workflowIds) { + client.newUntypedWorkflowStub(workflowId).terminate("test cleanup"); + } + } + + for (String workflowId : workflowIds) { + WorkflowExecutionHistory history = client.fetchHistory(workflowId); + List badEvents = + history.getEvents().stream() + .filter( + e -> + e.getEventType() == EventType.EVENT_TYPE_WORKFLOW_TASK_FAILED + || e.getEventType() == EventType.EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT) + .collect(Collectors.toList()); + assertTrue( + "Workflow " + + workflowId + + " had unexpected workflow task failures/timeouts: " + + badEvents, + badEvents.isEmpty()); + } + } + + private Capabilities getNamespaceCapabilities() { + DescribeNamespaceResponse response = + testWorkflowRule + .getWorkflowClient() + .getWorkflowServiceStubs() + .blockingStub() + .describeNamespace( + DescribeNamespaceRequest.newBuilder() + .setNamespace(testWorkflowRule.getWorkflowClient().getOptions().getNamespace()) + .build()); + return response.getNamespaceInfo().getCapabilities(); + } + + @WorkflowInterface + public interface LoopWorkflow { + @WorkflowMethod + void run(); + } + + public static class LoopWorkflowImpl implements LoopWorkflow { + + private final NoopActivity activities = + Workflow.newActivityStub( + NoopActivity.class, + ActivityOptions.newBuilder().setStartToCloseTimeout(Duration.ofSeconds(10)).build()); + + @Override + public void run() { + while (true) { + Workflow.sleep(Duration.ofMillis(10)); + activities.noop(); + } + } + } + + @ActivityInterface + public interface NoopActivity { + @ActivityMethod + void noop(); + } + + public static class NoopActivityImpl implements NoopActivity { + @Override + public void noop() {} + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java index 67a83b981..b74fd2808 100644 --- a/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java +++ b/temporal-sdk/src/test/java/io/temporal/worker/shutdown/StickyWorkflowDrainShutdownTest.java @@ -3,6 +3,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import io.temporal.api.workflowservice.v1.DescribeNamespaceRequest; +import io.temporal.api.workflowservice.v1.DescribeNamespaceResponse; import io.temporal.client.WorkflowClient; import io.temporal.client.WorkflowStub; import io.temporal.serviceclient.WorkflowServiceStubsOptions; @@ -53,6 +55,7 @@ public StickyWorkflowDrainShutdownTest(PollerBehavior pollerBehaviorAutoscaling) @Test public void testShutdown() throws InterruptedException { + boolean gracefulPollShutdownSupported = isGracefulPollShutdownSupported(); TestWorkflow1 workflow = testWorkflowRule.newWorkflowStub(TestWorkflow1.class); WorkflowClient.start(workflow::execute, null); testWorkflowRule.getTestEnvironment().shutdown(); @@ -62,10 +65,16 @@ public void testShutdown() throws InterruptedException { assertTrue(testWorkflowRule.getTestEnvironment().getWorkerFactory().isTerminated()); System.out.println("Shutdown completed"); long endTime = System.currentTimeMillis(); - assertTrue("Drain time should be respected", endTime - startTime > DRAIN_TIME.toMillis()); - // Workflow should complete successfully since the drain time is longer than the workflow - // execution time - assertEquals("Success", workflow.execute(null)); + WorkflowStub untyped = WorkflowStub.fromTyped(workflow); + if (gracefulPollShutdownSupported) { + assertTrue("Drain time should be skipped", endTime - startTime < DRAIN_TIME.toMillis()); + untyped.terminate("test cleanup"); + } else { + assertTrue("Drain time should be respected", endTime - startTime > DRAIN_TIME.toMillis()); + // Workflow should complete successfully since the drain time is longer than the workflow + // execution time. + assertEquals("Success", untyped.getResult(String.class)); + } } @Test @@ -94,4 +103,17 @@ public String execute(String now) { return "Success"; } } + + private boolean isGracefulPollShutdownSupported() { + DescribeNamespaceResponse response = + testWorkflowRule + .getWorkflowClient() + .getWorkflowServiceStubs() + .blockingStub() + .describeNamespace( + DescribeNamespaceRequest.newBuilder() + .setNamespace(testWorkflowRule.getWorkflowClient().getOptions().getNamespace()) + .build()); + return response.getNamespaceInfo().getCapabilities().getWorkerPollCompleteOnShutdown(); + } }