From d44aea156ba4946e4942b13844e0716d7b3eaecb Mon Sep 17 00:00:00 2001 From: Pranav Iyer Date: Wed, 20 May 2026 17:56:05 -0700 Subject: [PATCH] feat(gax): Implement cert-rotation retries for grpc and http-json. --- .../com/google/api/gax/grpc/ChannelPool.java | 8 + .../google/api/gax/grpc/GrpcCallContext.java | 58 +++-- .../api/gax/grpc/GrpcTransportChannel.java | 8 + .../api/gax/httpjson/HttpJsonCallContext.java | 68 +++-- .../InstantiatingHttpJsonChannelProvider.java | 22 +- .../gax/httpjson/ManagedHttpJsonChannel.java | 2 + .../ManagedHttpJsonInterceptorChannel.java | 5 + .../httpjson/RefreshingHttpJsonChannel.java | 233 ++++++++++++++++++ .../google/api/gax/rpc/ApiCallContext.java | 8 + .../api/gax/rpc/ApiResultRetryAlgorithm.java | 8 + .../google/api/gax/rpc/AttemptCallable.java | 22 ++ .../api/gax/rpc/BidiStreamingCallable.java | 38 ++- .../api/gax/rpc/ClientStreamingCallable.java | 33 ++- .../rpc/ServerStreamingAttemptCallable.java | 13 + .../google/api/gax/rpc/TransportChannel.java | 8 + 15 files changed, 488 insertions(+), 46 deletions(-) create mode 100644 sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index d611c96ff4c8..d35dbc8d12ca 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -82,6 +82,7 @@ class ChannelPool extends ManagedChannel { private ScheduledFuture resizeFuture = null; private final Object entryWriteLock = new Object(); + private long lastRefreshTimeNanos = 0; @VisibleForTesting final AtomicReference> entries = new AtomicReference<>(); private final AtomicInteger indexTicker = new AtomicInteger(); private final String authority; @@ -441,6 +442,13 @@ void refresh() { // - then thread2 will shut down channel that thread1 will put back into circulation (after it // replaces the list) synchronized (entryWriteLock) { + long now = System.nanoTime(); + if (now - lastRefreshTimeNanos < TimeUnit.SECONDS.toNanos(5)) { + LOG.fine("Channel pool was refreshed recently, skipping duplicate refresh"); + return; + } + lastRefreshTimeNanos = now; + LOG.fine("Refreshing all channels"); ArrayList newEntries = new ArrayList<>(entries.get()); diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index 7ff7c54de6f0..fb5e2edb0d07 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -97,6 +97,7 @@ public final class GrpcCallContext implements ApiCallContext { private final ApiCallContextOptions options; private final EndpointContext endpointContext; private final boolean isDirectPath; + @Nullable private final TransportChannel transportChannel; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { @@ -113,7 +114,8 @@ public static GrpcCallContext createDefault() { null, null, null, - false); + false, + null); } /** Returns an instance with the given channel and {@link CallOptions}. */ @@ -131,7 +133,8 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { null, null, null, - false); + false, + null); } private GrpcCallContext( @@ -147,7 +150,8 @@ private GrpcCallContext( @Nullable RetrySettings retrySettings, @Nullable Set retryableCodes, @Nullable EndpointContext endpointContext, - boolean isDirectPath) { + boolean isDirectPath, + @Nullable TransportChannel transportChannel) { this.channel = channel; this.credentials = credentials; Preconditions.checkNotNull(callOptions); @@ -167,6 +171,7 @@ private GrpcCallContext( this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; this.isDirectPath = isDirectPath; + this.transportChannel = transportChannel; } /** @@ -208,7 +213,13 @@ public GrpcCallContext withCredentials(Credentials newCredentials) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); + } + + @Override + public TransportChannel getTransportChannel() { + return transportChannel; } @Override @@ -232,7 +243,8 @@ public GrpcCallContext withTransportChannel(TransportChannel inputChannel) { retrySettings, retryableCodes, endpointContext, - transportChannel.isDirectPath()); + transportChannel.isDirectPath(), + inputChannel); } @Override @@ -251,7 +263,8 @@ public GrpcCallContext withEndpointContext(EndpointContext endpointContext) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -286,7 +299,8 @@ public GrpcCallContext withTimeoutDuration(@Nullable java.time.Duration timeout) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -335,7 +349,8 @@ public GrpcCallContext withStreamWaitTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** @@ -370,7 +385,8 @@ public GrpcCallContext withStreamIdleTimeoutDuration( retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for channel affinity is not stable yet and may change in the future.") @@ -388,7 +404,8 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -410,7 +427,8 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -433,7 +451,8 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -456,7 +475,8 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } @Override @@ -558,7 +578,8 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newRetrySettings, newRetryableCodes, endpointContext, - newIsDirectPath); + newIsDirectPath, + transportChannel); } /** The {@link Channel} set on this context. */ @@ -641,7 +662,8 @@ public GrpcCallContext withChannel(Channel newChannel) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** Returns a new instance with the call options set to the given call options. */ @@ -659,7 +681,8 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } public GrpcCallContext withRequestParamsDynamicHeaderOption(String requestParams) { @@ -704,7 +727,8 @@ public GrpcCallContext withOption(Key key, T value) { retrySettings, retryableCodes, endpointContext, - isDirectPath); + isDirectPath, + transportChannel); } /** {@inheritDoc} */ diff --git a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java index 2fa0908f17bc..80d471701d5a 100644 --- a/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java +++ b/sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcTransportChannel.java @@ -66,6 +66,14 @@ public Channel getChannel() { return getManagedChannel(); } + @Override + public void refresh() { + Channel channel = getChannel(); + if (channel instanceof ChannelPool) { + ((ChannelPool) channel).refresh(); + } + } + @Override public void shutdown() { getManagedChannel().shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java index c946e9aab03d..aa167d93e8b6 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java @@ -80,6 +80,7 @@ public final class HttpJsonCallContext implements ApiCallContext { @Nullable private final RetrySettings retrySettings; @Nullable private final ImmutableSet retryableCodes; private final EndpointContext endpointContext; + @Nullable private final TransportChannel transportChannel; /** Returns an empty instance. */ public static HttpJsonCallContext createDefault() { @@ -94,6 +95,7 @@ public static HttpJsonCallContext createDefault() { null, null, null, + null, null); } @@ -109,6 +111,7 @@ public static HttpJsonCallContext of(HttpJsonChannel channel, HttpJsonCallOption null, null, null, + null, null); } @@ -123,7 +126,8 @@ private HttpJsonCallContext( ApiTracer tracer, RetrySettings defaultRetrySettings, Set defaultRetryableCodes, - @Nullable EndpointContext endpointContext) { + @Nullable EndpointContext endpointContext, + @Nullable TransportChannel transportChannel) { this.channel = channel; this.callOptions = callOptions; this.timeout = timeout; @@ -139,6 +143,7 @@ private HttpJsonCallContext( // a valid EndpointContext with user configurations after the client has been initialized. this.endpointContext = endpointContext == null ? EndpointContext.getDefaultInstance() : endpointContext; + this.transportChannel = transportChannel; } /** @@ -231,7 +236,8 @@ public HttpJsonCallContext merge(ApiCallContext inputCallContext) { newTracer, newRetrySettings, newRetryableCodes, - endpointContext); + endpointContext, + this.transportChannel); } @Override @@ -249,7 +255,24 @@ public HttpJsonCallContext withTransportChannel(TransportChannel inputChannel) { "Expected HttpJsonTransportChannel, got " + inputChannel.getClass().getName()); } HttpJsonTransportChannel transportChannel = (HttpJsonTransportChannel) inputChannel; - return withChannel(transportChannel.getChannel()); + return new HttpJsonCallContext( + transportChannel.getChannel(), + this.callOptions, + this.timeout, + this.streamWaitTimeout, + this.streamIdleTimeout, + this.extraHeaders, + this.options, + this.tracer, + this.retrySettings, + this.retryableCodes, + this.endpointContext, + transportChannel); + } + + @Override + public TransportChannel getTransportChannel() { + return transportChannel; } /** This method is obsolete. Use {@link #withTimeoutDuration(java.time.Duration)} instead. */ @@ -273,7 +296,8 @@ public HttpJsonCallContext withEndpointContext(EndpointContext endpointContext) this.tracer, this.retrySettings, this.retryableCodes, - endpointContext); + endpointContext, + this.transportChannel); } @Override @@ -299,7 +323,8 @@ public HttpJsonCallContext withTimeoutDuration(java.time.Duration timeout) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** This method is obsolete. Use {@link #getTimeoutDuration()} instead. */ @@ -346,7 +371,8 @@ public HttpJsonCallContext withStreamWaitTimeoutDuration( this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** This method is obsolete. Use {@link #getStreamWaitTimeoutDuration()} instead. */ @@ -398,7 +424,8 @@ public HttpJsonCallContext withStreamIdleTimeoutDuration( this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** This method is obsolete. Use {@link #getStreamIdleTimeoutDuration()} instead. */ @@ -437,7 +464,8 @@ public ApiCallContext withExtraHeaders(Map> extraHeaders) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @BetaApi("The surface for extra headers is not stable yet and may change in the future.") @@ -461,7 +489,8 @@ public ApiCallContext withOption(Key key, T value) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } /** {@inheritDoc} */ @@ -533,7 +562,8 @@ public HttpJsonCallContext withRetrySettings(RetrySettings retrySettings) { this.tracer, retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @Override @@ -554,7 +584,8 @@ public HttpJsonCallContext withRetryableCodes(Set retryableCode this.tracer, this.retrySettings, retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { @@ -569,7 +600,8 @@ public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } public HttpJsonCallContext withCallOptions(HttpJsonCallOptions newCallOptions) { @@ -584,7 +616,8 @@ public HttpJsonCallContext withCallOptions(HttpJsonCallOptions newCallOptions) { this.tracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @Deprecated @@ -620,7 +653,8 @@ public HttpJsonCallContext withTracer(@Nonnull ApiTracer newTracer) { newTracer, this.retrySettings, this.retryableCodes, - this.endpointContext); + this.endpointContext, + this.transportChannel); } @Override @@ -640,7 +674,8 @@ public boolean equals(Object o) { && Objects.equals(this.tracer, that.tracer) && Objects.equals(this.retrySettings, that.retrySettings) && Objects.equals(this.retryableCodes, that.retryableCodes) - && Objects.equals(this.endpointContext, that.endpointContext); + && Objects.equals(this.endpointContext, that.endpointContext) + && Objects.equals(this.transportChannel, that.transportChannel); } @Override @@ -654,6 +689,7 @@ public int hashCode() { tracer, retrySettings, retryableCodes, - endpointContext); + endpointContext, + transportChannel); } } diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java index daf94a498cc4..347701816dc9 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/InstantiatingHttpJsonChannelProvider.java @@ -198,19 +198,23 @@ HttpTransport createHttpTransport() throws IOException, GeneralSecurityException } private HttpJsonTransportChannel createChannel() throws IOException, GeneralSecurityException { - HttpTransport httpTransportToUse = httpTransport; - if (httpTransportToUse == null) { - httpTransportToUse = createHttpTransport(); - } - - // Pass the executor to the ManagedChannel. If no executor was provided (or null), - // the channel will use a default executor for the calls. - ManagedHttpJsonChannel channel = - ManagedHttpJsonChannel.newBuilder() + java.util.function.Supplier channelFactory = () -> { + try { + HttpTransport httpTransportToUse = httpTransport; + if (httpTransportToUse == null) { + httpTransportToUse = createHttpTransport(); + } + return ManagedHttpJsonChannel.newBuilder() .setEndpoint(endpoint) .setExecutor(executor) .setHttpTransport(httpTransportToUse) .build(); + } catch (Exception e) { + throw new java.lang.RuntimeException("Failed to create fresh ManagedHttpJsonChannel", e); + } + }; + + ManagedHttpJsonChannel channel = new RefreshingHttpJsonChannel(channelFactory); HttpJsonClientInterceptor headerInterceptor = new HttpJsonHeaderInterceptor(headerProvider.getHeaders()); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java index bd3bed855608..6d800e579897 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonChannel.java @@ -86,6 +86,8 @@ public HttpJsonClientCall newCall( deadlineScheduledExecutorService); } + public void refresh() {} + @VisibleForTesting Executor getExecutor() { return executor; diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java index 3e71031f1c9d..f01d37a02c3f 100644 --- a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/ManagedHttpJsonInterceptorChannel.java @@ -55,6 +55,11 @@ public HttpJsonClientCall newCall( return interceptor.interceptCall(methodDescriptor, callOptions, channel); } + @Override + public void refresh() { + channel.refresh(); + } + @Override public synchronized void shutdown() { channel.shutdown(); diff --git a/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java new file mode 100644 index 000000000000..71754bd7d14d --- /dev/null +++ b/sdk-platform-java/gax-java/gax-httpjson/src/main/java/com/google/api/gax/httpjson/RefreshingHttpJsonChannel.java @@ -0,0 +1,233 @@ +/* + * Copyright 2026 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.httpjson; + +import com.google.api.core.InternalApi; +import com.google.api.gax.httpjson.ForwardingHttpJsonClientCall.SimpleForwardingHttpJsonClientCall; +import com.google.api.gax.httpjson.ForwardingHttpJsonClientCallListener.SimpleForwardingHttpJsonClientCallListener; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * An implementation of {@link ManagedHttpJsonChannel} that supports dynamic mTLS certificate + * rotation by thread-safely hot-swapping the underlying active HTTP/JSON channel while gracefully + * retiring older connections after all active in-flight requests complete. + */ +@InternalApi +public class RefreshingHttpJsonChannel extends ManagedHttpJsonChannel { + + private static final Logger LOG = Logger.getLogger(RefreshingHttpJsonChannel.class.getName()); + private static final long REFRESH_COOLDOWN_MS = 5000; + + private final Supplier channelFactory; + private final AtomicReference activeEntry; + private final AtomicBoolean refreshInProgress = new AtomicBoolean(false); + private final Object lock = new Object(); + private long lastRefreshTimeMs = 0; + + public RefreshingHttpJsonChannel(Supplier channelFactory) { + this.channelFactory = channelFactory; + this.activeEntry = new AtomicReference<>(new ChannelEntry(channelFactory.get())); + } + + @Override + public void refresh() { + // 1. Lock-free CAS coalescing check to prevent duplicate queueing/blocking of concurrent threads + if (!refreshInProgress.compareAndSet(false, true)) { + return; + } + try { + synchronized (lock) { + long now = System.currentTimeMillis(); + if (now - lastRefreshTimeMs < REFRESH_COOLDOWN_MS) { + LOG.fine("HTTP/JSON channel pool refreshed recently, skipping duplicate refresh"); + return; + } + + LOG.info("mTLS certificate rotation detected. Triggering HTTP/JSON channel pool refresh."); + ChannelEntry newEntry = new ChannelEntry(channelFactory.get()); + ChannelEntry oldEntry = activeEntry.getAndSet(newEntry); + + if (oldEntry != null) { + oldEntry.requestShutdown(); + } + + lastRefreshTimeMs = now; + } + } finally { + refreshInProgress.set(false); + } + } + + private ChannelEntry getRetainedEntry() { + while (true) { + ChannelEntry entry = activeEntry.get(); + if (entry.retain()) { + return entry; + } + } + } + + @Override + public HttpJsonClientCall newCall( + ApiMethodDescriptor methodDescriptor, HttpJsonCallOptions callOptions) { + ChannelEntry entry = getRetainedEntry(); + try { + HttpJsonClientCall delegateCall = + entry.channel.newCall(methodDescriptor, callOptions); + return new ReleasingHttpJsonClientCall<>(delegateCall, entry); + } catch (Exception e) { + entry.release(); + throw e; + } + } + + @Override + public void shutdown() { + activeEntry.get().requestShutdown(); + } + + @Override + public boolean isShutdown() { + return activeEntry.get().channel.isShutdown(); + } + + @Override + public boolean isTerminated() { + return activeEntry.get().channel.isTerminated(); + } + + @Override + public void shutdownNow() { + activeEntry.get().channel.shutdownNow(); + } + + @Override + public boolean awaitTermination(long duration, TimeUnit unit) throws InterruptedException { + return activeEntry.get().channel.awaitTermination(duration, unit); + } + + @Override + public void close() { + shutdown(); + } + + /** Internal container to manage request reference-counting and graceful shutdown. */ + private static class ChannelEntry { + private final ManagedHttpJsonChannel channel; + private final AtomicInteger outstandingCalls = new AtomicInteger(0); + private final AtomicBoolean shutdownRequested = new AtomicBoolean(false); + private final AtomicBoolean shutdownInitiated = new AtomicBoolean(false); + + ChannelEntry(ManagedHttpJsonChannel channel) { + this.channel = channel; + } + + boolean retain() { + outstandingCalls.incrementAndGet(); + if (shutdownRequested.get()) { + release(); + return false; + } + return true; + } + + void release() { + int count = outstandingCalls.decrementAndGet(); + if (shutdownRequested.get() && count == 0) { + shutdown(); + } + } + + void requestShutdown() { + shutdownRequested.set(true); + if (outstandingCalls.get() == 0) { + shutdown(); + } + } + + private void shutdown() { + if (shutdownInitiated.compareAndSet(false, true)) { + try { + channel.shutdown(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error shutting down retired HTTP/JSON channel", e); + } + } + } + } + + /** A client call decorator that decrements the entry counter upon call completion. */ + private static class ReleasingHttpJsonClientCall + extends SimpleForwardingHttpJsonClientCall { + + private final ChannelEntry entry; + private final AtomicBoolean wasClosed = new AtomicBoolean(false); + private final AtomicBoolean wasReleased = new AtomicBoolean(false); + + ReleasingHttpJsonClientCall(HttpJsonClientCall delegate, ChannelEntry entry) { + super(delegate); + this.entry = entry; + } + + @Override + public void start(Listener responseListener, HttpJsonMetadata requestHeaders) { + try { + super.start( + new SimpleForwardingHttpJsonClientCallListener(responseListener) { + @Override + public void onClose(int statusCode, HttpJsonMetadata trailers) { + if (!wasClosed.compareAndSet(false, true)) { + return; + } + try { + super.onClose(statusCode, trailers); + } finally { + if (wasReleased.compareAndSet(false, true)) { + entry.release(); + } + } + } + }, + requestHeaders); + } catch (Exception e) { + if (wasReleased.compareAndSet(false, true)) { + entry.release(); + } + throw e; + } + } + } +} diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java index 09af475e4833..fc7fb5e989fe 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java @@ -63,6 +63,14 @@ public interface ApiCallContext extends RetryingContext { /** Returns a new ApiCallContext with the given channel set. */ ApiCallContext withTransportChannel(TransportChannel channel); + /** + * Returns the {@link TransportChannel} associated with this call context, or {@code null} if none + * is set. + */ + default TransportChannel getTransportChannel() { + return null; + } + /** Returns a new ApiCallContext with the given Endpoint Context. */ ApiCallContext withEndpointContext(EndpointContext endpointContext); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java index 688fc32cd14b..7c8fad8497e9 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ApiResultRetryAlgorithm.java @@ -38,6 +38,10 @@ class ApiResultRetryAlgorithm extends BasicResultRetryAlgorithm internalFuture = callable.futureCall(request, callContext); + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + final ApiCallContext finalContext = callContext; + ApiFutures.addCallback( + internalFuture, + new com.google.api.core.ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } + + @Override + public void onSuccess(ResponseT result) {} + }, + com.google.common.util.concurrent.MoreExecutors.directExecutor()); + } + externalFuture.setAttemptFuture(internalFuture); } catch (Throwable e) { externalFuture.setAttemptFuture(ApiFutures.immediateFailedFuture(e)); diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java index 38efb2da3755..59d6099b2d5b 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/BidiStreamingCallable.java @@ -236,11 +236,45 @@ public BidiStreamingCallable withDefaultCallContext( return new BidiStreamingCallable() { @Override public ClientStream internalCall( - ResponseObserver responseObserver, + final ResponseObserver responseObserver, ClientStreamReadyObserver onReady, ApiCallContext thisCallContext) { + final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); + ResponseObserver refreshingObserver = responseObserver; + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + refreshingObserver = + new ResponseObserver() { + @Override + public void onStart(StreamController controller) { + responseObserver.onStart(controller); + } + + @Override + public void onResponse(ResponseT response) { + responseObserver.onResponse(response); + } + + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + responseObserver.onError(t); + } + + @Override + public void onComplete() { + responseObserver.onComplete(); + } + }; + } + return BidiStreamingCallable.this.internalCall( - responseObserver, onReady, defaultCallContext.merge(thisCallContext)); + refreshingObserver, onReady, mergedContext); } }; } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java index 13ef1c64568b..c172e93ba20b 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ClientStreamingCallable.java @@ -73,9 +73,38 @@ public ClientStreamingCallable withDefaultCallContext( return new ClientStreamingCallable() { @Override public ApiStreamObserver clientStreamingCall( - ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { + final ApiStreamObserver responseObserver, ApiCallContext thisCallContext) { + final ApiCallContext mergedContext = defaultCallContext.merge(thisCallContext); + ApiStreamObserver refreshingObserver = responseObserver; + + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + refreshingObserver = + new ApiStreamObserver() { + @Override + public void onNext(ResponseT response) { + responseObserver.onNext(response); + } + + @Override + public void onError(Throwable t) { + if (t instanceof UnauthenticatedException) { + TransportChannel transportChannel = mergedContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + return ClientStreamingCallable.this.clientStreamingCall( - responseObserver, defaultCallContext.merge(thisCallContext)); + refreshingObserver, mergedContext); } }; } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java index da0c8de632da..3fe6441d762c 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/ServerStreamingAttemptCallable.java @@ -219,6 +219,7 @@ public Void call() { .getTracer() .attemptStarted(request, outerRetryingFuture.getAttemptSettings().getOverallAttemptCount()); + final ApiCallContext finalContext = attemptContext; innerCallable.call( request, new StateCheckingResponseObserver() { @@ -234,6 +235,18 @@ public void onResponseImpl(ResponseT response) { @Override public void onErrorImpl(Throwable t) { + if ("true".equalsIgnoreCase(System.getenv("isMwlidEnvironment"))) { + Throwable cause = t; + if (cause instanceof com.google.api.gax.retrying.ServerStreamingAttemptException) { + cause = cause.getCause(); + } + if (cause instanceof UnauthenticatedException) { + TransportChannel transportChannel = finalContext.getTransportChannel(); + if (transportChannel != null) { + transportChannel.refresh(); + } + } + } onAttemptError(t); } diff --git a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java index d54352e9b246..65b3cce0e0a3 100644 --- a/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java +++ b/sdk-platform-java/gax-java/gax/src/main/java/com/google/api/gax/rpc/TransportChannel.java @@ -47,4 +47,12 @@ public interface TransportChannel extends BackgroundResource { * Returns an empty {@link ApiCallContext} that is compatible with this {@code TransportChannel}. */ ApiCallContext getEmptyCallContext(); + + /** + * Refreshes or recreates the underlying network connections of this transport channel. + * + *

By default, this is a no-op for transports that do not require stateful connection lifecycle + * management. + */ + default void refresh() {} }