@@ -14,6 +14,7 @@ import (
1414 "golang.org/x/net/proxy"
1515 "google.golang.org/grpc"
1616 "google.golang.org/grpc/codes"
17+ "google.golang.org/grpc/connectivity"
1718 "google.golang.org/grpc/credentials"
1819 "google.golang.org/grpc/credentials/insecure"
1920 "google.golang.org/grpc/keepalive"
@@ -30,72 +31,54 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc {
3031 }
3132}
3233
33- // BlockingDial is a helper method to dial the given address, using optional TLS credentials,
34+ // BlockingNewClient is a helper method to dial the given address, using optional TLS credentials,
3435// and blocking until the returned connection is ready. If the given credentials are nil, the
3536// connection will be insecure (plain-text).
3637// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go
37- func BlockingDial (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
38- // grpc.Dial doesn't provide any information on permanent connection errors (like
39- // TLS handshake failures). So in order to provide good error messages, we need a
40- // custom dialer that can provide that info. That means we manage the TLS handshake.
41- result := make (chan any , 1 )
42- writeResult := func (res any ) {
43- // non-blocking write: we only need the first result
44- select {
45- case result <- res :
46- default :
47- }
38+ func BlockingNewClient (ctx context.Context , network , address string , creds credentials.TransportCredentials , opts ... grpc.DialOption ) (* grpc.ClientConn , error ) {
39+ rawConn , err := proxy .Dial (ctx , network , address )
40+ if err != nil {
41+ return nil , fmt .Errorf ("error dial proxy: %w" , err )
4842 }
49-
50- dialer := func (ctx context.Context , address string ) (net.Conn , error ) {
51- conn , err := proxy .Dial (ctx , network , address )
43+ if creds != nil {
44+ rawConn , _ , err = creds .ClientHandshake (ctx , address , rawConn )
5245 if err != nil {
53- writeResult (err )
54- return nil , fmt .Errorf ("error dial proxy: %w" , err )
46+ return nil , fmt .Errorf ("error creating connection: %w" , err )
5547 }
56- if creds != nil {
57- conn , _ , err = creds .ClientHandshake (ctx , address , conn )
58- if err != nil {
59- writeResult (err )
60- return nil , fmt .Errorf ("error creating connection: %w" , err )
61- }
62- }
63- return conn , nil
6448 }
6549
66- // Even with grpc.FailOnNonTempDialError, this call will usually timeout in
67- // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
68- // know when we're done. So we run it in a goroutine and then use result
69- // channel to either get the channel or fail-fast.
70- go func () {
71- opts = append (opts ,
72- //nolint:staticcheck
73- grpc .WithBlock (),
74- //nolint:staticcheck
75- grpc .FailOnNonTempDialError (true ),
76- grpc .WithContextDialer (dialer ),
77- grpc .WithTransportCredentials (insecure .NewCredentials ()), // we are handling TLS, so tell grpc not to
78- grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
79- )
80- //nolint:staticcheck
81- conn , err := grpc .DialContext (ctx , address , opts ... )
82- var res any
83- if err != nil {
84- res = err
85- } else {
86- res = conn
87- }
88- writeResult (res )
89- }()
50+ customDialer := func (_ context.Context , _ string ) (net.Conn , error ) {
51+ return rawConn , nil
52+ }
53+
54+ opts = append (opts ,
55+ grpc .WithContextDialer (customDialer ),
56+ grpc .WithTransportCredentials (insecure .NewCredentials ()), // we are handling TLS, so tell grpc not to
57+ grpc .WithKeepaliveParams (keepalive.ClientParameters {Time : common .GetGRPCKeepAliveTime ()}),
58+ )
9059
91- select {
92- case res := <- result :
93- if conn , ok := res .(* grpc.ClientConn ); ok {
94- return conn , nil
60+ conn , err := grpc .NewClient (address , opts ... )
61+ if err != nil {
62+ return nil , err
63+ }
64+
65+ conn .Connect ()
66+ if err := waitForReady (ctx , conn ); err != nil {
67+ return nil , err
68+ }
69+
70+ return conn , nil
71+ }
72+
73+ func waitForReady (ctx context.Context , conn * grpc.ClientConn ) error {
74+ for {
75+ state := conn .GetState ()
76+ if state == connectivity .Ready {
77+ return nil
78+ }
79+ if ! conn .WaitForStateChange (ctx , state ) {
80+ return ctx .Err () // timeout or canceled
9581 }
96- return nil , res .(error )
97- case <- ctx .Done ():
98- return nil , ctx .Err ()
9982 }
10083}
10184
@@ -119,15 +102,15 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
119102 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
120103 defer cancel ()
121104
122- conn , err := BlockingDial (ctx , "tcp" , address , creds )
105+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
123106 if err == nil {
124107 _ = conn .Close ()
125108 testResult .TLS = true
126109 creds := credentials .NewTLS (& tls.Config {})
127110 ctx , cancel := context .WithTimeout (context .Background (), dialTime )
128111 defer cancel ()
129112
130- conn , err := BlockingDial (ctx , "tcp" , address , creds )
113+ conn , err := BlockingNewClient (ctx , "tcp" , address , creds )
131114 if err == nil {
132115 _ = conn .Close ()
133116 } else {
@@ -142,7 +125,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) {
142125 // refused). Test if server accepts plain-text connections
143126 ctx , cancel = context .WithTimeout (context .Background (), dialTime )
144127 defer cancel ()
145- conn , err = BlockingDial (ctx , "tcp" , address , nil )
128+ conn , err = BlockingNewClient (ctx , "tcp" , address , nil )
146129 if err == nil {
147130 _ = conn .Close ()
148131 testResult .TLS = false
0 commit comments