Skip to content

Commit 4b3dc4c

Browse files
committed
tcp: handshake iff egress is successful
1 parent 7186cfd commit 4b3dc4c

5 files changed

Lines changed: 66 additions & 41 deletions

File tree

intra/listener.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ type SocketSummary struct {
2929
RPID string
3030
// UID of the app that owns this socket (sans ICMP).
3131
UID string
32+
// Source IP.
33+
Source string
3234
// Remote IP, if dialed in.
3335
Target string
3436
// Total bytes downloaded.
@@ -125,19 +127,20 @@ func icmpSummary(id, uid string) *SocketSummary {
125127
}
126128
}
127129

128-
func tcpSummary(id, uid string, dst netip.Addr) *SocketSummary {
130+
func tcpSummary(id, uid string, src, dst netip.Addr) *SocketSummary {
129131
return &SocketSummary{
130132
Proto: ProtoTypeTCP,
131133
ID: id,
132134
UID: uid,
135+
Source: src.String(),
133136
Target: dst.String(),
134137
start: time.Now(),
135138
Msg: errNone.Error(),
136139
}
137140
}
138141

139-
func udpSummary(id, uid string, dst netip.Addr) *SocketSummary {
140-
s := tcpSummary(id, uid, dst)
142+
func udpSummary(id, uid string, src, dst netip.Addr) *SocketSummary {
143+
s := tcpSummary(id, uid, src, dst)
141144
s.Proto = ProtoTypeUDP
142145
return s
143146
}

intra/netstack/tcp.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,8 @@ const rcvwnd = 0
2929

3030
const maxInFlight = 512 // arbitrary
3131

32-
// syn-ack before delivering to handler?
33-
const earlyConnect = false
34-
35-
// retry connect when earlyConnect fails?
36-
const retryLateConnect = earlyConnect && true
32+
// retry connect when early connect (done when no happy eyeballs) fails?
33+
const retryLateConnect = false
3734

3835
var (
3936
// defaults: github.com/google/gvisor/blob/fa49677e141db/pkg/tcpip/transport/tcp/protocol.go#L73
@@ -71,7 +68,8 @@ type GTCPConn struct {
7168
func InboundTCP(who string, s *stack.Stack, in net.Conn, to, from netip.AddrPort, h GTCPConnHandler) error {
7269
newgc := makeGTCPConn(who, s, nil /*not a forwarder req*/, to, from)
7370

74-
if earlyConnect {
71+
// early syn/ack is okay if happy eyeballs isn't strictly required
72+
if !settings.HappyEyeballs.Load() {
7573
open, err := newgc.tryConnect()
7674

7775
if settings.Debug {
@@ -127,7 +125,7 @@ func tcpForwarder(who string, s *stack.Stack, h GTCPConnHandler) *tcp.Forwarder
127125

128126
// setup endpoint right away, so that netstack's internal state is consistent
129127
// in case there are multiple forwarders dispatching from the TUN device.
130-
if earlyConnect {
128+
if !settings.HappyEyeballs.Load() { // syn-ack before delivering to handler?
131129
opened, err := gtcp.tryConnect()
132130

133131
if settings.Debug {

intra/netstack/udp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func OutboundUDP(id string, s *stack.Stack, h GUDPConnHandler) {
8484

8585
func InboundUDP(who string, s *stack.Stack, in net.Conn, to, from netip.AddrPort, h GUDPConnHandler) error {
8686
newgc := makeGUDPConn(who, s, nil /*not a forwarder req*/, to, from)
87-
if earlyConnect {
87+
if !settings.HappyEyeballs.Load() { // ref comment in netstack/tcp.go
8888
err := newgc.Establish()
8989

9090
if settings.Debug {
@@ -163,7 +163,7 @@ func udpForwarder(who string, s *stack.Stack, h GUDPConnHandler) *udp.Forwarder
163163

164164
// setup to recv right away, so that netstack's internal state is consistent
165165
// in case there are multiple forwarders dispatching from the TUN device.
166-
if earlyConnect {
166+
if !settings.HappyEyeballs.Load() {
167167
err := gc.Establish()
168168

169169
if settings.Debug {

intra/tcp.go

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"net/netip"
3333
"time"
3434

35+
x "github.com/celzero/firestack/intra/backend"
3536
"github.com/celzero/firestack/intra/dnsx"
3637
"github.com/celzero/firestack/intra/log"
3738
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
@@ -96,7 +97,7 @@ func (h *tcpHandler) Error(gconn *netstack.GTCPConn, src, dst netip.AddrPort, er
9697
h.maybeReplaceDest(res, &dst)
9798

9899
cid, uid, fid, pids := h.judge(res)
99-
smm := tcpSummary(cid, uid, dst.Addr())
100+
smm := tcpSummary(cid, uid, src.Addr(), dst.Addr())
100101

101102
if isAnyBlockPid(pids) {
102103
smm.PID = ipn.Block
@@ -122,7 +123,7 @@ func (h *tcpHandler) Error(gconn *netstack.GTCPConn, src, dst netip.AddrPort, er
122123
func (h *tcpHandler) ReverseProxy(gconn *netstack.GTCPConn, in net.Conn, to, from netip.AddrPort) (open bool) {
123124
fm := h.onInflow(to, from)
124125
cid, uid, _, pids := h.judge(fm)
125-
smm := tcpSummary(cid, uid, from.Addr())
126+
smm := tcpSummary(cid, uid, to.Addr(), from.Addr())
126127

127128
if settings.Debug {
128129
log.VV("tcp: %s [%s]: reverse: %s => %s; pids: %v", cid, uid, from, to, pids)
@@ -149,6 +150,20 @@ func (h *tcpHandler) ReverseProxy(gconn *netstack.GTCPConn, in net.Conn, to, fro
149150
return true
150151
}
151152

153+
func (h *tcpHandler) handshakeIfNeededOrClose(gconn *netstack.GTCPConn, smm *SocketSummary) (bool, error) {
154+
const allow bool = true // allowed
155+
const deny bool = !allow // blocked
156+
157+
if _, err := gconn.Establish(); err != nil {
158+
err = log.EE("tcp: %s handshake err %v; %s => %s for %s",
159+
smm.ID, err, smm.Source, smm.Target, smm.UID)
160+
clos(gconn)
161+
h.queueSummary(smm.done(err))
162+
return deny, err // == !open
163+
}
164+
return allow, nil
165+
}
166+
152167
// Proxy implements netstack.GTCPConnHandler
153168
// It must be called from a goroutine.
154169
func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort) (open bool) {
@@ -180,7 +195,7 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
180195
actualTargets = []netip.AddrPort{target}
181196
}
182197
// actualTargets[0] may be same as target
183-
smm = tcpSummary(cid, uid, actualTargets[0].Addr())
198+
smm = tcpSummary(cid, uid, src.Addr(), actualTargets[0].Addr())
184199

185200
if h.status.Load() == HDLEND {
186201
err = log.EE("tcp: proxy: %s end %s => %s [%v]", cid, src, target, actualTargets)
@@ -206,16 +221,8 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
206221
return deny
207222
}
208223

209-
// handshake; since we assume a duplex-stream from here on
210-
if _, err := gconn.Establish(); err != nil {
211-
err = log.EE("tcp: %s connect1 err %v; %s => %s for %s", cid, err, src, target, uid)
212-
clos(gconn)
213-
h.queueSummary(smm.done(err))
214-
return deny // == !open
215-
}
216-
217224
if isAnyBasePid(pids) { // see udp.go:Connect
218-
if h.dnsOverride(gconn, target, uid) {
225+
if synack, _ := h.handshakeIfNeededOrClose(gconn, smm); synack && h.dnsOverride(gconn, target, uid) {
219226
// SocketSummary not sent; x.DNSSummary supercedes it
220227
// conn closed by resolver
221228
return allow
@@ -227,15 +234,18 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
227234
cid, src, target, actualTargets, uid, pids)
228235
}
229236

237+
cont := true
230238
boundSrc := makeAnyAddrPort(src)
231239
// pick all realips to connect to
232240
for i, dstipp := range actualTargets {
241+
targetstr := dstipp.Addr().String()
242+
233243
var px ipn.Proxy = nil
234244
px, err = h.prox.ProxyTo(dstipp, uid, pids)
235245

236246
// last chosen (but not dialed in) proxy (which error)
237-
smm.Target = dstipp.Addr().String() // addr may be invalid
238-
smm.PID = pidstr(px) // px may be nil
247+
smm.Target = targetstr // addr may be invalid
248+
smm.PID = pidstr(px) // px may be nil
239249
smm.RPID = ipn.ViaID(px)
240250

241251
if err != nil || px == nil {
@@ -244,13 +254,13 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
244254
continue
245255
}
246256

247-
if err = h.handle(px, gconn, boundSrc, dstipp, smm); err == nil {
257+
if cont, err = h.handle(px, gconn, boundSrc, dstipp, smm); err == nil {
248258
return allow // smm instead queued by handle() => forward()
249259
} else {
250260
end := time.Since(smm.start)
251-
err = log.WE("tcp: dial: #%d: %s failed; addr(%s) / fallback? %t; for uid %s (%s); w err(%v)",
252-
i, cid, dstipp, fallingback, uid, core.FmtPeriod(end), err)
253-
if end > retryTimeout {
261+
err = log.WE("tcp: dial: #%d: %s failed; addr(%s) / fallback? %t / cont? %t; for uid %s (%s); w err(%v)",
262+
i, cid, dstipp, fallingback, cont, uid, core.FmtPeriod(end), err)
263+
if !cont || end > retryTimeout {
254264
break // return err
255265
} // else: continue; try the next realip
256266
}
@@ -262,24 +272,33 @@ func (h *tcpHandler) Proxy(gconn *netstack.GTCPConn, src, target netip.AddrPort)
262272
}
263273

264274
// handle connects to the target via the proxy, and pipes data between the src, target; thread-safe.
265-
func (h *tcpHandler) handle(px ipn.Proxy, src *netstack.GTCPConn, boundSrc, target netip.AddrPort, smm *SocketSummary) (err error) {
275+
func (h *tcpHandler) handle(px ipn.Proxy, src *netstack.GTCPConn, boundSrc, target netip.AddrPort, smm *SocketSummary) (next bool, err error) {
276+
cont := true
277+
stop := !cont
278+
targetstr := target.String()
279+
280+
// make sure to not synack in HappyEyeballs scenarios
281+
if canroute := px.Router().Contains(x.StrOf(targetstr)); !canroute {
282+
return cont, log.WE("proxy(%s) has no route to %s", pidstr(px), targetstr)
283+
}
284+
266285
var pc protect.Conn
267286
var dst net.Conn
268287

269288
start := time.Now()
270289

271290
if settings.Debug {
272291
log.VV("tcp: %s dial %s: attempt: %s [%s] => %s for %s",
273-
smm.ID, pidstr(px), src.LocalAddr(), boundSrc, target, smm.UID)
292+
smm.ID, pidstr(px), src.LocalAddr(), boundSrc, targetstr, smm.UID)
274293
}
275294

276295
// github.com/google/gvisor/blob/5ba35f516b5c2/test/benchmarks/tcp/tcp_proxy.go#L359
277296
// ref: stackoverflow.com/questions/63656117
278297
// ref: stackoverflow.com/questions/40328025
279298
if settings.PortForward.Load() {
280-
pc, err = px.Dialer().DialBind("tcp", boundSrc.String(), target.String())
299+
pc, err = px.Dialer().DialBind("tcp", boundSrc.String(), targetstr)
281300
} else {
282-
pc, err = px.Dialer().Dial("tcp", target.String())
301+
pc, err = px.Dialer().Dial("tcp", targetstr)
283302
}
284303
if err == nil {
285304
smm.Rtt = time.Since(start).Milliseconds()
@@ -307,15 +326,20 @@ func (h *tcpHandler) handle(px ipn.Proxy, src *netstack.GTCPConn, boundSrc, targ
307326
smm.RPID = ipn.ViaID(px)
308327

309328
if err != nil {
310-
clos(src, pc)
329+
clos(pc)
311330
log.W("tcp: err dialing %s proxy(%s) to dst(%v) for %s: %v",
312-
smm.ID, px.ID(), target, smm.UID, err)
313-
return err
331+
smm.ID, smm.PID, smm.Target, smm.UID, err)
332+
return cont, err
333+
}
334+
335+
if _, synackerr := h.handshakeIfNeededOrClose(src, smm); synackerr != nil {
336+
clos(pc)
337+
return stop, synackerr
314338
}
315339

316340
core.Go("tcp.forward."+smm.ID, func() {
317341
h.listener.PostFlow(smm.postMark())
318342
h.forward(src, rwext{dst, tcptimeout}, smm) // src always *gonet.TCPConn
319343
})
320-
return nil // handled; takes ownership of src
344+
return cont, nil // handled; takes ownership of src
321345
}

intra/udp.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func NewUDPHandler(pctx context.Context, resolver dnsx.Resolver, prox ipn.ProxyP
9191
func (h *udpHandler) ReverseProxy(gconn *netstack.GUDPConn, in net.Conn, to, from netip.AddrPort) (ok bool) {
9292
fm := h.onInflow(to, from)
9393
cid, uid, _, pids := h.judge(fm)
94-
smm := udpSummary(cid, uid, from.Addr())
94+
smm := udpSummary(cid, uid, to.Addr(), from.Addr())
9595

9696
if settings.Debug {
9797
log.VV("udp: %s [%s]: reverse: %s => %s; pids: %v", cid, uid, from, to, pids)
@@ -135,7 +135,7 @@ func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, target netip.AddrPort,
135135
res, undidAlg, realips, domains := h.onFlow(src, target)
136136
h.maybeReplaceDest(res, &target)
137137
cid, uid, fid, pids := h.judge(res)
138-
smm := udpSummary(cid, uid, target.Addr())
138+
smm := udpSummary(cid, uid, src.Addr(), target.Addr())
139139

140140
if isAnyBlockPid(pids) {
141141
smm.PID = ipn.Block
@@ -206,7 +206,7 @@ func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPor
206206
if len(actualTargets) <= 0 { // unlikely
207207
actualTargets = []netip.AddrPort{target}
208208
}
209-
smm = udpSummary(cid, uid, actualTargets[0].Addr())
209+
smm = udpSummary(cid, uid, src.Addr(), actualTargets[0].Addr())
210210

211211
if h.status.Load() == HDLEND {
212212
err = log.EE("udp: connect: %s [%s] %v => %v, end", cid, uid, src, target)

0 commit comments

Comments
 (0)