@@ -32,6 +32,7 @@ import (
3232 "net/netip"
3333 "time"
3434
35+ x "github.com/celzero/firestack/intra/backend"
3536 "github.com/celzero/firestack/intra/dnsx"
3637 "github.com/celzero/firestack/intra/log"
3738 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -96,7 +97,7 @@ func (h *tcpHandler) Error(gconn *netstack.GTCPConn, src, dst netip.AddrPort, er
9697 h .maybeReplaceDest (res , & dst )
9798
9899 cid , uid , fid , pids := h .judge (res )
99- smm := tcpSummary (cid , uid , dst .Addr ())
100+ smm := tcpSummary (cid , uid , src . Addr (), dst .Addr ())
100101
101102 if isAnyBlockPid (pids ) {
102103 smm .PID = ipn .Block
@@ -122,7 +123,7 @@ func (h *tcpHandler) Error(gconn *netstack.GTCPConn, src, dst netip.AddrPort, er
122123func (h * tcpHandler ) ReverseProxy (gconn * netstack.GTCPConn , in net.Conn , to , from netip.AddrPort ) (open bool ) {
123124 fm := h .onInflow (to , from )
124125 cid , uid , _ , pids := h .judge (fm )
125- smm := tcpSummary (cid , uid , from .Addr ())
126+ smm := tcpSummary (cid , uid , to . Addr (), from .Addr ())
126127
127128 if settings .Debug {
128129 log .VV ("tcp: %s [%s]: reverse: %s => %s; pids: %v" , cid , uid , from , to , pids )
@@ -149,6 +150,20 @@ func (h *tcpHandler) ReverseProxy(gconn *netstack.GTCPConn, in net.Conn, to, fro
149150 return true
150151}
151152
153+ func (h * tcpHandler ) handshakeIfNeededOrClose (gconn * netstack.GTCPConn , smm * SocketSummary ) (bool , error ) {
154+ const allow bool = true // allowed
155+ const deny bool = ! allow // blocked
156+
157+ if _ , err := gconn .Establish (); err != nil {
158+ err = log .EE ("tcp: %s handshake err %v; %s => %s for %s" ,
159+ smm .ID , err , smm .Source , smm .Target , smm .UID )
160+ clos (gconn )
161+ h .queueSummary (smm .done (err ))
162+ return deny , err // == !open
163+ }
164+ return allow , nil
165+ }
166+
152167// Proxy implements netstack.GTCPConnHandler
153168// It must be called from a goroutine.
154169func (h * tcpHandler ) Proxy (gconn * netstack.GTCPConn , src , target netip.AddrPort ) (open bool ) {
@@ -180,7 +195,7 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
180195 actualTargets = []netip.AddrPort {target }
181196 }
182197 // actualTargets[0] may be same as target
183- smm = tcpSummary (cid , uid , actualTargets [0 ].Addr ())
198+ smm = tcpSummary (cid , uid , src . Addr (), actualTargets [0 ].Addr ())
184199
185200 if h .status .Load () == HDLEND {
186201 err = log .EE ("tcp: proxy: %s end %s => %s [%v]" , cid , src , target , actualTargets )
@@ -206,16 +221,8 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
206221 return deny
207222 }
208223
209- // handshake; since we assume a duplex-stream from here on
210- if _ , err := gconn .Establish (); err != nil {
211- err = log .EE ("tcp: %s connect1 err %v; %s => %s for %s" , cid , err , src , target , uid )
212- clos (gconn )
213- h .queueSummary (smm .done (err ))
214- return deny // == !open
215- }
216-
217224 if isAnyBasePid (pids ) { // see udp.go:Connect
218- if h .dnsOverride (gconn , target , uid ) {
225+ if synack , _ := h . handshakeIfNeededOrClose ( gconn , smm ); synack && h .dnsOverride (gconn , target , uid ) {
219226 // SocketSummary not sent; x.DNSSummary supercedes it
220227 // conn closed by resolver
221228 return allow
@@ -227,15 +234,18 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
227234 cid , src , target , actualTargets , uid , pids )
228235 }
229236
237+ cont := true
230238 boundSrc := makeAnyAddrPort (src )
231239 // pick all realips to connect to
232240 for i , dstipp := range actualTargets {
241+ targetstr := dstipp .Addr ().String ()
242+
233243 var px ipn.Proxy = nil
234244 px , err = h .prox .ProxyTo (dstipp , uid , pids )
235245
236246 // last chosen (but not dialed in) proxy (which error)
237- smm .Target = dstipp . Addr (). String () // addr may be invalid
238- smm .PID = pidstr (px ) // px may be nil
247+ smm .Target = targetstr // addr may be invalid
248+ smm .PID = pidstr (px ) // px may be nil
239249 smm .RPID = ipn .ViaID (px )
240250
241251 if err != nil || px == nil {
@@ -244,13 +254,13 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
244254 continue
245255 }
246256
247- if err = h .handle (px , gconn , boundSrc , dstipp , smm ); err == nil {
257+ if cont , err = h .handle (px , gconn , boundSrc , dstipp , smm ); err == nil {
248258 return allow // smm instead queued by handle() => forward()
249259 } else {
250260 end := time .Since (smm .start )
251- err = log .WE ("tcp: dial: #%d: %s failed; addr(%s) / fallback? %t; for uid %s (%s); w err(%v)" ,
252- i , cid , dstipp , fallingback , uid , core .FmtPeriod (end ), err )
253- if end > retryTimeout {
261+ err = log .WE ("tcp: dial: #%d: %s failed; addr(%s) / fallback? %t / cont? %t ; for uid %s (%s); w err(%v)" ,
262+ i , cid , dstipp , fallingback , cont , uid , core .FmtPeriod (end ), err )
263+ if ! cont || end > retryTimeout {
254264 break // return err
255265 } // else: continue; try the next realip
256266 }
@@ -262,24 +272,33 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
262272}
263273
264274// handle connects to the target via the proxy, and pipes data between the src, target; thread-safe.
265- func (h * tcpHandler ) handle (px ipn.Proxy , src * netstack.GTCPConn , boundSrc , target netip.AddrPort , smm * SocketSummary ) (err error ) {
275+ func (h * tcpHandler ) handle (px ipn.Proxy , src * netstack.GTCPConn , boundSrc , target netip.AddrPort , smm * SocketSummary ) (next bool , err error ) {
276+ cont := true
277+ stop := ! cont
278+ targetstr := target .String ()
279+
280+ // make sure to not synack in HappyEyeballs scenarios
281+ if canroute := px .Router ().Contains (x .StrOf (targetstr )); ! canroute {
282+ return cont , log .WE ("proxy(%s) has no route to %s" , pidstr (px ), targetstr )
283+ }
284+
266285 var pc protect.Conn
267286 var dst net.Conn
268287
269288 start := time .Now ()
270289
271290 if settings .Debug {
272291 log .VV ("tcp: %s dial %s: attempt: %s [%s] => %s for %s" ,
273- smm .ID , pidstr (px ), src .LocalAddr (), boundSrc , target , smm .UID )
292+ smm .ID , pidstr (px ), src .LocalAddr (), boundSrc , targetstr , smm .UID )
274293 }
275294
276295 // github.com/google/gvisor/blob/5ba35f516b5c2/test/benchmarks/tcp/tcp_proxy.go#L359
277296 // ref: stackoverflow.com/questions/63656117
278297 // ref: stackoverflow.com/questions/40328025
279298 if settings .PortForward .Load () {
280- pc , err = px .Dialer ().DialBind ("tcp" , boundSrc .String (), target . String () )
299+ pc , err = px .Dialer ().DialBind ("tcp" , boundSrc .String (), targetstr )
281300 } else {
282- pc , err = px .Dialer ().Dial ("tcp" , target . String () )
301+ pc , err = px .Dialer ().Dial ("tcp" , targetstr )
283302 }
284303 if err == nil {
285304 smm .Rtt = time .Since (start ).Milliseconds ()
@@ -307,15 +326,20 @@ func (h *tcpHandler) handle(px ipn.Proxy, src *netstack.GTCPConn, boundSrc, targ
307326 smm .RPID = ipn .ViaID (px )
308327
309328 if err != nil {
310- clos (src , pc )
329+ clos (pc )
311330 log .W ("tcp: err dialing %s proxy(%s) to dst(%v) for %s: %v" ,
312- smm .ID , px .ID (), target , smm .UID , err )
313- return err
331+ smm .ID , smm .PID , smm .Target , smm .UID , err )
332+ return cont , err
333+ }
334+
335+ if _ , synackerr := h .handshakeIfNeededOrClose (src , smm ); synackerr != nil {
336+ clos (pc )
337+ return stop , synackerr
314338 }
315339
316340 core .Go ("tcp.forward." + smm .ID , func () {
317341 h .listener .PostFlow (smm .postMark ())
318342 h .forward (src , rwext {dst , tcptimeout }, smm ) // src always *gonet.TCPConn
319343 })
320- return nil // handled; takes ownership of src
344+ return cont , nil // handled; takes ownership of src
321345}
0 commit comments