Skip to content

Commit 2af7188

Browse files
committed
netstack: refactor swap, dispatch, attach
1 parent 5cdd0c3 commit 2af7188

4 files changed

Lines changed: 113 additions & 97 deletions

File tree

intra/netstack/dispatchers.go

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ package netstack
2525

2626
import (
2727
"math/rand"
28-
"net"
2928
"sync"
3029
"sync/atomic"
3130
"time"
@@ -187,14 +186,8 @@ var _ linkDispatcher = (*readVDispatcher)(nil)
187186
// newReadVDispatcher creates a new linkDispatcher that vector reads packets from
188187
// fd and dispatches them to endpoint e. It assumes ownership of fd but not of e.
189188
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
190-
tun, err := newTun(fd)
191-
if err != nil {
192-
clos(fd)
193-
return nil, err
194-
}
195189
d := &readVDispatcher{
196190
e: e,
197-
fds: core.NewVolatile(tun),
198191
buf: newIovecBuffer(bufcfg),
199192
mgr: newSupervisor(e, fd),
200193
}
@@ -206,26 +199,10 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
206199

207200
// swap atomically swaps existing fd for this new one.
208201
// On error, it closes fd.
209-
func (d *readVDispatcher) swap(fd int) error {
210-
done := d.closed.Load()
211-
if done {
212-
clos(fd)
213-
return net.ErrClosed
202+
func (d *readVDispatcher) prepare(f *fds) {
203+
if !d.closed.Load() {
204+
d.mgr.swap(f.tun()) // used for diagnostics only
214205
}
215-
216-
note := log.I
217-
f, err := newTun(fd) // fd may be invalid (ex: -1)
218-
if err != nil {
219-
defer clos(fd)
220-
note = log.W
221-
}
222-
223-
prev := d.fds.Swap(f) // f may be nil
224-
d.wrapup(prev, wrapttl) // prev may be nil
225-
d.mgr.swap(f.tun()) // used for diagnostics only
226-
227-
note("ns: dispatch: swap: tun(%d => %d); err %v", prev.tun(), fd, err)
228-
return err
229206
}
230207

231208
// stop stops the dispatcher once. Safe to call multiple times.
@@ -234,7 +211,6 @@ func (d *readVDispatcher) stop() {
234211

235212
d.once.Do(func() {
236213
d.closed.Store(true)
237-
d.fds.Load().stop()
238214
d.mgr.stop()
239215
log.I("ns: dispatch: closed!")
240216
})
@@ -244,8 +220,8 @@ const abort = false // abort indicates that the dispatcher should stop.
244220
const cont = true // cont indicates that the dispatcher should continue delivering packets despite an error.
245221

246222
// dispatch reads packets from the current file descriptor in d.fds and dispatches it to netstack.
247-
func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
248-
return d.io(d.fds.Load())
223+
func (d *readVDispatcher) dispatch(fd *fds) (bool, tcpip.Error) {
224+
return d.io(fd)
249225
}
250226

251227
// wrapup reads packets from fds and dispatches it to netstack
@@ -255,6 +231,8 @@ func (d *readVDispatcher) wrapup(fds *fds, noMoreThan30s time.Duration) {
255231
if !fds.ok() { // fds may be nil
256232
return
257233
}
234+
defer fds.stop()
235+
258236
// Loopback is set to true when VPN is in lockdown mode (block connections
259237
// without vpn). It is observed that by closing the previous tun after delay
260238
// results in "connection was reset" errors in netstack's TCP handler, which
@@ -263,21 +241,11 @@ func (d *readVDispatcher) wrapup(fds *fds, noMoreThan30s time.Duration) {
263241
// quite a while or the device switches to a new network (on Android 14+).
264242
if !threadSafe || settings.Loopingback.Load() {
265243
log.I("ns: tun(%d): wrapup: immediate (loopback)", fds.tun())
266-
fds.stop()
267244
return
268245
}
269246

270-
noMoreThan30s = min(30*time.Second, noMoreThan30s)
271-
secs := int64(noMoreThan30s.Seconds())
272-
273-
go func() {
274-
<-time.After(noMoreThan30s)
275-
log.W("ns: tun(%d): drain: timeout! %dsecs", fds.tun(), secs)
276-
fds.stop()
277-
}()
278-
279-
go func() {
280-
log.I("ns: tun(%d): drain: start w timeout in %dsecs", fds.tun(), secs)
247+
awaited := core.Await(func() {
248+
log.I("ns: tun(%d): drain: start w timeout in %s", fds.tun(), core.FmtPeriod(noMoreThan30s))
281249
for {
282250
cont, err := d.io(fds)
283251
if fd := fds.tun(); !cont {
@@ -287,7 +255,9 @@ func (d *readVDispatcher) wrapup(fds *fds, noMoreThan30s time.Duration) {
287255
log.W("ns: tun(%d): drain: continue on err: %v", fd, err)
288256
} // else: continue draining
289257
}
290-
}()
258+
}, min(30*time.Second, noMoreThan30s))
259+
260+
logei(awaited)("ns: tun(%d): drain: timeout!", fds.tun())
291261
}
292262

293263
// io reads packets from fds and dispatches it to netstack.

intra/netstack/fdbased.go

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ package netstack
3232
import (
3333
"fmt"
3434
"sync/atomic"
35-
"syscall"
3635
"time"
3736

3837
"github.com/celzero/firestack/intra/core"
@@ -53,6 +52,8 @@ var _ FdSwapper = (*linkFdSwap)(nil)
5352

5453
const invalidfd int = -1
5554

55+
const waitttl = wrapttl
56+
5657
type FdSwapper interface {
5758
// Cur returns the current FD.
5859
Cur() int
@@ -71,16 +72,17 @@ type SeamlessEndpoint interface {
7172
// NetworkDispatcher.
7273
type linkDispatcher interface {
7374
stop()
74-
swap(fd int) error
75-
dispatch() (bool, tcpip.Error)
75+
prepare(fd *fds)
76+
dispatch(fd *fds) (bool, tcpip.Error)
77+
wrapup(prev *fds, ttl time.Duration)
7678
}
7779

7880
type endpoint struct {
7981
sync.RWMutex
8082
// fds is the set of file descriptors each identifying one inbound/outbound
8183
// channel. The endpoint will dispatch from all inbound channels as well as
8284
// hash outbound packets to specific channels based on the packet hash.
83-
fds *core.Volatile[int]
85+
fds *core.Volatile[*fds]
8486

8587
// mtu (maximum transmission unit) is the maximum size of a packet.
8688
mtu atomic.Uint32
@@ -196,7 +198,7 @@ func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
196198

197199
e := &endpoint{
198200
mtu: atomic.Uint32{},
199-
fds: core.NewVolatile(invalidfd),
201+
fds: core.NewVolatile(invalidFds),
200202
caps: caps,
201203
addr: opts.Address,
202204
hdrSize: hdrSize,
@@ -239,16 +241,11 @@ func (e *endpoint) Cur() int {
239241
}
240242

241243
func (e *endpoint) Dispose() (err error) {
242-
if e.fds == nil {
243-
log.W("ns: tun: Dispose: no fds")
244-
return nil
245-
}
246-
247244
e.Lock()
248245
defer e.Unlock()
249246

250-
prevfd := e.fds.Swap(invalidfd) // prevfd may be invalidfd
251-
if prevfd == invalidfd {
247+
prevfd := e.fds.Swap(invalidFds) // prevfd may be invalidfd
248+
if !prevfd.ok() {
252249
log.W("ns: tun: Dispose: invalid prevfd")
253250
return nil
254251
}
@@ -258,49 +255,50 @@ func (e *endpoint) Dispose() (err error) {
258255
// nothing to do
259256
return nil
260257
}
261-
// e.inboundDispatcher.swap() will close prevfd
262-
// e.dispatchLoop() will auto-exit due to invalidfd
263-
return e.inboundDispatcher.swap(invalidfd)
258+
// e.inboundDispatcher.prepare() will not close prevfd
259+
// dispatchLoop() will auto-exit on invalidfd
260+
core.Go("ns.dispose.wrapup", func() { e.inboundDispatcher.wrapup(prevfd, wrapttl) })
261+
e.inboundDispatcher.prepare(invalidFds)
262+
263+
return nil
264264
}
265265

266266
// Implements Swapper.
267267
func (e *endpoint) Swap(fd int) (err error) {
268-
if err = unix.SetNonblock(fd, true); err != nil {
269-
clos(fd)
270-
return fmt.Errorf("ns: tun: set non blocking(%d) failed: %v", fd, err)
271-
}
272-
273268
e.Lock()
274269
defer e.Unlock()
275270

276-
prevfd := e.fds.Swap(fd) // commence WritePackets() on fd
277-
log.D("ns: swap: tun fd %d => %d", prevfd, fd)
271+
f, err := newTun(fd) // fd may be invalid (ex: -1)
272+
if err != nil {
273+
fd = invalidfd
274+
err = log.EE("ns: tun: swap: (%d) err: %v / %v; using invalidfd", fd, err)
275+
}
276+
277+
prevfd := e.fds.Swap(f) // commence WritePackets() on fd
278+
core.Go(
279+
"ns.swap.wrapup", func() { e.inboundDispatcher.wrapup(prevfd, wrapttl) },
280+
) // closes prevfd, which may be invalidfd
281+
282+
log.D("ns: tun: swap: fd %s => %d", prevfd, fd)
278283

279284
if e.inboundDispatcher == nil { // prevfd must be 0 value if inbound is nil
280285
e.inboundDispatcher, err = createInboundDispatcher(e, fd)
281286
} else {
282-
err = e.inboundDispatcher.swap(fd) // always closes prevfd
283-
// todo: on err != nil e.inboundDispatcher = nil?
287+
e.inboundDispatcher.prepare(f)
284288
}
285289

286290
hasDispatcher := e.dispatcher != nil
287291
if err == nil && hasDispatcher { // attached?
288-
log.I("ns: tun(%d => %d): swap: restart looper %t for new fd",
292+
log.I("ns: tun: (%s => %d) swap: restart looper %t for new fd",
289293
prevfd, fd, hasDispatcher)
290-
go e.dispatchLoop(e.inboundDispatcher)
294+
go dispatchLoop(e.inboundDispatcher, f, &e.wg)
291295
} else {
292-
log.E("ns: tun(%d => %d): swap: no dispatcher? %t for new fd; err %v",
296+
log.E("ns: tun: (%s => %d) swap: no dispatcher? %t for new fd; err %v",
293297
prevfd, fd, !hasDispatcher, err)
294298
}
295299
return
296300
}
297301

298-
func clos(fd int) {
299-
if fd > 0 || fd != invalidfd {
300-
_ = syscall.Close(fd)
301-
}
302-
}
303-
304302
// Attach launches the goroutine that reads packets from the file descriptor and
305303
// dispatches them via the provided dispatcher.
306304
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -310,41 +308,52 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
310308
defer e.Unlock()
311309

312310
rx := e.inboundDispatcher
313-
fd := e.fd()
311+
312+
fds := e.fds.Load()
313+
fd := fds.tun()
314+
314315
attach := dispatcher != nil // nil means the NIC is being removed.
315316
pipe := rx != nil // nil means there's no read dispatcher.
316317
exists := e.dispatcher != nil // nil means the NIC is already detached.
318+
317319
// Attach is called when the NIC is being created and then enabled.
318320
// stack.CreateNIC -> nic.newNIC -> ep.Attach
319321
if dispatcher == nil && e.dispatcher != nil {
320322
log.I("ns: tun(%d): attach: detach dispatcher (and inbound? %t)", fd, pipe)
323+
allLoopersExited := true
321324
if rx != nil {
325+
fds.stop()
322326
go rx.stop() // avoid mutex; closes fd
323-
e.Wait() // on all inboundDispatcher w/ mutex locked?
327+
328+
allLoopersExited = e.wait(waitttl) // on all inboundDispatcher w/ mutex locked?
324329
}
325330
e.dispatcher = nil
326331
e.inboundDispatcher = nil // rx
327-
e.fds.Store(invalidfd)
328-
log.I("ns: tun(%d): attach: done detaching dispatcher", fd)
332+
e.fds.Store(invalidFds)
333+
logei(!allLoopersExited)("ns: tun(%d): attach: done detaching dispatcher; all loopers done? %t",
334+
fd, allLoopersExited)
329335
return
330336
}
337+
331338
if dispatcher != nil && e.dispatcher == nil {
332339
log.I("ns: tun(%d): attach: new dispatcher & looper", fd)
333340
e.dispatcher = dispatcher
334-
if e.inboundDispatcher == nil && fd != invalidfd { // unlikely
341+
if e.inboundDispatcher == nil && fds.ok() { // unlikely
335342
var err error
336343
e.inboundDispatcher, err = createInboundDispatcher(e, fd)
337344
logeif(err)("ns: tun(%d): attach: just-in-time createInboundDispatcher; err? %v", fd, err)
338345
rx = e.inboundDispatcher
339346
}
340-
go e.dispatchLoop(rx)
347+
go dispatchLoop(rx, fds, &e.wg)
341348
return
342349
}
350+
343351
if dispatcher != nil {
344352
log.W("ns: tun(%d): attach: discard? %t; but switch to new anyway", fd, exists)
345353
e.dispatcher = dispatcher
346354
return
347355
}
356+
348357
log.W("ns: tun(%d): attach: discard? %t; hadDispatcher? %t hadInbound? %t", fd, exists, attach, pipe)
349358
}
350359

@@ -382,6 +391,11 @@ func (e *endpoint) Wait() {
382391
e.wg.Wait()
383392
}
384393

394+
func (e *endpoint) wait(d time.Duration) bool {
395+
// wait on e.Wait() until ttl expires
396+
return core.Await(func() { e.Wait() }, d)
397+
}
398+
385399
// AddHeader implements stack.LinkEndpoint.AddHeader.
386400
func (e *endpoint) AddHeader(pkt *stack.PacketBuffer) {
387401
if e.hdrSize > 0 && pkt != nil {
@@ -416,10 +430,7 @@ func (e *endpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
416430

417431
// fd returns the file descriptor associated with the endpoint.
418432
func (e *endpoint) fd() int {
419-
if fd := e.fds.Load(); fd > 0 {
420-
return fd
421-
}
422-
return invalidfd
433+
return e.fds.Load().tun()
423434
}
424435

425436
// writePackets writes outbound packets to the file descriptor. If it is not
@@ -482,31 +493,33 @@ func (e *endpoint) notifyRestart() {
482493
*/
483494

484495
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
485-
// them to the network stack. Must be run as a goroutine.
486-
func (e *endpoint) dispatchLoop(inbound linkDispatcher) tcpip.Error {
496+
// them to the network stack (linkDispatcher). Must be run as a goroutine.
497+
func dispatchLoop(inbound linkDispatcher, f *fds, wg *sync.WaitGroup) tcpip.Error {
487498
// defer core.RecoverFn("ns.e.dispatch", e.notifyRestart)
488499
defer core.Recover(core.Exit11, "ns.e.dispatch")
489500

490-
e.wg.Add(1)
491-
defer e.wg.Done()
501+
wg.Add(1)
502+
defer wg.Done()
492503

493-
fd := e.fd()
494504
if inbound == nil || core.IsNil(inbound) {
495-
log.W("ns: tun(%d): dispatchLoop: inbound nil", fd)
505+
defer f.stop()
506+
log.W("ns: tun(%d): dispatchLoop: inbound nil", f.tun())
496507
return &tcpip.ErrUnknownDevice{}
497508
}
498509

499510
start := time.Now()
500-
log.I("ns: tun(%d): dispatchLoop: start", fd)
511+
log.I("ns: tun(%d): dispatchLoop: start", f.tun())
501512
for {
502-
cont, err := inbound.dispatch()
513+
cont, err := inbound.dispatch(f)
514+
if err != nil {
515+
log.W("ns: tun(%d): dispatchLoop: dur: %s; continue? %t; err: %v",
516+
f.tun(), core.FmtTimeAsPeriod(start), cont, err)
517+
}
503518
if !cont {
504-
elapsed := time.Since(start)
505-
log.W("ns: tun(%d): dispatchLoop: exit; dur: %s; err? %v", fd, elapsed, err)
519+
defer f.stop()
506520
return err
507-
} else if err != nil {
508-
log.W("ns: tun(%d): dispatchLoop: continue on err: %v", fd, err)
509521
} // else: continue dispatching
522+
510523
}
511524
}
512525

0 commit comments

Comments
 (0)