diff --git a/services/wireguard/connection/connection.go b/services/wireguard/connection/connection.go index 3657b8b00c..7756475b49 100644 --- a/services/wireguard/connection/connection.go +++ b/services/wireguard/connection/connection.go @@ -86,6 +86,15 @@ func (c *Connection) State() <-chan connectionstate.State { return c.stateCh } +// sendState safely sends a state to the state channel, preventing panics from sends on closed channel. +func (c *Connection) sendState(state connectionstate.State) { + select { + case c.stateCh <- state: + case <-c.done: + // Connection is stopped, don't send + } +} + // Statistics returns connection statistics channel. func (c *Connection) Statistics() (connectionstate.Statistics, error) { stats, err := c.connectionEndpoint.PeerStats() @@ -128,7 +137,7 @@ func (c *Connection) start(ctx context.Context, start startConn, options connect } }() - c.stateCh <- connectionstate.Connecting + c.sendState(connectionstate.Connecting) if options.ProviderNATConn != nil { options.ProviderNATConn.Close() @@ -170,7 +179,7 @@ func (c *Connection) start(ctx context.Context, start startConn, options connect return errors.Wrap(err, "failed while waiting for a peer handshake") } - c.stateCh <- connectionstate.Connected + c.sendState(connectionstate.Connected) return nil } @@ -224,7 +233,7 @@ func (c *Connection) GetConfig() (connection.ConsumerConfig, error) { func (c *Connection) Stop() { c.stopOnce.Do(func() { log.Info().Msg("Stopping WireGuard connection") - c.stateCh <- connectionstate.Disconnecting + c.sendState(connectionstate.Disconnecting) if c.removeAllowedIPRule != nil { c.removeAllowedIPRule() @@ -236,9 +245,14 @@ func (c *Connection) Stop() { } } - c.stateCh <- connectionstate.NotConnected + // Send final state before closing channels + select { + case c.stateCh <- connectionstate.NotConnected: + default: + } - close(c.stateCh) + // Close done first so sendState calls will bail out instead of sending on closed stateCh close(c.done) + close(c.stateCh) }) }