diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index 43e2c11..06d586c 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -103,4 +103,20 @@ func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { } _, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.3f\n", prometheusLabels(snapshot.Index, snapshot.Module, snapshot.UpstreamAddr), duration) } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_accepted_connections_total Total accepted connections since start.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_accepted_connections_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_accepted_connections_total %d\n", s.acceptedConnCount.Load()) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_completed_connections_total Total completed connections since start.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_completed_connections_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_completed_connections_total %d\n", s.completedConnCount.Load()) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_sent_bytes_total Total bytes sent to clients since start.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_sent_bytes_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_sent_bytes_total %d\n", s.sentBytesTotal.Load()) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_received_bytes_total Total bytes received from clients since start.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_received_bytes_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_received_bytes_total %d\n", s.recvBytesTotal.Load()) } diff --git a/pkg/server/server.go b/pkg/server/server.go index aee8da5..857a676 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -147,6 +147,11 @@ type Server struct { connIndex atomic.Uint32 connInfo sync.Map + acceptedConnCount atomic.Uint64 + completedConnCount atomic.Uint64 + sentBytesTotal atomic.Uint64 + recvBytesTotal atomic.Uint64 + TCPListener net.Listener TLSListener net.Listener HTTPListener net.Listener @@ -695,15 +700,23 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err if err := closeRead(upConn, true); err != nil { s.errorLog.F("close upstream read: %v", err) } + downConn.Close() case <-sentClosed: if err := closeRead(downConn, false); err != nil { s.errorLog.F("close downstream read: %v", err) } + upConn.Close() } + <-sentClosed + <-receivedClosed sentBytes := info.SentBytes.Load() receivedBytes := info.ReceivedBytes.Load() + s.completedConnCount.Add(1) + s.sentBytesTotal.Add(uint64(sentBytes)) + s.recvBytesTotal.Add(uint64(receivedBytes)) + duration := time.Since(info.ConnectedAt) s.accessLog.F("client %s finishes module %s (sent: %d, received: %d, duration: %s)", ip, moduleName, sentBytes, receivedBytes, duration) return nil @@ -893,6 +906,7 @@ func (s *Server) Close() { func (s *Server) handleConn(ctx context.Context, conn net.Conn) { s.activeConnCount.Add(1) defer s.activeConnCount.Add(-1) + s.acceptedConnCount.Add(1) connIndex := s.connIndex.Add(1) defer func() { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index fb80166..755674a 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -505,6 +505,65 @@ func TestPrometheusDurationIncludesFractionalSeconds(t *testing.T) { assert.Contains(t, buf.String(), "rsync_proxy_connection_duration_seconds{index=\"1\",module=\"fake\",upstream=\"127.0.0.1:873\"} 0.250\n") } +func TestMetricsIncludesLifetimeCounters(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + payload := []byte("payload from upstream\n") + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + require.NoError(t, err) + _, err = conn.Write(payload) + require.NoError(t, err) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + upstreamAddr := fakeRsync.Listener.Addr().String() + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: upstreamAddr}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + require.NoError(t, err) + + _, err = io.ReadAll(conn) + require.NoError(t, err) + conn.Close() + + require.Eventually(t, func() bool { + return srv.GetActiveConnectionCount() == 0 + }, 3*time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, text, "# HELP rsync_proxy_accepted_connections_total") + assert.Contains(t, text, "# TYPE rsync_proxy_accepted_connections_total counter") + assert.Contains(t, text, "rsync_proxy_accepted_connections_total 1\n") + assert.Contains(t, text, "# HELP rsync_proxy_completed_connections_total") + assert.Contains(t, text, "# TYPE rsync_proxy_completed_connections_total counter") + assert.Contains(t, text, "rsync_proxy_completed_connections_total 1\n") + assert.Contains(t, text, "# HELP rsync_proxy_sent_bytes_total") + assert.Contains(t, text, "# TYPE rsync_proxy_sent_bytes_total counter") + assert.Contains(t, text, fmt.Sprintf("rsync_proxy_sent_bytes_total %d\n", len(payload))) + assert.Contains(t, text, "# HELP rsync_proxy_received_bytes_total") + assert.Contains(t, text, "# TYPE rsync_proxy_received_bytes_total counter") +} + func TestPrometheusLabelValueEscaping(t *testing.T) { assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain")) assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`))