Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pkg/watcher/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"strconv"
"sync"
"time"

"github.com/civo/civogo"
Expand Down Expand Up @@ -39,6 +40,9 @@ type watcher struct {
nodeDesiredGPUCount int
rebootTimeWindowMinutes time.Duration

// NOTE: This is only effective when running with a single node-agent. If we want to run multiple instances, additional logic modifications will be required.
lastRebootCmdTimes sync.Map

nodeSelector *metav1.LabelSelector
}

Expand Down Expand Up @@ -157,11 +161,23 @@ func (w *watcher) run(ctx context.Context) error {

for _, node := range nodes.Items {
if !isNodeDesiredGPU(&node, w.nodeDesiredGPUCount) || !isNodeReady(&node) {

// LTT: LastTransitionTime of node.
// LRCT: LastRebootCmdTimes
// 60: Threshold time (example)
// - LTT > 60 , LRCT < 60 dont reboot
// - LTT < 60 , LRCT < 60 dont reboot
// - LTT < 60 , LRCT > 60 dont reboot
// - LTT > 60, LRCT >. 60 reboot
slog.Info("Node is not ready, attempting to reboot", "node", node.GetName())
if isReadyOrNotReadyStatusChangedAfter(&node, thresholdTime) {
slog.Info("Skipping reboot because Ready/NotReady status was updated recently", "node", node.GetName())
continue
}
if w.isLastRebootCommandTimeAfter(node.GetName(), thresholdTime) {
slog.Info("Skipping reboot because Reboot command was executed recently", "node", node.GetName())
continue
}
if err := w.rebootNode(node.GetName()); err != nil {
slog.Error("Failed to reboot Node", "node", node.GetName(), "error", err)
return fmt.Errorf("failed to reboot node: %w", err)
Expand Down Expand Up @@ -193,6 +209,32 @@ func isReadyOrNotReadyStatusChangedAfter(node *corev1.Node, thresholdTime time.T
return lastChangedTime.After(thresholdTime)
}

// isLastRebootCommandTimeAfter checks if the last reboot command time for the specified node
// is after the given threshold time. In case of delays in reboot, the
// LastTransitionTime of node might not be updated, so it compares the latest reboot
// command time to prevent sending reboot commands multiple times.
// NOTE: This is only effective when running with a single node-agent. If we want to run multiple instances, additional logic modifications will be required.
func (w *watcher) isLastRebootCommandTimeAfter(nodeName string, thresholdTime time.Time) bool {
v, ok := w.lastRebootCmdTimes.Load(nodeName)
if !ok {
slog.Info("LastRebootCommandTime not found", "node", nodeName)
return false
}
lastRebootCmdTime, ok := v.(time.Time)
if !ok {
slog.Info("LastRebootCommandTime is invalid, so it will be removed from the records", "node", nodeName, "value", v)
w.lastRebootCmdTimes.Delete(nodeName)
return false
}

slog.Info("Checking if LastRebootCommandTime has changed recently",
"node", nodeName,
"lastRebootCommandTime", lastRebootCmdTime.String(),
"thresholdTime", thresholdTime.String())

return lastRebootCmdTime.After(thresholdTime)
}

func isNodeReady(node *corev1.Node) bool {
for _, cond := range node.Status.Conditions {
if cond.Type == corev1.NodeReady {
Expand Down Expand Up @@ -241,5 +283,6 @@ func (w *watcher) rebootNode(name string) error {
return fmt.Errorf("failed to reboot instance, clusterID: %s, instanceID: %s: %w", w.clusterID, instance.ID, err)
}
slog.Info("Instance is rebooting", "instanceID", instance.ID, "node", name)
w.lastRebootCmdTimes.Store(name, time.Now())
return nil
}
126 changes: 126 additions & 0 deletions pkg/watcher/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,50 @@ func TestRun(t *testing.T) {
t.Helper()
client := w.client.(*fake.Clientset)

w.lastRebootCmdTimes.Store("node-01", time.Now())

nodes := &corev1.NodeList{
Items: []corev1.Node{
{
ObjectMeta: metav1.ObjectMeta{
Name: "node-01",
Labels: map[string]string{
nodePoolLabelKey: testNodePoolID,
},
},
Status: corev1.NodeStatus{
Conditions: []corev1.NodeCondition{
{
Type: corev1.NodeReady,
Status: corev1.ConditionFalse,
},
},
Allocatable: corev1.ResourceList{
gpuResourceName: resource.MustParse("8"),
},
},
},
},
}
client.Fake.PrependReactor("list", "nodes", func(action k8stesting.Action) (handled bool, ret runtime.Object, err error) {
return true, nodes, nil
})
},
},
{
name: "Returns nil and skips reboot when GPU count matches desired but node is not ready, and LastRebootCmdTime is more recent than thresholdTime",
args: args{
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
client := w.client.(*fake.Clientset)

nodes := &corev1.NodeList{
Items: []corev1.Node{
{
Expand Down Expand Up @@ -600,6 +644,88 @@ func TestIsReadyOrNotReadyStatusChangedAfter(t *testing.T) {
}
}

func TestIsLastRebootCommandTimeAfter(t *testing.T) {
type test struct {
name string
nodeName string
opts []Option
thresholdTime time.Time
beforeFunc func(*watcher)
want bool
}

tests := []test{
{
name: "Return true when last reboot command time is after threshold",
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
},
nodeName: "node-01",
thresholdTime: time.Now().Add(-time.Hour),
beforeFunc: func(w *watcher) {
w.lastRebootCmdTimes.Store("node-01", time.Now())
},
want: true,
},
{
name: "Return false when last reboot command time is before threshold",
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
},
nodeName: "node-01",
thresholdTime: time.Now().Add(-time.Hour),
beforeFunc: func(w *watcher) {
w.lastRebootCmdTimes.Store("nodde-01", time.Now().Add(-2*time.Hour))
},
want: false,
},
{
name: "Return false when last reboot command time not found",
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
},
nodeName: "node-01",
thresholdTime: time.Now().Add(-time.Hour),
want: false,
},
{
name: "Return false when type of last reboot command time is invalid",
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
},
nodeName: "node-01",
thresholdTime: time.Now().Add(-time.Hour),
beforeFunc: func(w *watcher) {
w.lastRebootCmdTimes.Store("nodde-01", "invalid-type")
},
want: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w, err := NewWatcher(t.Context(),
testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.opts...)
if err != nil {
t.Fatal(err)
}

obj := w.(*watcher)
if test.beforeFunc != nil {
test.beforeFunc(obj)
}
got := obj.isLastRebootCommandTimeAfter(test.nodeName, test.thresholdTime)
if got != test.want {
t.Errorf("got = %v, want %v", got, test.want)
}
})
}
}

func TestIsNodeReady(t *testing.T) {
type test struct {
name string
Expand Down