@@ -16,27 +16,24 @@ import (
1616 "syscall"
1717 "time"
1818
19+ _ "github.com/go-sql-driver/mysql"
1920 "github.com/gorilla/websocket"
2021 "github.com/prometheus/client_golang/prometheus"
2122 "github.com/prometheus/client_golang/prometheus/promhttp"
2223 logrus "github.com/sirupsen/logrus"
23-
24- _ "github.com/go-sql-driver/mysql"
2524 _ "modernc.org/sqlite"
2625)
2726
2827// ///////////////////
2928// GLOBAL CONFIG
3029// ///////////////////
3130var (
32- dbDSN string
33- dbDriver string
34- serverAddr string
35- allowedOrigins []string
36-
37- pongWait = 60 * time .Second
38- pingPeriod = 30 * time .Second
39-
31+ dbDSN string
32+ dbDriver string
33+ serverAddr string
34+ allowedOrigins []string
35+ pongWait = 60 * time .Second
36+ pingPeriod = 30 * time .Second
4037 offlineTTL time.Duration
4138 maxQueuedMessagesPerPlayer int
4239 maxConnectionsPerPlayer int
10198 // Rate limiting
10299 limiters = make (map [string ]* limiter )
103100 lm sync.Mutex
101+
102+ // logFileHandle holds the open log file so it can be closed on shutdown.
103+ logFileHandle * os.File
104104)
105105
106106type limiter struct {
@@ -250,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) {
250250 flushPendingMessages (playerID , c )
251251}
252252
253- // unregisterConnection removes a websocket connection for a player and decrements the active connections metric.
253+ // unregisterConnection removes a websocket connection for a player, explicitly closes the underlying
254+ // websocket (releasing the file descriptor), and decrements the active connections metric.
254255// If no connections remain for the player, the player's entry is removed from the players map.
255256func unregisterConnection (playerID string , c * websocket.Conn ) {
256257 mu .Lock ()
@@ -259,6 +260,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) {
259260 if len (players [playerID ]) == 0 {
260261 delete (players , playerID )
261262 }
263+
264+ _ = c .Close ()
262265 connections .Dec ()
263266}
264267
@@ -317,8 +320,8 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
317320 // check connection limit
318321 mu .Lock ()
319322 current := len (players [playerID ])
320- mu .Unlock ()
321323 if current >= maxConnectionsPerPlayer {
324+ mu .Unlock ()
322325 logrus .WithFields (logrus.Fields {
323326 "player_id" : playerID ,
324327 "current" : current ,
@@ -327,6 +330,7 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
327330 http .Error (w , "too many connections" , http .StatusTooManyRequests )
328331 return
329332 }
333+ mu .Unlock ()
330334
331335 // upgrade to WebSocket
332336 conn , err := upgrader .Upgrade (w , r , nil )
@@ -353,9 +357,18 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
353357
354358 ticker := time .NewTicker (pingPeriod )
355359 defer ticker .Stop ()
360+
361+ done := make (chan struct {})
362+ defer close (done )
363+
356364 go func () {
357- for range ticker .C {
358- _ = conn .WriteControl (websocket .PingMessage , []byte {}, time .Now ().Add (5 * time .Second ))
365+ for {
366+ select {
367+ case <- done :
368+ return
369+ case <- ticker .C :
370+ _ = conn .WriteControl (websocket .PingMessage , []byte {}, time .Now ().Add (5 * time .Second ))
371+ }
359372 }
360373 }()
361374
@@ -499,28 +512,105 @@ func initMetrics() {
499512 prometheus .MustRegister (connections , messagesPublished , messagesDelivered )
500513}
501514
515+ // startOfflineMessageCleanup periodically removes expired pending messages from the offline queue.
516+ // It stops when stopCh is closed.
517+ func startOfflineMessageCleanup (stopCh <- chan struct {}) {
518+ go func () {
519+ ticker := time .NewTicker (30 * time .Second )
520+ defer ticker .Stop ()
521+
522+ for {
523+ select {
524+ case <- stopCh :
525+ return
526+ case <- ticker .C :
527+ now := time .Now ()
528+ pendingMu .Lock ()
529+ for pid , msgs := range pendingMessages {
530+ filtered := msgs [:0 ]
531+ for _ , pm := range msgs {
532+ if now .Sub (pm .timestamp ) <= offlineTTL {
533+ filtered = append (filtered , pm )
534+ }
535+ }
536+ if len (filtered ) == 0 {
537+ delete (pendingMessages , pid )
538+ } else {
539+ pendingMessages [pid ] = filtered
540+ }
541+ }
542+ pendingMu .Unlock ()
543+ }
544+ }
545+ }()
546+ }
547+
548+ // buildServer constructs and returns the HTTP server and its ServeMux with all routes registered.
549+ func buildServer () * http.Server {
550+ mux := http .NewServeMux ()
551+ mux .HandleFunc ("/ws" , wsHandler )
552+ mux .HandleFunc ("/publish" , publishHandler )
553+ mux .HandleFunc ("/broadcast" , broadcastHandler )
554+ mux .Handle ("/metrics" , promhttp .Handler ())
555+
556+ return & http.Server {Addr : serverAddr , Handler : mux }
557+ }
558+
559+ // runServer starts the HTTP server and blocks until it shuts down.
560+ // It listens for SIGINT/SIGTERM, closes stopCh to signal background goroutines,
561+ // then performs a graceful HTTP shutdown followed by closing all websocket connections.
562+ // Returns an error if the server exits unexpectedly.
563+ func runServer (server * http.Server , stopCh chan struct {}) error {
564+ quit := make (chan os.Signal , 1 )
565+ signal .Notify (quit , syscall .SIGINT , syscall .SIGTERM )
566+
567+ go func () {
568+ <- quit
569+ logrus .Info ("Shutting down server..." )
570+ // Signal all background goroutines (cleanup, revalidation) to stop.
571+ close (stopCh )
572+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
573+ defer cancel ()
574+ _ = server .Shutdown (ctx )
575+ closeAllConnections ()
576+ }()
577+
578+ logrus .Infof ("Server listening on %s" , serverAddr )
579+ if err := server .ListenAndServe (); err != nil && err != http .ErrServerClosed {
580+ return fmt .Errorf ("server error: %w" , err )
581+ }
582+ return nil
583+ }
584+
502585// startTokenRevalidation periodically validates all active websocket tokens.
503586// Invalid tokens cause connections to be closed and removed.
504- func startTokenRevalidation (interval time.Duration ) {
587+ // The provided stopCh can be closed to stop the revalidation loop and its ticker.
588+ func startTokenRevalidation (interval time.Duration , stopCh <- chan struct {}) {
505589 ticker := time .NewTicker (interval )
506590 go func () {
507- for range ticker .C {
508- mu .Lock ()
509- for playerID , conns := range players {
510- for c , wc := range conns {
511- _ , valid := validateToken (wc .token , false )
512- if ! valid {
513- logrus .WithFields (logrus.Fields {
514- "player_id" : playerID ,
515- }).Info ("Token invalid, closing connection" )
516- _ = wc .conn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (websocket .ClosePolicyViolation , "token expired" ))
517- _ = wc .conn .Close ()
518- delete (conns , c )
519- connections .Dec ()
591+ defer ticker .Stop ()
592+ for {
593+ select {
594+ case <- stopCh :
595+ return
596+ case <- ticker .C :
597+ mu .Lock ()
598+ for playerID , conns := range players {
599+ for c , wc := range conns {
600+ _ , valid := validateToken (wc .token , false )
601+ if ! valid {
602+ logrus .WithFields (logrus.Fields {
603+ "player_id" : playerID ,
604+ }).Info ("Token invalid, closing connection" )
605+ _ = wc .conn .WriteMessage (websocket .CloseMessage , websocket .FormatCloseMessage (websocket .ClosePolicyViolation , "token expired" ))
606+ _ = wc .conn .Close ()
607+ delete (conns , c )
608+ connections .Dec ()
609+ }
520610 }
521611 }
612+ mu .Unlock ()
522613 }
523- mu .Unlock ()
524614 }
525615 }()
526616}
@@ -588,6 +678,7 @@ func daemonizeSelf() error {
588678// setupLogging configures logrus logging for the application.
589679// It sets the output destination and log level based on global flags.
590680// Returns an error if the log file cannot be opened or if the log level is invalid.
681+ // The opened log file handle is stored in logFileHandle so it can be closed on shutdown.
591682func setupLogging () error {
592683 logrus .SetFormatter (& logrus.JSONFormatter {})
593684
@@ -596,6 +687,8 @@ func setupLogging() error {
596687 if err != nil {
597688 return fmt .Errorf ("failed to open log file %s: %w" , logFile , err )
598689 }
690+
691+ logFileHandle = f
599692 logrus .SetOutput (f )
600693 } else {
601694 logrus .SetOutput (os .Stdout )
@@ -657,11 +750,17 @@ func run() error {
657750 return fmt .Errorf ("failed to setup logging: %w" , err )
658751 }
659752
753+ if logFileHandle != nil {
754+ defer logFileHandle .Close ()
755+ }
756+
660757 // Initialize DB
661758 if err := initDB (); err != nil {
662759 return fmt .Errorf ("failed to init DB: %w" , err )
663760 }
664761
762+ defer db .Close ()
763+
665764 // Daemonize if needed
666765 if daemonize {
667766 if err := daemonizeSelf (); err != nil {
@@ -671,57 +770,11 @@ func run() error {
671770
672771 initMetrics ()
673772
674- // Start offline message cleanup
675- go func () {
676- ticker := time .NewTicker (30 * time .Second )
677- for range ticker .C {
678- now := time .Now ()
679- pendingMu .Lock ()
680- for pid , msgs := range pendingMessages {
681- filtered := msgs [:0 ]
682- for _ , pm := range msgs {
683- if now .Sub (pm .timestamp ) <= offlineTTL {
684- filtered = append (filtered , pm )
685- }
686- }
687- if len (filtered ) == 0 {
688- delete (pendingMessages , pid )
689- } else {
690- pendingMessages [pid ] = filtered
691- }
692- }
693- pendingMu .Unlock ()
694- }
695- }()
773+ // stopCh is closed on shutdown to signal background goroutines to exit.
774+ stopCh := make (chan struct {})
696775
697- // start WS token revalidation
698- startTokenRevalidation (tokenRevalidationPeriod )
776+ startOfflineMessageCleanup ( stopCh )
777+ startTokenRevalidation (tokenRevalidationPeriod , stopCh )
699778
700- mux := http .NewServeMux ()
701- mux .HandleFunc ("/ws" , wsHandler )
702- mux .HandleFunc ("/publish" , publishHandler )
703- mux .HandleFunc ("/broadcast" , broadcastHandler )
704- mux .Handle ("/metrics" , promhttp .Handler ())
705-
706- server := & http.Server {Addr : serverAddr , Handler : mux }
707-
708- // Graceful shutdown
709- quit := make (chan os.Signal , 1 )
710- signal .Notify (quit , syscall .SIGINT , syscall .SIGTERM )
711- go func () {
712- <- quit
713- logrus .Info ("Shutting down server..." )
714- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
715- defer cancel ()
716- _ = server .Shutdown (ctx )
717- closeAllConnections ()
718- }()
719-
720- logrus .Infof ("Server listening on %s" , serverAddr )
721- err := server .ListenAndServe ()
722- if err != nil && err != http .ErrServerClosed {
723- return fmt .Errorf ("server error: %w" , err )
724- }
725-
726- return nil
779+ return runServer (buildServer (), stopCh )
727780}
0 commit comments