Skip to content

Commit 43261be

Browse files
authored
plug file-descriptor leaks (#32)
* plug file-descriptor leaks * reduce complexity * remove the fix comment
1 parent 8b5c6b1 commit 43261be

1 file changed

Lines changed: 133 additions & 80 deletions

File tree

main.go

Lines changed: 133 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,24 @@ import (
1616
"syscall"
1717
"time"
1818

19+
_ "github.com/go-sql-driver/mysql"
1920
"github.com/gorilla/websocket"
2021
"github.com/prometheus/client_golang/prometheus"
2122
"github.com/prometheus/client_golang/prometheus/promhttp"
2223
logrus "github.com/sirupsen/logrus"
23-
24-
_ "github.com/go-sql-driver/mysql"
2524
_ "modernc.org/sqlite"
2625
)
2726

2827
// ///////////////////
2928
// GLOBAL CONFIG
3029
// ///////////////////
3130
var (
32-
dbDSN string
33-
dbDriver string
34-
serverAddr string
35-
allowedOrigins []string
36-
37-
pongWait = 60 * time.Second
38-
pingPeriod = 30 * time.Second
39-
31+
dbDSN string
32+
dbDriver string
33+
serverAddr string
34+
allowedOrigins []string
35+
pongWait = 60 * time.Second
36+
pingPeriod = 30 * time.Second
4037
offlineTTL time.Duration
4138
maxQueuedMessagesPerPlayer int
4239
maxConnectionsPerPlayer int
@@ -101,6 +98,9 @@ var (
10198
// Rate limiting
10299
limiters = make(map[string]*limiter)
103100
lm sync.Mutex
101+
102+
// logFileHandle holds the open log file so it can be closed on shutdown.
103+
logFileHandle *os.File
104104
)
105105

106106
type limiter struct {
@@ -250,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) {
250250
flushPendingMessages(playerID, c)
251251
}
252252

253-
// unregisterConnection removes a websocket connection for a player and decrements the active connections metric.
253+
// unregisterConnection removes a websocket connection for a player, explicitly closes the underlying
254+
// websocket (releasing the file descriptor), and decrements the active connections metric.
254255
// If no connections remain for the player, the player's entry is removed from the players map.
255256
func unregisterConnection(playerID string, c *websocket.Conn) {
256257
mu.Lock()
@@ -259,6 +260,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) {
259260
if len(players[playerID]) == 0 {
260261
delete(players, playerID)
261262
}
263+
264+
_ = c.Close()
262265
connections.Dec()
263266
}
264267

@@ -317,8 +320,8 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
317320
// check connection limit
318321
mu.Lock()
319322
current := len(players[playerID])
320-
mu.Unlock()
321323
if current >= maxConnectionsPerPlayer {
324+
mu.Unlock()
322325
logrus.WithFields(logrus.Fields{
323326
"player_id": playerID,
324327
"current": current,
@@ -327,6 +330,7 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
327330
http.Error(w, "too many connections", http.StatusTooManyRequests)
328331
return
329332
}
333+
mu.Unlock()
330334

331335
// upgrade to WebSocket
332336
conn, err := upgrader.Upgrade(w, r, nil)
@@ -353,9 +357,18 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
353357

354358
ticker := time.NewTicker(pingPeriod)
355359
defer ticker.Stop()
360+
361+
done := make(chan struct{})
362+
defer close(done)
363+
356364
go func() {
357-
for range ticker.C {
358-
_ = conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
365+
for {
366+
select {
367+
case <-done:
368+
return
369+
case <-ticker.C:
370+
_ = conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second))
371+
}
359372
}
360373
}()
361374

@@ -499,28 +512,105 @@ func initMetrics() {
499512
prometheus.MustRegister(connections, messagesPublished, messagesDelivered)
500513
}
501514

515+
// startOfflineMessageCleanup periodically removes expired pending messages from the offline queue.
516+
// It stops when stopCh is closed.
517+
func startOfflineMessageCleanup(stopCh <-chan struct{}) {
518+
go func() {
519+
ticker := time.NewTicker(30 * time.Second)
520+
defer ticker.Stop()
521+
522+
for {
523+
select {
524+
case <-stopCh:
525+
return
526+
case <-ticker.C:
527+
now := time.Now()
528+
pendingMu.Lock()
529+
for pid, msgs := range pendingMessages {
530+
filtered := msgs[:0]
531+
for _, pm := range msgs {
532+
if now.Sub(pm.timestamp) <= offlineTTL {
533+
filtered = append(filtered, pm)
534+
}
535+
}
536+
if len(filtered) == 0 {
537+
delete(pendingMessages, pid)
538+
} else {
539+
pendingMessages[pid] = filtered
540+
}
541+
}
542+
pendingMu.Unlock()
543+
}
544+
}
545+
}()
546+
}
547+
548+
// buildServer constructs and returns the HTTP server and its ServeMux with all routes registered.
549+
func buildServer() *http.Server {
550+
mux := http.NewServeMux()
551+
mux.HandleFunc("/ws", wsHandler)
552+
mux.HandleFunc("/publish", publishHandler)
553+
mux.HandleFunc("/broadcast", broadcastHandler)
554+
mux.Handle("/metrics", promhttp.Handler())
555+
556+
return &http.Server{Addr: serverAddr, Handler: mux}
557+
}
558+
559+
// runServer starts the HTTP server and blocks until it shuts down.
560+
// It listens for SIGINT/SIGTERM, closes stopCh to signal background goroutines,
561+
// then performs a graceful HTTP shutdown followed by closing all websocket connections.
562+
// Returns an error if the server exits unexpectedly.
563+
func runServer(server *http.Server, stopCh chan struct{}) error {
564+
quit := make(chan os.Signal, 1)
565+
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
566+
567+
go func() {
568+
<-quit
569+
logrus.Info("Shutting down server...")
570+
// Signal all background goroutines (cleanup, revalidation) to stop.
571+
close(stopCh)
572+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
573+
defer cancel()
574+
_ = server.Shutdown(ctx)
575+
closeAllConnections()
576+
}()
577+
578+
logrus.Infof("Server listening on %s", serverAddr)
579+
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
580+
return fmt.Errorf("server error: %w", err)
581+
}
582+
return nil
583+
}
584+
502585
// startTokenRevalidation periodically validates all active websocket tokens.
503586
// Invalid tokens cause connections to be closed and removed.
504-
func startTokenRevalidation(interval time.Duration) {
587+
// The provided stopCh can be closed to stop the revalidation loop and its ticker.
588+
func startTokenRevalidation(interval time.Duration, stopCh <-chan struct{}) {
505589
ticker := time.NewTicker(interval)
506590
go func() {
507-
for range ticker.C {
508-
mu.Lock()
509-
for playerID, conns := range players {
510-
for c, wc := range conns {
511-
_, valid := validateToken(wc.token, false)
512-
if !valid {
513-
logrus.WithFields(logrus.Fields{
514-
"player_id": playerID,
515-
}).Info("Token invalid, closing connection")
516-
_ = wc.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "token expired"))
517-
_ = wc.conn.Close()
518-
delete(conns, c)
519-
connections.Dec()
591+
defer ticker.Stop()
592+
for {
593+
select {
594+
case <-stopCh:
595+
return
596+
case <-ticker.C:
597+
mu.Lock()
598+
for playerID, conns := range players {
599+
for c, wc := range conns {
600+
_, valid := validateToken(wc.token, false)
601+
if !valid {
602+
logrus.WithFields(logrus.Fields{
603+
"player_id": playerID,
604+
}).Info("Token invalid, closing connection")
605+
_ = wc.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "token expired"))
606+
_ = wc.conn.Close()
607+
delete(conns, c)
608+
connections.Dec()
609+
}
520610
}
521611
}
612+
mu.Unlock()
522613
}
523-
mu.Unlock()
524614
}
525615
}()
526616
}
@@ -588,6 +678,7 @@ func daemonizeSelf() error {
588678
// setupLogging configures logrus logging for the application.
589679
// It sets the output destination and log level based on global flags.
590680
// Returns an error if the log file cannot be opened or if the log level is invalid.
681+
// The opened log file handle is stored in logFileHandle so it can be closed on shutdown.
591682
func setupLogging() error {
592683
logrus.SetFormatter(&logrus.JSONFormatter{})
593684

@@ -596,6 +687,8 @@ func setupLogging() error {
596687
if err != nil {
597688
return fmt.Errorf("failed to open log file %s: %w", logFile, err)
598689
}
690+
691+
logFileHandle = f
599692
logrus.SetOutput(f)
600693
} else {
601694
logrus.SetOutput(os.Stdout)
@@ -657,11 +750,17 @@ func run() error {
657750
return fmt.Errorf("failed to setup logging: %w", err)
658751
}
659752

753+
if logFileHandle != nil {
754+
defer logFileHandle.Close()
755+
}
756+
660757
// Initialize DB
661758
if err := initDB(); err != nil {
662759
return fmt.Errorf("failed to init DB: %w", err)
663760
}
664761

762+
defer db.Close()
763+
665764
// Daemonize if needed
666765
if daemonize {
667766
if err := daemonizeSelf(); err != nil {
@@ -671,57 +770,11 @@ func run() error {
671770

672771
initMetrics()
673772

674-
// Start offline message cleanup
675-
go func() {
676-
ticker := time.NewTicker(30 * time.Second)
677-
for range ticker.C {
678-
now := time.Now()
679-
pendingMu.Lock()
680-
for pid, msgs := range pendingMessages {
681-
filtered := msgs[:0]
682-
for _, pm := range msgs {
683-
if now.Sub(pm.timestamp) <= offlineTTL {
684-
filtered = append(filtered, pm)
685-
}
686-
}
687-
if len(filtered) == 0 {
688-
delete(pendingMessages, pid)
689-
} else {
690-
pendingMessages[pid] = filtered
691-
}
692-
}
693-
pendingMu.Unlock()
694-
}
695-
}()
773+
// stopCh is closed on shutdown to signal background goroutines to exit.
774+
stopCh := make(chan struct{})
696775

697-
// start WS token revalidation
698-
startTokenRevalidation(tokenRevalidationPeriod)
776+
startOfflineMessageCleanup(stopCh)
777+
startTokenRevalidation(tokenRevalidationPeriod, stopCh)
699778

700-
mux := http.NewServeMux()
701-
mux.HandleFunc("/ws", wsHandler)
702-
mux.HandleFunc("/publish", publishHandler)
703-
mux.HandleFunc("/broadcast", broadcastHandler)
704-
mux.Handle("/metrics", promhttp.Handler())
705-
706-
server := &http.Server{Addr: serverAddr, Handler: mux}
707-
708-
// Graceful shutdown
709-
quit := make(chan os.Signal, 1)
710-
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
711-
go func() {
712-
<-quit
713-
logrus.Info("Shutting down server...")
714-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
715-
defer cancel()
716-
_ = server.Shutdown(ctx)
717-
closeAllConnections()
718-
}()
719-
720-
logrus.Infof("Server listening on %s", serverAddr)
721-
err := server.ListenAndServe()
722-
if err != nil && err != http.ErrServerClosed {
723-
return fmt.Errorf("server error: %w", err)
724-
}
725-
726-
return nil
779+
return runServer(buildServer(), stopCh)
727780
}

0 commit comments

Comments
 (0)