Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ allowed_external_types = [
"bytes::*",
"tonic::*",
"futures_core::stream::Stream",
"tokio::sync::oneshot::Sender",
"tokio::sync::oneshot::*",
]

[features]
Expand Down
258 changes: 258 additions & 0 deletions grpc/src/client/metadata_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
use tokio::sync::oneshot;
use tonic::metadata::MetadataMap;

use crate::client::CallOptions;
use crate::client::InvokeOnce;
use crate::client::RecvStream;
use crate::client::interceptor::Intercept;
use crate::client::interceptor::InterceptOnce;
use crate::core::RequestHeaders;

/// An interceptor that attaches metadata to outgoing RPC headers.
pub struct AttachHeadersInterceptor {
md: MetadataMap,
}

impl AttachHeadersInterceptor {
pub fn new(md: MetadataMap) -> Self {
Self { md }
}
}

impl<I: InvokeOnce> Intercept<I> for AttachHeadersInterceptor {
type SendStream = I::SendStream;
type RecvStream = I::RecvStream;

async fn intercept(
&self,
mut headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
headers
.metadata_mut()
.as_mut()
.extend(self.md.as_ref().clone());

let md = headers.metadata_mut();
for entry in self.md.iter() {
match entry {
tonic::metadata::KeyAndValueRef::Ascii(k, v) => _ = md.insert(k, v.clone()),
tonic::metadata::KeyAndValueRef::Binary(k, v) => _ = md.insert_bin(k, v.clone()),
}
}
next.invoke_once(headers, options).await
}
}

/// An interceptor that reads the received headers' metadata from the stream and
/// sends them to the returned oneshot channel.
pub struct CaptureHeadersInterceptor {
tx: oneshot::Sender<MetadataMap>,
}

impl CaptureHeadersInterceptor {
pub fn new() -> (Self, oneshot::Receiver<MetadataMap>) {
let (tx, rx) = oneshot::channel();
(Self { tx }, rx)
}
}

impl<I: InvokeOnce> InterceptOnce<I> for CaptureHeadersInterceptor {
type SendStream = I::SendStream;
type RecvStream = CaptureHeadersRecvStream<I::RecvStream>;

async fn intercept_once(
self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (tx, rx) = next.invoke_once(headers, options).await;
(tx, CaptureHeadersRecvStream::new(rx, self.tx))
}
}

pub struct CaptureHeadersRecvStream<R> {
rx: R,
tx: Option<oneshot::Sender<MetadataMap>>,
}

impl<R> CaptureHeadersRecvStream<R> {
pub fn new(rx: R, tx: oneshot::Sender<MetadataMap>) -> Self {
Self { rx, tx: Some(tx) }
}
}

impl<R: RecvStream> RecvStream for CaptureHeadersRecvStream<R> {
async fn next(&mut self, msg: &mut dyn super::RecvMessage) -> super::ClientResponseStreamItem {
let res = self.rx.next(msg).await;
if let super::ClientResponseStreamItem::Headers(headers) = &res
&& let Some(tx) = self.tx.take()
{
_ = tx.send(headers.metadata().clone());
}
res
}
}

/// An interceptor that reads the received trailers' metadata from the stream
/// and sends them to the returned oneshot channel.
pub struct CaptureTrailersInterceptor {
tx: oneshot::Sender<MetadataMap>,
}

impl CaptureTrailersInterceptor {
pub fn new() -> (Self, oneshot::Receiver<MetadataMap>) {
let (tx, rx) = oneshot::channel();
(Self { tx }, rx)
}
}

impl<I: InvokeOnce> InterceptOnce<I> for CaptureTrailersInterceptor {
type SendStream = I::SendStream;
type RecvStream = CaptureTrailersRecvStream<I::RecvStream>;

async fn intercept_once(
self,
headers: RequestHeaders,
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
let (tx, rx) = next.invoke_once(headers, options).await;
(tx, CaptureTrailersRecvStream::new(rx, self.tx))
}
}

pub struct CaptureTrailersRecvStream<R> {
rx: R,
tx: Option<oneshot::Sender<MetadataMap>>,
}

impl<R> CaptureTrailersRecvStream<R> {
pub fn new(rx: R, tx: oneshot::Sender<MetadataMap>) -> Self {
Self { rx, tx: Some(tx) }
}
}

impl<R: RecvStream> RecvStream for CaptureTrailersRecvStream<R> {
async fn next(&mut self, msg: &mut dyn super::RecvMessage) -> super::ClientResponseStreamItem {
let res = self.rx.next(msg).await;
if let super::ClientResponseStreamItem::Trailers(trailers) = &res
&& let Some(tx) = self.tx.take()
{
_ = tx.send(trailers.metadata().clone());
}
res
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::Status;
use crate::StatusCode;
use crate::client::test_util::MockInvoker;
use crate::client::test_util::NopRecvMessage;
use crate::core::ClientResponseStreamItem;
use crate::core::ResponseHeaders;
use crate::core::Trailers;

#[tokio::test]
async fn test_attach_headers_interceptor() {
// Create test interceptor with metadata to attach.
let mut md = MetadataMap::new();
md.insert("x-test-header", "test-value".parse().unwrap());
md.insert_bin(
"x-test-header-bin",
tonic::metadata::MetadataValue::from_bytes(b"test-bin"),
);
let interceptor = AttachHeadersInterceptor::new(md);

// Call the interceptor with additional headers in place.
let (invoker, _) = MockInvoker::new();
let mut initial_headers = RequestHeaders::default();
initial_headers
.metadata_mut()
.insert("x-initial-header", "initial".parse().unwrap());
let _ = interceptor
.intercept(initial_headers, CallOptions::default(), &invoker)
.await;

// Verify the received headers include all values.
let final_headers = invoker.req_headers.lock().unwrap().take().unwrap();
assert_eq!(
final_headers.metadata().get("x-test-header").unwrap(),
"test-value"
);
assert_eq!(
final_headers
.metadata()
.get_bin("x-test-header-bin")
.unwrap(),
b"test-bin".as_slice()
);
assert_eq!(
final_headers.metadata().get("x-initial-header").unwrap(),
"initial"
);
}

#[tokio::test]
async fn test_capture_headers_interceptor() {
// Create test interceptor.
let (interceptor, rx) = CaptureHeadersInterceptor::new();

// Start a call through the interceptor.
let (invoker, mut controller) = MockInvoker::new();
let (_, mut recv_stream) = interceptor
.intercept_once(RequestHeaders::default(), CallOptions::default(), &invoker)
.await;

// Send a Headers response on the call.
let mut resp_md = MetadataMap::new();
resp_md.insert("x-resp-header", "resp-value".parse().unwrap());
let mut headers = ResponseHeaders::default();
*headers.metadata_mut() = resp_md;
controller
.send_resp(ClientResponseStreamItem::Headers(headers))
.await;

// Receive the sent Headers response.
let res = recv_stream.next(&mut NopRecvMessage).await;
assert!(matches!(res, ClientResponseStreamItem::Headers(_)));

// Verify the received headers are correct.
let captured_md = rx.await.unwrap();
assert_eq!(captured_md.get("x-resp-header").unwrap(), "resp-value");
}

#[tokio::test]
async fn test_capture_trailers_interceptor() {
// Create test interceptor.
let (interceptor, rx) = CaptureTrailersInterceptor::new();

// Start a call through the interceptor.
let (invoker, mut controller) = MockInvoker::new();
let (_, mut recv_stream) = interceptor
.intercept_once(RequestHeaders::default(), CallOptions::default(), &invoker)
.await;

// Send a Trailers response on the call.
let mut trailers_md = MetadataMap::new();
trailers_md.insert("x-trailer", "trailer-value".parse().unwrap());
let mut trailers = Trailers::new(Status::new(StatusCode::Ok, "ok"));
*trailers.metadata_mut() = trailers_md;
controller
.send_resp(ClientResponseStreamItem::Trailers(trailers))
.await;

// Receive the sent Trailers response.
let res = recv_stream.next(&mut NopRecvMessage).await;
assert!(matches!(res, ClientResponseStreamItem::Trailers(_)));

// Verify the received trailers are correct.
let captured_md = rx.await.unwrap();
assert_eq!(captured_md.get("x-trailer").unwrap(), "trailer-value");
}
}
43 changes: 41 additions & 2 deletions grpc/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use crate::core::SendMessage;

pub mod channel;
pub mod interceptor;
pub mod metadata_utils;
pub mod service_config;
pub mod stream_util;

Expand All @@ -45,6 +46,9 @@ pub(crate) mod name_resolution;
mod subchannel;
pub(crate) mod transport;

#[cfg(test)]
mod test_util;

/// A representation of the current state of a gRPC channel, also used for the
/// state of subchannels (individual connections within the channel).
///
Expand Down Expand Up @@ -87,6 +91,25 @@ pub struct CallOptions {
deadline: Option<Instant>,
}

impl CallOptions {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm - the CallOptions and SendOptions changes are both orthogonal to this PR? I know we have to pipe them through the affected call sites but I didn't see that we were actually using them yet. If there was a change here that required them I want to be sure to understand it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah yes, I think these were supposed to stay with #2549, but I'd rather keep it like this if that's OK since I have a whole slew of changes stacked here.

pub fn new() -> Self {
Self::default()
}

pub fn with_deadline(mut self, deadline: Instant) -> Self {
self.deadline = Some(deadline);
self
}

pub fn set_deadline(&mut self, deadline: Instant) {
self.deadline = Some(deadline);
}

pub fn deadline(&self) -> Option<Instant> {
self.deadline
}
}

/// A trait which may be implemented by types to perform RPCs (Remote Procedure
/// Calls, often shortened to "call").
///
Expand Down Expand Up @@ -193,7 +216,7 @@ impl<T: SendStream> DynSendStream for T {
}
}

impl SendStream for Box<dyn DynSendStream> {
impl<'a> SendStream for Box<dyn DynSendStream + 'a> {
async fn send(&mut self, msg: &dyn SendMessage, options: SendOptions) -> Result<(), ()> {
(**self).dyn_send(msg, options).await
}
Expand All @@ -212,6 +235,22 @@ pub struct SendOptions {
pub disable_compression: bool,
}

impl SendOptions {
pub fn new() -> Self {
Self::default()
}

pub fn with_final_msg(mut self, final_msg: bool) -> Self {
self.final_msg = final_msg;
self
}

pub fn with_disable_compression(mut self, disable_compression: bool) -> Self {
self.disable_compression = disable_compression;
self
}
}

/// Represents the receiving side of a client stream. When a `RecvStream` is
/// dropped, the associated call is cancelled if the server has not already
/// terminated the stream.
Expand Down Expand Up @@ -243,7 +282,7 @@ impl<T: RecvStream> DynRecvStream for T {
}
}

impl RecvStream for Box<dyn DynRecvStream> {
impl<'a> RecvStream for Box<dyn DynRecvStream + 'a> {
async fn next(&mut self, msg: &mut dyn RecvMessage) -> ClientResponseStreamItem {
(**self).dyn_next(msg).await
}
Expand Down
Loading
Loading