diff --git a/main.go b/main.go index 49e3ec9..94f678b 100644 --- a/main.go +++ b/main.go @@ -28,8 +28,9 @@ func run(ctx context.Context) error { ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() - w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID, nodeDesiredGPUCount, + w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID, watcher.WithRebootTimeWindowMinutes(rebootTimeWindowMinutes), + watcher.WithDesiredGPUCount(nodeDesiredGPUCount), ) if err != nil { return err diff --git a/pkg/watcher/options.go b/pkg/watcher/options.go index 5fbad75..904634d 100644 --- a/pkg/watcher/options.go +++ b/pkg/watcher/options.go @@ -1,6 +1,7 @@ package watcher import ( + "log/slog" "strconv" "time" @@ -13,6 +14,7 @@ type Option func(*watcher) var defaultOptions = []Option{ WithRebootTimeWindowMinutes("40"), + WithDesiredGPUCount("0"), } // WithKubernetesClient returns Option to set Kubernetes API client. @@ -48,6 +50,20 @@ func WithRebootTimeWindowMinutes(s string) Option { n, err := strconv.Atoi(s) if err == nil && n > 0 { w.rebootTimeWindowMinutes = time.Duration(n) + } else { + slog.Info("RebootTimeWindowMinutes is invalid", "value", s) + } + } +} + +// WithDesiredGPUCount returns Option to set desired GPU count . +func WithDesiredGPUCount(s string) Option { + return func(w *watcher) { + n, err := strconv.Atoi(s) + if err == nil && n >= 0 { + w.nodeDesiredGPUCount = n + } else { + slog.Info("DesiredGPUCount is invalid", "value", s) } } } diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index 0f3d60f..41b3ec6 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -42,7 +42,7 @@ type watcher struct { nodeSelector *metav1.LabelSelector } -func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID, nodeDesiredGPUCount string, opts ...Option) (Watcher, error) { +func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID string, opts ...Option) (Watcher, error) { w := &watcher{ clusterID: clusterID, apiKey: apiKey, @@ -63,12 +63,6 @@ func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePool return nil, fmt.Errorf("CIVO_API_KEY not set") } - n, err := strconv.Atoi(nodeDesiredGPUCount) - if err != nil { - return nil, fmt.Errorf("CIVO_NODE_DESIRED_GPU_COUNT has an invalid value, %s: %w", nodeDesiredGPUCount, err) - } - - w.nodeDesiredGPUCount = n w.nodeSelector = &metav1.LabelSelector{ MatchLabels: map[string]string{ nodePoolLabelKey: nodePoolID, @@ -212,7 +206,7 @@ func isNodeReady(node *corev1.Node) bool { func isNodeDesiredGPU(node *corev1.Node, desired int) bool { if desired == 0 { - slog.Info("Desired GPU count is set to 0", "node", node.GetName()) + slog.Info("Desired GPU count is set to 0, so the GPU count check is skipped", "node", node.GetName()) return true } diff --git a/pkg/watcher/watcher_test.go b/pkg/watcher/watcher_test.go index c805922..147c259 100644 --- a/pkg/watcher/watcher_test.go +++ b/pkg/watcher/watcher_test.go @@ -28,13 +28,12 @@ var ( func TestNew(t *testing.T) { type args struct { - clusterID string - region string - apiKey string - apiURL string - nodePoolID string - nodeDesiredGPUCount string - opts []Option + clusterID string + region string + apiKey string + apiURL string + nodePoolID string + opts []Option } type test struct { name string @@ -47,17 +46,15 @@ func TestNew(t *testing.T) { { name: "Returns no error when given valid input", args: args{ - clusterID: testClusterID, - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, - nodeDesiredGPUCount: testNodeDesiredGPUCount, + clusterID: testClusterID, + region: testRegion, + apiKey: testApiKey, + apiURL: testApiURL, + nodePoolID: testNodePoolID, opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), - WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used. - WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used. + WithDesiredGPUCount(testNodeDesiredGPUCount), }, }, checkFunc: func(w *watcher) error { @@ -97,51 +94,66 @@ func TestNew(t *testing.T) { }, }, { - name: "Returns an error when clusterID is missing", + name: "Returns no error when input is invalid, but default value is set", args: args{ - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, - nodeDesiredGPUCount: testNodeDesiredGPUCount, + clusterID: testClusterID, + region: testRegion, + apiKey: testApiKey, + apiURL: testApiURL, + nodePoolID: testNodePoolID, opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount("invalid"), // It is invalid, but the default count (0) will be used. + WithDesiredGPUCount("-1"), // It is invalid, but the default count (0) will be used. + WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used. + WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used. }, }, - wantErr: true, + checkFunc: func(w *watcher) error { + if w.nodeDesiredGPUCount != 0 { + return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0) + } + if w.rebootTimeWindowMinutes != testRebootTimeWindowMinutes { + return fmt.Errorf("w.rebootTimeWindowMinutes mismatch: got %v, want %s", w.nodeSelector, testNodePoolID) + } + return nil + }, }, { - name: "Returns an error when nodeDesiredGPUCount is invalid", + name: "Returns no error when nodeDesiredGPUCount is 0", args: args{ - clusterID: testClusterID, - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, - nodeDesiredGPUCount: "invalid_number", + clusterID: testClusterID, + region: testRegion, + apiKey: testApiKey, + apiURL: testApiURL, + nodePoolID: testNodePoolID, opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount("0"), }, }, - wantErr: true, + checkFunc: func(w *watcher) error { + if w.nodeDesiredGPUCount != 0 { + return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0) + } + return nil + }, }, { - name: "Returns an error when nodeDesiredGPUCount is 0", + name: "Returns an error when clusterID is missing", args: args{ - clusterID: testClusterID, - region: testRegion, - apiKey: testApiKey, - apiURL: testApiURL, - nodePoolID: testNodePoolID, - nodeDesiredGPUCount: "0", + region: testRegion, + apiKey: testApiKey, + apiURL: testApiURL, + nodePoolID: testNodePoolID, opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), }, }, - wantErr: false, + wantErr: true, }, } @@ -153,7 +165,6 @@ func TestNew(t *testing.T) { test.args.region, test.args.clusterID, test.args.nodePoolID, - test.args.nodeDesiredGPUCount, test.args.opts...) if (err != nil) != test.wantErr { t.Errorf("error = %v, wantErr %v", err, test.wantErr) @@ -177,9 +188,8 @@ func TestNew(t *testing.T) { func TestRun(t *testing.T) { type args struct { - opts []Option - nodeDesiredGPUCount string - nodePoolID string + opts []Option + nodePoolID string } type test struct { name string @@ -195,9 +205,9 @@ func TestRun(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, - nodeDesiredGPUCount: testNodeDesiredGPUCount, - nodePoolID: testNodePoolID, + nodePoolID: testNodePoolID, }, beforeFunc: func(w *watcher) { t.Helper() @@ -241,9 +251,9 @@ func TestRun(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, - nodeDesiredGPUCount: testNodeDesiredGPUCount, - nodePoolID: testNodePoolID, + nodePoolID: testNodePoolID, }, beforeFunc: func(w *watcher) { t.Helper() @@ -298,9 +308,9 @@ func TestRun(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, - nodeDesiredGPUCount: testNodeDesiredGPUCount, - nodePoolID: testNodePoolID, + nodePoolID: testNodePoolID, }, beforeFunc: func(w *watcher) { t.Helper() @@ -351,9 +361,9 @@ func TestRun(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, - nodeDesiredGPUCount: testNodeDesiredGPUCount, - nodePoolID: testNodePoolID, + nodePoolID: testNodePoolID, }, beforeFunc: func(w *watcher) { t.Helper() @@ -394,9 +404,9 @@ func TestRun(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, - nodeDesiredGPUCount: testNodeDesiredGPUCount, - nodePoolID: testNodePoolID, + nodePoolID: testNodePoolID, }, beforeFunc: func(w *watcher) { t.Helper() @@ -415,9 +425,9 @@ func TestRun(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, - nodeDesiredGPUCount: testNodeDesiredGPUCount, - nodePoolID: testNodePoolID, + nodePoolID: testNodePoolID, }, beforeFunc: func(w *watcher) { t.Helper() @@ -462,7 +472,7 @@ func TestRun(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { w, err := NewWatcher(t.Context(), - testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.nodeDesiredGPUCount, test.args.opts...) + testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.opts...) if err != nil { t.Fatal(err) } @@ -682,6 +692,19 @@ func TestIsNodeDesiredGPU(t *testing.T) { desired: 8, want: true, }, + { + name: "Returns true when desired GPU count is 0, so count check is skipped", + node: &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-01", + }, + Status: corev1.NodeStatus{ + Allocatable: corev1.ResourceList{}, + }, + }, + desired: 0, + want: true, + }, { name: "Returns false when GPU count is 0", node: &corev1.Node{ @@ -744,6 +767,7 @@ func TestRebootNode(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, }, beforeFunc: func(t *testing.T, w *watcher) { @@ -772,6 +796,7 @@ func TestRebootNode(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, }, beforeFunc: func(t *testing.T, w *watcher) { @@ -791,6 +816,7 @@ func TestRebootNode(t *testing.T) { opts: []Option{ WithKubernetesClient(fake.NewSimpleClientset()), WithCivoClient(&FakeClient{}), + WithDesiredGPUCount(testNodeDesiredGPUCount), }, }, beforeFunc: func(t *testing.T, w *watcher) { @@ -818,7 +844,7 @@ func TestRebootNode(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { w, err := NewWatcher(t.Context(), - testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, testNodeDesiredGPUCount, test.args.opts...) + testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.args.opts...) if err != nil { t.Fatal(err) }