diff --git a/xray/xray.go b/xray/xray.go index 166cce5..18929ef 100644 --- a/xray/xray.go +++ b/xray/xray.go @@ -60,13 +60,13 @@ import ( _ "github.com/xtls/xray-core/main/json" ) -// drainTimeout is how long an instance stays in the servers map after Close() is called. -// This gives any in-flight operations a chance to finish before the instance is actually closed, -// which helps xray-core's goroutines clean up properly. -const drainTimeout = 30 * time.Second +// DrainTimeout is how long an instance stays in the servers map after Close() +// is called. This gives in-flight operations a chance to finish before the +// instance is actually closed, preventing goroutine leaks in xray-core. +var DrainTimeout = 30 * time.Second -// sweepInterval how often the background sweeper runs. -const sweepInterval = 10 * time.Second +// SweepInterval is how often the background sweeper runs. +var SweepInterval = 10 * time.Second type Server struct { Instance *core.Instance @@ -75,52 +75,76 @@ type Server struct { } var ( - mu sync.Mutex - servers = make(map[string]*Server) + mu sync.Mutex + servers = make(map[string]*Server) + sweeperOnce sync.Once + stopCh chan struct{} + sweeperWG sync.WaitGroup ) -func init() { - go sweeper() +// StartSweeper launches the background sweeper goroutine if not already running. +// It is called automatically by the public API; you do not need to call it. +func startSweeper() { + sweeperOnce.Do(func() { + ch := make(chan struct{}) + stopCh = ch + sweeperWG.Add(1) + go func(stop <-chan struct{}) { + sweeper(stop) + sweeperWG.Done() + }(ch) // pass as parameter so the goroutine uses its own copy + }) } -func sweeper() { +// StopSweeper stops the running sweeper goroutine (if any) and waits for it +// to exit, then resets its Once gate so a new sweeper can be started. +// Intended for use in tests only. +func StopSweeper() { + if stopCh != nil { + close(stopCh) + } + sweeperWG.Wait() + // Reset state so a fresh sweeper can be started in the next test. + stopCh = nil + sweeperOnce = sync.Once{} +} + +func sweeper(stop <-chan struct{}) { for { - time.Sleep(sweepInterval) + select { + case <-stop: + return + case <-time.After(SweepInterval): + } - // Collect expired URLs under lock, then release lock before closing - // to avoid blocking all map operations while Instance.Close() runs. var expired []struct { - url string - srv *Server + url string + srv *Server } mu.Lock() now := time.Now() for url, srv := range servers { - if !srv.DrainedAt.IsZero() && now.Sub(srv.DrainedAt) > drainTimeout { + if !srv.DrainedAt.IsZero() && now.Sub(srv.DrainedAt) > DrainTimeout { expired = append(expired, struct { - url string - srv *Server + url string + srv *Server }{url, srv}) } } mu.Unlock() - // Close instances outside the critical section. for _, e := range expired { - e.srv.Instance.Close() //nolint: errcheck - mu.Lock() - delete(servers, e.url) - mu.Unlock() + tryCloseAndDelete(e.url, e.srv) } } } func getServer(proxyURL string) *Server { + startSweeper() mu.Lock() defer mu.Unlock() if proxy, ok := servers[proxyURL]; ok { - // If draining, revive it. if !proxy.DrainedAt.IsZero() { proxy.DrainedAt = time.Time{} } @@ -130,20 +154,40 @@ func getServer(proxyURL string) *Server { } func setServer(proxyURL string, instance *core.Instance, port int) { + startSweeper() mu.Lock() defer mu.Unlock() servers[proxyURL] = &Server{ - Instance: instance, + Instance: instance, SocksPort: port, - DrainedAt: time.Time{}, + DrainedAt: time.Time{}, + } +} + +// tryCloseAndDelete checks the entry under lock, closes it if still draining, +// then removes it from the map. The lock pattern ensures: +// - The entry hasn't been revived (DrainedAt reset to zero) since collection. +// - The entry hasn't been replaced by a newer server for the same URL. +func tryCloseAndDelete(url string, srv *Server) { + mu.Lock() + defer mu.Unlock() + if srv == nil || servers[url] != srv || srv.DrainedAt.IsZero() { + return + } + if srv.Instance != nil { + srv.Instance.Close() //nolint: errcheck + } + if servers[url] == srv { + delete(servers, url) } } // Close marks the server as draining. The sweeper goroutine will actually close -// the xray instance after drainTimeout elapses, giving in-flight operations a +// the xray instance after DrainTimeout elapses, giving in-flight operations a // chance to finish cleanly and preventing premature close from leaking goroutines. func Close(proxyURL string) { + startSweeper() mu.Lock() defer mu.Unlock() @@ -154,8 +198,9 @@ func Close(proxyURL string) { } // CloseAll marks all servers as draining immediately. The sweeper will close -// each one after drainTimeout. +// each one after DrainTimeout. func CloseAll() { + startSweeper() mu.Lock() defer mu.Unlock() @@ -166,3 +211,15 @@ func CloseAll() { } } } + +// ResetForTest clears all entries from the servers map and resets the sweeper, +// so tests get a clean state without reassigning the map variable (which would +// race with any goroutines still iterating over the old map). Safe to call from tests. +func ResetForTest() { + mu.Lock() + for url := range servers { + delete(servers, url) + } + mu.Unlock() + StopSweeper() +} diff --git a/xray/xray_test.go b/xray/xray_test.go new file mode 100644 index 0000000..341f7f3 --- /dev/null +++ b/xray/xray_test.go @@ -0,0 +1,323 @@ +package xray + +import ( + "os" + "sync" + "testing" + "time" +) + +func TestMain(m *testing.M) { + code := m.Run() + ResetForTest() + os.Exit(code) +} + +// itoa avoids importing strconv just for int-to-string in tests. +func itoa(i int) string { + if i == 0 { + return "0" + } + var buf [20]byte + p := len(buf) + for i > 0 { + p-- + buf[p] = byte('0' + i%10) + i /= 10 + } + return string(buf[p:]) +} + +func TestSetAndGet(t *testing.T) { + ResetForTest() + DrainTimeout = 50 * time.Millisecond + SweepInterval = 10 * time.Millisecond + + // Inject a server directly into the map to avoid needing a real Instance. + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Time{}} + mu.Unlock() + + srv := getServer("socks5://127.0.0.1:1080") + if srv == nil { + t.Fatal("expected server, got nil") + } + mu.Lock() + if servers["socks5://127.0.0.1:1080"].SocksPort != 1080 { + t.Errorf("expected port 1080, got %d", srv.SocksPort) + } + mu.Unlock() +} + +func TestGetNonExistent(t *testing.T) { + ResetForTest() + srv := getServer("socks5://127.0.0.1:9999") + if srv != nil { + t.Error("expected nil for non-existent server") + } +} + +func TestCloseRevivesServer(t *testing.T) { + ResetForTest() + DrainTimeout = 50 * time.Millisecond + SweepInterval = 10 * time.Millisecond + + // Set up an active server. + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Time{}} + mu.Unlock() + + // Close it — marks DrainedAt non-zero. + Close("socks5://127.0.0.1:1080") + + // Verify DrainedAt is non-zero (read through map under lock). + mu.Lock() + wasZero := servers["socks5://127.0.0.1:1080"].DrainedAt.IsZero() + mu.Unlock() + if wasZero { + t.Error("expected DrainedAt to be non-zero after Close()") + } + + // getServer should revive it (reset DrainedAt to zero). + got := getServer("socks5://127.0.0.1:1080") + if got == nil { + t.Fatal("expected server after getServer") + } + + // Verify DrainedAt is now zero — read through the map under lock. + mu.Lock() + stillZero := servers["socks5://127.0.0.1:1080"].DrainedAt.IsZero() + mu.Unlock() + if !stillZero { + t.Error("expected DrainedAt to be reset to zero after getServer (revive)") + } +} + +func TestCloseIdempotent(t *testing.T) { + ResetForTest() + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Time{}} + mu.Unlock() + + Close("socks5://127.0.0.1:1080") + Close("socks5://127.0.0.1:1080") // second call must not panic + + mu.Lock() + defer mu.Unlock() + if servers["socks5://127.0.0.1:1080"].DrainedAt.IsZero() { + t.Error("expected DrainedAt non-zero") + } +} + +func TestCloseNonExistent(t *testing.T) { + ResetForTest() + // Must not panic. + Close("socks5://127.0.0.1:9999") +} + +func TestCloseAll(t *testing.T) { + ResetForTest() + DrainTimeout = 50 * time.Millisecond + SweepInterval = 10 * time.Millisecond + + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Time{}} + servers["socks5://127.0.0.1:1081"] = &Server{SocksPort: 1081, DrainedAt: time.Time{}} + servers["socks5://127.0.0.1:1082"] = &Server{SocksPort: 1082, DrainedAt: time.Time{}} + mu.Unlock() + + CloseAll() + + mu.Lock() + defer mu.Unlock() + for _, port := range []int{1080, 1081, 1082} { + key := "socks5://127.0.0.1:" + itoa(port) + if servers[key].DrainedAt.IsZero() { + t.Errorf("expected server %s to be draining after CloseAll", key) + } + } +} + +func TestSweeperRemovesExpired(t *testing.T) { + ResetForTest() + DrainTimeout = 80 * time.Millisecond + SweepInterval = 15 * time.Millisecond + + // Inject an expired server directly into the map. + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Now().Add(-200 * time.Millisecond)} + mu.Unlock() + + // Call getServer to start the sweeper (it is lazy). This also revives + // the server (resetting DrainedAt), so use a different URL. + getServer("socks5://127.0.0.1:9998") + + // Inject another expired server after sweeper is running. + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Now().Add(-200 * time.Millisecond)} + mu.Unlock() + + // Wait enough for sweeper to run and remove the entry. + time.Sleep(300 * time.Millisecond) + + mu.Lock() + _, ok := servers["socks5://127.0.0.1:1080"] + mu.Unlock() + + if ok { + t.Error("expected server to be removed by sweeper after DrainTimeout") + } +} + +func TestSweeperSkipsRevivedEntry(t *testing.T) { + ResetForTest() + DrainTimeout = 50 * time.Millisecond + SweepInterval = 10 * time.Millisecond + + // Entry is old enough to be collected, but we'll revive it before sweeper runs. + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Now().Add(-100 * time.Millisecond)} + mu.Unlock() + + // Revive via getServer before sweeper picks it up. + getServer("socks5://127.0.0.1:1080") + + time.Sleep(200 * time.Millisecond) + + mu.Lock() + srv, ok := servers["socks5://127.0.0.1:1080"] + stillZero := srv != nil && srv.DrainedAt.IsZero() + mu.Unlock() + + if !ok { + t.Error("expected server to still exist after revive") + } + if !stillZero { + t.Error("expected DrainedAt to be zero after revive") + } +} + +func TestSweeperSkipsActiveEntry(t *testing.T) { + ResetForTest() + DrainTimeout = 50 * time.Millisecond + SweepInterval = 10 * time.Millisecond + + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Time{}} + mu.Unlock() + + time.Sleep(200 * time.Millisecond) + + mu.Lock() + _, ok := servers["socks5://127.0.0.1:1080"] + mu.Unlock() + + if !ok { + t.Error("expected active server to NOT be removed") + } +} + +func TestTryCloseAndDelete_NotInMap(t *testing.T) { + ResetForTest() + // Must not panic when url is not in map. + tryCloseAndDelete("socks5://127.0.0.1:9999", nil) +} + +func TestTryCloseAndDelete_WrongPointer(t *testing.T) { + ResetForTest() + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: time.Now()} + mu.Unlock() + + // Try to close with a different (non-existent) pointer. + ghost := &Server{SocksPort: 9999, DrainedAt: time.Now()} + tryCloseAndDelete("socks5://127.0.0.1:1080", ghost) + + mu.Lock() + defer mu.Unlock() + if _, ok := servers["socks5://127.0.0.1:1080"]; !ok { + t.Error("expected server to remain when wrong pointer is passed") + } +} + +func TestTryCloseAndDelete_RevivedEntry(t *testing.T) { + ResetForTest() + now := time.Now() + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 1080, DrainedAt: now} + mu.Unlock() + + // Manually revive the entry (simulate getServer racing with sweeper). + mu.Lock() + servers["socks5://127.0.0.1:1080"].DrainedAt = time.Time{} + mu.Unlock() + + // tryCloseAndDelete should see DrainedAt==0 and skip. + tryCloseAndDelete("socks5://127.0.0.1:1080", &Server{SocksPort: 1080, DrainedAt: now}) + + mu.Lock() + _, ok := servers["socks5://127.0.0.1:1080"] + mu.Unlock() + + if !ok { + t.Error("expected server to remain after tryCloseAndDelete on revived entry") + } +} + +func TestTryCloseAndDelete_ReplacedEntry(t *testing.T) { + ResetForTest() + old := &Server{SocksPort: 1080, DrainedAt: time.Now()} + mu.Lock() + servers["socks5://127.0.0.1:1080"] = old + mu.Unlock() + + // Replace with a new server for same URL. + mu.Lock() + servers["socks5://127.0.0.1:1080"] = &Server{SocksPort: 9999, DrainedAt: time.Time{}} + mu.Unlock() + + // tryCloseAndDelete with old pointer should not delete the new entry. + tryCloseAndDelete("socks5://127.0.0.1:1080", old) + + mu.Lock() + srv, ok := servers["socks5://127.0.0.1:1080"] + mu.Unlock() + + if !ok { + t.Fatal("expected server to still exist") + } + if srv.SocksPort != 9999 { + t.Errorf("expected new server port 9999, got %d", srv.SocksPort) + } +} + +func TestConcurrentGetSetClose(t *testing.T) { + ResetForTest() + DrainTimeout = 50 * time.Millisecond + SweepInterval = 10 * time.Millisecond + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + url := "socks5://127.0.0.1:" + itoa(1000+idx%10) + mu.Lock() + servers[url] = &Server{SocksPort: 1000 + idx%10, DrainedAt: time.Time{}} + mu.Unlock() + _ = getServer(url) + Close(url) + _ = getServer(url) + }(i) + } + wg.Wait() + + // No crash = pass. Verify map is consistent. + mu.Lock() + defer mu.Unlock() + for url, srv := range servers { + if url == "" || srv == nil { + t.Errorf("nil entry in map: url=%q srv=%v", url, srv) + } + } +}