diff --git a/security-framework/src/secure_transport.rs b/security-framework/src/secure_transport.rs index f522c713..326a5d32 100644 --- a/security-framework/src/secure_transport.rs +++ b/security-framework/src/secure_transport.rs @@ -889,7 +889,15 @@ where S: Write { let mut ret = errSecSuccess; while start < data.len() { - match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) { + let write_res = panic::catch_unwind(AssertUnwindSafe(|| { + let count = conn.stream.write(&data[start..])?; + // Need to flush during the handshake so that the handshake doesn't stall on buffered + // write streams. It would be better if we only flushed automatically during the + // handshake, and not for the remainder of the stream. + conn.stream.flush()?; + Ok(count) + })); + match write_res { Ok(Ok(0)) => { ret = errSSLClosedNoNotify; break; @@ -1449,6 +1457,49 @@ mod test { ctx.handshake(stream).expect_err("expected failure"); } + #[test] + fn connect_buffered_stream() { + use std::io::BufWriter; + + /// Small wrapper around a `TcpStream` to provide buffered writes. + #[derive(Debug)] + struct BufferedTcpStream { + reader: TcpStream, + writer: BufWriter, + } + + impl BufferedTcpStream { + fn new(tcp: TcpStream) -> std::io::Result { + Ok(Self { + writer: BufWriter::with_capacity(500, tcp.try_clone()?), + reader: tcp, + }) + } + } + + impl Read for BufferedTcpStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.reader.read(buf) + } + } + + impl Write for BufferedTcpStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.writer.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.writer.flush() + } + } + + let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)); + p!(ctx.set_peer_domain_name("google.com")); + let stream = p!(TcpStream::connect("google.com:443")); + let stream = p!(BufferedTcpStream::new(stream)); + p!(ctx.handshake(stream)); + } + #[test] fn load_page() { let mut ctx = p!(SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM));