@@ -32,7 +32,6 @@ package netstack
3232import (
3333 "fmt"
3434 "sync/atomic"
35- "syscall"
3635 "time"
3736
3837 "github.com/celzero/firestack/intra/core"
@@ -53,6 +52,8 @@ var _ FdSwapper = (*linkFdSwap)(nil)
5352
5453const invalidfd int = - 1
5554
55+ const waitttl = wrapttl
56+
5657type FdSwapper interface {
5758 // Cur returns the current FD.
5859 Cur () int
@@ -71,16 +72,17 @@ type SeamlessEndpoint interface {
7172// NetworkDispatcher.
7273type linkDispatcher interface {
7374 stop ()
74- swap (fd int ) error
75- dispatch () (bool , tcpip.Error )
75+ prepare (fd * fds )
76+ dispatch (fd * fds ) (bool , tcpip.Error )
77+ wrapup (prev * fds , ttl time.Duration )
7678}
7779
7880type endpoint struct {
7981 sync.RWMutex
8082 // fds is the set of file descriptors each identifying one inbound/outbound
8183 // channel. The endpoint will dispatch from all inbound channels as well as
8284 // hash outbound packets to specific channels based on the packet hash.
83- fds * core.Volatile [int ]
85+ fds * core.Volatile [* fds ]
8486
8587 // mtu (maximum transmission unit) is the maximum size of a packet.
8688 mtu atomic.Uint32
@@ -196,7 +198,7 @@ func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
196198
197199 e := & endpoint {
198200 mtu : atomic.Uint32 {},
199- fds : core .NewVolatile (invalidfd ),
201+ fds : core .NewVolatile (invalidFds ),
200202 caps : caps ,
201203 addr : opts .Address ,
202204 hdrSize : hdrSize ,
@@ -239,16 +241,11 @@ func (e *endpoint) Cur() int {
239241}
240242
241243func (e * endpoint ) Dispose () (err error ) {
242- if e .fds == nil {
243- log .W ("ns: tun: Dispose: no fds" )
244- return nil
245- }
246-
247244 e .Lock ()
248245 defer e .Unlock ()
249246
250- prevfd := e .fds .Swap (invalidfd ) // prevfd may be invalidfd
251- if prevfd == invalidfd {
247+ prevfd := e .fds .Swap (invalidFds ) // prevfd may be invalidfd
248+ if ! prevfd . ok () {
252249 log .W ("ns: tun: Dispose: invalid prevfd" )
253250 return nil
254251 }
@@ -258,49 +255,50 @@ func (e *endpoint) Dispose() (err error) {
258255 // nothing to do
259256 return nil
260257 }
261- // e.inboundDispatcher.swap() will close prevfd
262- // e.dispatchLoop() will auto-exit due to invalidfd
263- return e .inboundDispatcher .swap (invalidfd )
258+ // e.inboundDispatcher.prepare() will not close prevfd
259+ // dispatchLoop() will auto-exit on invalidfd
260+ core .Go ("ns.dispose.wrapup" , func () { e .inboundDispatcher .wrapup (prevfd , wrapttl ) })
261+ e .inboundDispatcher .prepare (invalidFds )
262+
263+ return nil
264264}
265265
266266// Implements Swapper.
267267func (e * endpoint ) Swap (fd int ) (err error ) {
268- if err = unix .SetNonblock (fd , true ); err != nil {
269- clos (fd )
270- return fmt .Errorf ("ns: tun: set non blocking(%d) failed: %v" , fd , err )
271- }
272-
273268 e .Lock ()
274269 defer e .Unlock ()
275270
276- prevfd := e .fds .Swap (fd ) // commence WritePackets() on fd
277- log .D ("ns: swap: tun fd %d => %d" , prevfd , fd )
271+ f , err := newTun (fd ) // fd may be invalid (ex: -1)
272+ if err != nil {
273+ fd = invalidfd
274+ err = log .EE ("ns: tun: swap: (%d) err: %v / %v; using invalidfd" , fd , err )
275+ }
276+
277+ prevfd := e .fds .Swap (f ) // commence WritePackets() on fd
278+ core .Go (
279+ "ns.swap.wrapup" , func () { e .inboundDispatcher .wrapup (prevfd , wrapttl ) },
280+ ) // closes prevfd, which may be invalidfd
281+
282+ log .D ("ns: tun: swap: fd %s => %d" , prevfd , fd )
278283
279284 if e .inboundDispatcher == nil { // prevfd must be 0 value if inbound is nil
280285 e .inboundDispatcher , err = createInboundDispatcher (e , fd )
281286 } else {
282- err = e .inboundDispatcher .swap (fd ) // always closes prevfd
283- // todo: on err != nil e.inboundDispatcher = nil?
287+ e .inboundDispatcher .prepare (f )
284288 }
285289
286290 hasDispatcher := e .dispatcher != nil
287291 if err == nil && hasDispatcher { // attached?
288- log .I ("ns: tun(%d => %d): swap: restart looper %t for new fd" ,
292+ log .I ("ns: tun: (%s => %d) swap: restart looper %t for new fd" ,
289293 prevfd , fd , hasDispatcher )
290- go e . dispatchLoop (e .inboundDispatcher )
294+ go dispatchLoop (e .inboundDispatcher , f , & e . wg )
291295 } else {
292- log .E ("ns: tun(%d => %d): swap: no dispatcher? %t for new fd; err %v" ,
296+ log .E ("ns: tun: (%s => %d) swap: no dispatcher? %t for new fd; err %v" ,
293297 prevfd , fd , ! hasDispatcher , err )
294298 }
295299 return
296300}
297301
298- func clos (fd int ) {
299- if fd > 0 || fd != invalidfd {
300- _ = syscall .Close (fd )
301- }
302- }
303-
304302// Attach launches the goroutine that reads packets from the file descriptor and
305303// dispatches them via the provided dispatcher.
306304func (e * endpoint ) Attach (dispatcher stack.NetworkDispatcher ) {
@@ -310,41 +308,52 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
310308 defer e .Unlock ()
311309
312310 rx := e .inboundDispatcher
313- fd := e .fd ()
311+
312+ fds := e .fds .Load ()
313+ fd := fds .tun ()
314+
314315 attach := dispatcher != nil // nil means the NIC is being removed.
315316 pipe := rx != nil // nil means there's no read dispatcher.
316317 exists := e .dispatcher != nil // nil means the NIC is already detached.
318+
317319 // Attach is called when the NIC is being created and then enabled.
318320 // stack.CreateNIC -> nic.newNIC -> ep.Attach
319321 if dispatcher == nil && e .dispatcher != nil {
320322 log .I ("ns: tun(%d): attach: detach dispatcher (and inbound? %t)" , fd , pipe )
323+ allLoopersExited := true
321324 if rx != nil {
325+ fds .stop ()
322326 go rx .stop () // avoid mutex; closes fd
323- e .Wait () // on all inboundDispatcher w/ mutex locked?
327+
328+ allLoopersExited = e .wait (waitttl ) // on all inboundDispatcher w/ mutex locked?
324329 }
325330 e .dispatcher = nil
326331 e .inboundDispatcher = nil // rx
327- e .fds .Store (invalidfd )
328- log .I ("ns: tun(%d): attach: done detaching dispatcher" , fd )
332+ e .fds .Store (invalidFds )
333+ logei (! allLoopersExited )("ns: tun(%d): attach: done detaching dispatcher; all loopers done? %t" ,
334+ fd , allLoopersExited )
329335 return
330336 }
337+
331338 if dispatcher != nil && e .dispatcher == nil {
332339 log .I ("ns: tun(%d): attach: new dispatcher & looper" , fd )
333340 e .dispatcher = dispatcher
334- if e .inboundDispatcher == nil && fd != invalidfd { // unlikely
341+ if e .inboundDispatcher == nil && fds . ok () { // unlikely
335342 var err error
336343 e .inboundDispatcher , err = createInboundDispatcher (e , fd )
337344 logeif (err )("ns: tun(%d): attach: just-in-time createInboundDispatcher; err? %v" , fd , err )
338345 rx = e .inboundDispatcher
339346 }
340- go e . dispatchLoop (rx )
347+ go dispatchLoop (rx , fds , & e . wg )
341348 return
342349 }
350+
343351 if dispatcher != nil {
344352 log .W ("ns: tun(%d): attach: discard? %t; but switch to new anyway" , fd , exists )
345353 e .dispatcher = dispatcher
346354 return
347355 }
356+
348357 log .W ("ns: tun(%d): attach: discard? %t; hadDispatcher? %t hadInbound? %t" , fd , exists , attach , pipe )
349358}
350359
@@ -382,6 +391,11 @@ func (e *endpoint) Wait() {
382391 e .wg .Wait ()
383392}
384393
394+ func (e * endpoint ) wait (d time.Duration ) bool {
395+ // wait on e.Wait() until ttl expires
396+ return core .Await (func () { e .Wait () }, d )
397+ }
398+
385399// AddHeader implements stack.LinkEndpoint.AddHeader.
386400func (e * endpoint ) AddHeader (pkt * stack.PacketBuffer ) {
387401 if e .hdrSize > 0 && pkt != nil {
@@ -416,10 +430,7 @@ func (e *endpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
416430
417431// fd returns the file descriptor associated with the endpoint.
418432func (e * endpoint ) fd () int {
419- if fd := e .fds .Load (); fd > 0 {
420- return fd
421- }
422- return invalidfd
433+ return e .fds .Load ().tun ()
423434}
424435
425436// writePackets writes outbound packets to the file descriptor. If it is not
@@ -482,31 +493,33 @@ func (e *endpoint) notifyRestart() {
482493*/
483494
484495// dispatchLoop reads packets from the file descriptor in a loop and dispatches
485- // them to the network stack. Must be run as a goroutine.
486- func ( e * endpoint ) dispatchLoop (inbound linkDispatcher ) tcpip.Error {
496+ // them to the network stack (linkDispatcher) . Must be run as a goroutine.
497+ func dispatchLoop (inbound linkDispatcher , f * fds , wg * sync. WaitGroup ) tcpip.Error {
487498 // defer core.RecoverFn("ns.e.dispatch", e.notifyRestart)
488499 defer core .Recover (core .Exit11 , "ns.e.dispatch" )
489500
490- e . wg .Add (1 )
491- defer e . wg .Done ()
501+ wg .Add (1 )
502+ defer wg .Done ()
492503
493- fd := e .fd ()
494504 if inbound == nil || core .IsNil (inbound ) {
495- log .W ("ns: tun(%d): dispatchLoop: inbound nil" , fd )
505+ defer f .stop ()
506+ log .W ("ns: tun(%d): dispatchLoop: inbound nil" , f .tun ())
496507 return & tcpip.ErrUnknownDevice {}
497508 }
498509
499510 start := time .Now ()
500- log .I ("ns: tun(%d): dispatchLoop: start" , fd )
511+ log .I ("ns: tun(%d): dispatchLoop: start" , f . tun () )
501512 for {
502- cont , err := inbound .dispatch ()
513+ cont , err := inbound .dispatch (f )
514+ if err != nil {
515+ log .W ("ns: tun(%d): dispatchLoop: dur: %s; continue? %t; err: %v" ,
516+ f .tun (), core .FmtTimeAsPeriod (start ), cont , err )
517+ }
503518 if ! cont {
504- elapsed := time .Since (start )
505- log .W ("ns: tun(%d): dispatchLoop: exit; dur: %s; err? %v" , fd , elapsed , err )
519+ defer f .stop ()
506520 return err
507- } else if err != nil {
508- log .W ("ns: tun(%d): dispatchLoop: continue on err: %v" , fd , err )
509521 } // else: continue dispatching
522+
510523 }
511524}
512525
0 commit comments