diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index 41b3ec6..5f28054 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "strconv" + "sync" "time" "github.com/civo/civogo" @@ -39,6 +40,9 @@ type watcher struct { nodeDesiredGPUCount int rebootTimeWindowMinutes time.Duration + // NOTE: This is only effective when running with a single node-agent. If we want to run multiple instances, additional logic modifications will be required. + lastRebootCmdTimes sync.Map + nodeSelector *metav1.LabelSelector } @@ -157,11 +161,23 @@ func (w *watcher) run(ctx context.Context) error { for _, node := range nodes.Items { if !isNodeDesiredGPU(&node, w.nodeDesiredGPUCount) || !isNodeReady(&node) { + + // LTT: LastTransitionTime of node. + // LRCT: LastRebootCmdTimes + // 60: Threshold time (example) + // - LTT > 60 , LRCT < 60 dont reboot + // - LTT < 60 , LRCT < 60 dont reboot + // - LTT < 60 , LRCT > 60 dont reboot + // - LTT > 60, LRCT >. 60 reboot slog.Info("Node is not ready, attempting to reboot", "node", node.GetName()) if isReadyOrNotReadyStatusChangedAfter(&node, thresholdTime) { slog.Info("Skipping reboot because Ready/NotReady status was updated recently", "node", node.GetName()) continue } + if w.isLastRebootCommandTimeAfter(node.GetName(), thresholdTime) { + slog.Info("Skipping reboot because Reboot command was executed recently", "node", node.GetName()) + continue + } if err := w.rebootNode(node.GetName()); err != nil { slog.Error("Failed to reboot Node", "node", node.GetName(), "error", err) return fmt.Errorf("failed to reboot node: %w", err) @@ -193,6 +209,32 @@ func isReadyOrNotReadyStatusChangedAfter(node *corev1.Node, thresholdTime time.T return lastChangedTime.After(thresholdTime) } +// isLastRebootCommandTimeAfter checks if the last reboot command time for the specified node +// is after the given threshold time. In case of delays in reboot, the +// LastTransitionTime of node might not be updated, so it compares the latest reboot +// command time to prevent sending reboot commands multiple times. +// NOTE: This is only effective when running with a single node-agent. If we want to run multiple instances, additional logic modifications will be required. +func (w *watcher) isLastRebootCommandTimeAfter(nodeName string, thresholdTime time.Time) bool { + v, ok := w.lastRebootCmdTimes.Load(nodeName) + if !ok { + slog.Info("LastRebootCommandTime not found", "node", nodeName) + return false + } + lastRebootCmdTime, ok := v.(time.Time) + if !ok { + slog.Info("LastRebootCommandTime is invalid, so it will be removed from the records", "node", nodeName, "value", v) + w.lastRebootCmdTimes.Delete(nodeName) + return false + } + + slog.Info("Checking if LastRebootCommandTime has changed recently", + "node", nodeName, + "lastRebootCommandTime", lastRebootCmdTime.String(), + "thresholdTime", thresholdTime.String()) + + return lastRebootCmdTime.After(thresholdTime) +} + func isNodeReady(node *corev1.Node) bool { for _, cond := range node.Status.Conditions { if cond.Type == corev1.NodeReady { @@ -241,5 +283,6 @@ func (w *watcher) rebootNode(name string) error { return fmt.Errorf("failed to reboot instance, clusterID: %s, instanceID: %s: %w", w.clusterID, instance.ID, err) } slog.Info("Instance is rebooting", "instanceID", instance.ID, "node", name) + w.lastRebootCmdTimes.Store(name, time.Now()) return nil } diff --git a/pkg/watcher/watcher_test.go b/pkg/watcher/watcher_test.go index 147c259..c69d17c 100644 --- a/pkg/watcher/watcher_test.go +++ b/pkg/watcher/watcher_test.go @@ -369,6 +369,50 @@ func TestRun(t *testing.T) { t.Helper() client := w.client.(*fake.Clientset) + w.lastRebootCmdTimes.Store("node-01", time.Now()) + + nodes := &corev1.NodeList{ + Items: []corev1.Node{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + Labels: map[string]string{ + nodePoolLabelKey: testNodePoolID, + }, + }, + Status: corev1.NodeStatus{ + Conditions: []corev1.NodeCondition{ + { + Type: corev1.NodeReady, + Status: corev1.ConditionFalse, + }, + }, + Allocatable: corev1.ResourceList{ + gpuResourceName: resource.MustParse("8"), + }, + }, + }, + }, + } + client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) { + return true, nodes, nil + }) + }, + }, + { + name: "Returns nil and skips reboot when GPU count matches desired but node is not ready, and LastRebootCmdTime is more recent than thresholdTime", + args: args{ + opts: []Option{ + WithKubernetesClient(fake.NewSimpleClientset()), + WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), + }, + nodePoolID: testNodePoolID, + }, + beforeFunc: func(w *watcher) { + t.Helper() + client := w.client.(*fake.Clientset) + nodes := &corev1.NodeList{ Items: []corev1.Node{ { @@ -600,6 +644,88 @@ func TestIsReadyOrNotReadyStatusChangedAfter(t *testing.T) { } } +func TestIsLastRebootCommandTimeAfter(t *testing.T) { + type test struct { + name string + nodeName string + opts []Option + thresholdTime time.Time + beforeFunc func(*watcher) + want bool + } + + tests := []test{ + { + name: "Return true when last reboot command time is after threshold", + opts: []Option{ + WithKubernetesClient(fake.NewSimpleClientset()), + WithCivoClient(&FakeClient{}), + }, + nodeName: "node-01", + thresholdTime: time.Now().Add(-time.Hour), + beforeFunc: func(w *watcher) { + w.lastRebootCmdTimes.Store("node-01", time.Now()) + }, + want: true, + }, + { + name: "Return false when last reboot command time is before threshold", + opts: []Option{ + WithKubernetesClient(fake.NewSimpleClientset()), + WithCivoClient(&FakeClient{}), + }, + nodeName: "node-01", + thresholdTime: time.Now().Add(-time.Hour), + beforeFunc: func(w *watcher) { + w.lastRebootCmdTimes.Store("nodde-01", time.Now().Add(-2*time.Hour)) + }, + want: false, + }, + { + name: "Return false when last reboot command time not found", + opts: []Option{ + WithKubernetesClient(fake.NewSimpleClientset()), + WithCivoClient(&FakeClient{}), + }, + nodeName: "node-01", + thresholdTime: time.Now().Add(-time.Hour), + want: false, + }, + { + name: "Return false when type of last reboot command time is invalid", + opts: []Option{ + WithKubernetesClient(fake.NewSimpleClientset()), + WithCivoClient(&FakeClient{}), + }, + nodeName: "node-01", + thresholdTime: time.Now().Add(-time.Hour), + beforeFunc: func(w *watcher) { + w.lastRebootCmdTimes.Store("nodde-01", "invalid-type") + }, + want: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + w, err := NewWatcher(t.Context(), + testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.opts...) + if err != nil { + t.Fatal(err) + } + + obj := w.(*watcher) + if test.beforeFunc != nil { + test.beforeFunc(obj) + } + got := obj.isLastRebootCommandTimeAfter(test.nodeName, test.thresholdTime) + if got != test.want { + t.Errorf("got = %v, want %v", got, test.want) + } + }) + } +} + func TestIsNodeReady(t *testing.T) { type test struct { name string