Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ func run(ctx context.Context) error {
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()

w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID, nodeDesiredGPUCount,
w, err := watcher.NewWatcher(ctx, apiURL, apiKey, region, clusterID, nodePoolID,
watcher.WithRebootTimeWindowMinutes(rebootTimeWindowMinutes),
watcher.WithDesiredGPUCount(nodeDesiredGPUCount),
)
if err != nil {
return err
Expand Down
16 changes: 16 additions & 0 deletions pkg/watcher/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package watcher

import (
"log/slog"
"strconv"
"time"

Expand All @@ -13,6 +14,7 @@ type Option func(*watcher)

var defaultOptions = []Option{
WithRebootTimeWindowMinutes("40"),
WithDesiredGPUCount("0"),
}

// WithKubernetesClient returns Option to set Kubernetes API client.
Expand Down Expand Up @@ -48,6 +50,20 @@ func WithRebootTimeWindowMinutes(s string) Option {
n, err := strconv.Atoi(s)
if err == nil && n > 0 {
w.rebootTimeWindowMinutes = time.Duration(n)
} else {
slog.Info("RebootTimeWindowMinutes is invalid", "value", s)
}
}
}

// WithDesiredGPUCount returns Option to set desired GPU count .
func WithDesiredGPUCount(s string) Option {
return func(w *watcher) {
n, err := strconv.Atoi(s)
if err == nil && n >= 0 {
w.nodeDesiredGPUCount = n
} else {
slog.Info("DesiredGPUCount is invalid", "value", s)
}
}
}
10 changes: 2 additions & 8 deletions pkg/watcher/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type watcher struct {
nodeSelector *metav1.LabelSelector
}

func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID, nodeDesiredGPUCount string, opts ...Option) (Watcher, error) {
func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePoolID string, opts ...Option) (Watcher, error) {
w := &watcher{
clusterID: clusterID,
apiKey: apiKey,
Expand All @@ -63,12 +63,6 @@ func NewWatcher(ctx context.Context, apiURL, apiKey, region, clusterID, nodePool
return nil, fmt.Errorf("CIVO_API_KEY not set")
}

n, err := strconv.Atoi(nodeDesiredGPUCount)
if err != nil {
return nil, fmt.Errorf("CIVO_NODE_DESIRED_GPU_COUNT has an invalid value, %s: %w", nodeDesiredGPUCount, err)
}

w.nodeDesiredGPUCount = n
w.nodeSelector = &metav1.LabelSelector{
MatchLabels: map[string]string{
nodePoolLabelKey: nodePoolID,
Expand Down Expand Up @@ -212,7 +206,7 @@ func isNodeReady(node *corev1.Node) bool {

func isNodeDesiredGPU(node *corev1.Node, desired int) bool {
if desired == 0 {
slog.Info("Desired GPU count is set to 0", "node", node.GetName())
slog.Info("Desired GPU count is set to 0, so the GPU count check is skipped", "node", node.GetName())
return true
}

Expand Down
138 changes: 82 additions & 56 deletions pkg/watcher/watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ var (

func TestNew(t *testing.T) {
type args struct {
clusterID string
region string
apiKey string
apiURL string
nodePoolID string
nodeDesiredGPUCount string
opts []Option
clusterID string
region string
apiKey string
apiURL string
nodePoolID string
opts []Option
}
type test struct {
name string
Expand All @@ -47,17 +46,15 @@ func TestNew(t *testing.T) {
{
name: "Returns no error when given valid input",
args: args{
clusterID: testClusterID,
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
nodeDesiredGPUCount: testNodeDesiredGPUCount,
clusterID: testClusterID,
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used.
WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used.
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
},
checkFunc: func(w *watcher) error {
Expand Down Expand Up @@ -97,51 +94,66 @@ func TestNew(t *testing.T) {
},
},
{
name: "Returns an error when clusterID is missing",
name: "Returns no error when input is invalid, but default value is set",
args: args{
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
nodeDesiredGPUCount: testNodeDesiredGPUCount,
clusterID: testClusterID,
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount("invalid"), // It is invalid, but the default count (0) will be used.
WithDesiredGPUCount("-1"), // It is invalid, but the default count (0) will be used.
WithRebootTimeWindowMinutes("invalid time"), // It is invalid, but the default time (40) will be used.
WithRebootTimeWindowMinutes("0"), // It is invalid, but the default time (40) will be used.
},
},
wantErr: true,
checkFunc: func(w *watcher) error {
if w.nodeDesiredGPUCount != 0 {
return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0)
}
if w.rebootTimeWindowMinutes != testRebootTimeWindowMinutes {
return fmt.Errorf("w.rebootTimeWindowMinutes mismatch: got %v, want %s", w.nodeSelector, testNodePoolID)
}
return nil
},
},
{
name: "Returns an error when nodeDesiredGPUCount is invalid",
name: "Returns no error when nodeDesiredGPUCount is 0",
args: args{
clusterID: testClusterID,
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
nodeDesiredGPUCount: "invalid_number",
clusterID: testClusterID,
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount("0"),
},
},
wantErr: true,
checkFunc: func(w *watcher) error {
if w.nodeDesiredGPUCount != 0 {
return fmt.Errorf("w.nodeDesiredGPUCount mismatch: got %d, want %d", w.nodeDesiredGPUCount, 0)
}
return nil
},
},
{
name: "Returns an error when nodeDesiredGPUCount is 0",
name: "Returns an error when clusterID is missing",
args: args{
clusterID: testClusterID,
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
nodeDesiredGPUCount: "0",
region: testRegion,
apiKey: testApiKey,
apiURL: testApiURL,
nodePoolID: testNodePoolID,
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
},
},
wantErr: false,
wantErr: true,
},
}

Expand All @@ -153,7 +165,6 @@ func TestNew(t *testing.T) {
test.args.region,
test.args.clusterID,
test.args.nodePoolID,
test.args.nodeDesiredGPUCount,
test.args.opts...)
if (err != nil) != test.wantErr {
t.Errorf("error = %v, wantErr %v", err, test.wantErr)
Expand All @@ -177,9 +188,8 @@ func TestNew(t *testing.T) {

func TestRun(t *testing.T) {
type args struct {
opts []Option
nodeDesiredGPUCount string
nodePoolID string
opts []Option
nodePoolID string
}
type test struct {
name string
Expand All @@ -195,9 +205,9 @@ func TestRun(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodeDesiredGPUCount: testNodeDesiredGPUCount,
nodePoolID: testNodePoolID,
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
Expand Down Expand Up @@ -241,9 +251,9 @@ func TestRun(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodeDesiredGPUCount: testNodeDesiredGPUCount,
nodePoolID: testNodePoolID,
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
Expand Down Expand Up @@ -298,9 +308,9 @@ func TestRun(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodeDesiredGPUCount: testNodeDesiredGPUCount,
nodePoolID: testNodePoolID,
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
Expand Down Expand Up @@ -351,9 +361,9 @@ func TestRun(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodeDesiredGPUCount: testNodeDesiredGPUCount,
nodePoolID: testNodePoolID,
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
Expand Down Expand Up @@ -394,9 +404,9 @@ func TestRun(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodeDesiredGPUCount: testNodeDesiredGPUCount,
nodePoolID: testNodePoolID,
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
Expand All @@ -415,9 +425,9 @@ func TestRun(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
nodeDesiredGPUCount: testNodeDesiredGPUCount,
nodePoolID: testNodePoolID,
nodePoolID: testNodePoolID,
},
beforeFunc: func(w *watcher) {
t.Helper()
Expand Down Expand Up @@ -462,7 +472,7 @@ func TestRun(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w, err := NewWatcher(t.Context(),
testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.nodeDesiredGPUCount, test.args.opts...)
testApiURL, testApiKey, testRegion, testClusterID, test.args.nodePoolID, test.args.opts...)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -682,6 +692,19 @@ func TestIsNodeDesiredGPU(t *testing.T) {
desired: 8,
want: true,
},
{
name: "Returns true when desired GPU count is 0, so count check is skipped",
node: &corev1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: "node-01",
},
Status: corev1.NodeStatus{
Allocatable: corev1.ResourceList{},
},
},
desired: 0,
want: true,
},
{
name: "Returns false when GPU count is 0",
node: &corev1.Node{
Expand Down Expand Up @@ -744,6 +767,7 @@ func TestRebootNode(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
},
beforeFunc: func(t *testing.T, w *watcher) {
Expand Down Expand Up @@ -772,6 +796,7 @@ func TestRebootNode(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
},
beforeFunc: func(t *testing.T, w *watcher) {
Expand All @@ -791,6 +816,7 @@ func TestRebootNode(t *testing.T) {
opts: []Option{
WithKubernetesClient(fake.NewSimpleClientset()),
WithCivoClient(&FakeClient{}),
WithDesiredGPUCount(testNodeDesiredGPUCount),
},
},
beforeFunc: func(t *testing.T, w *watcher) {
Expand Down Expand Up @@ -818,7 +844,7 @@ func TestRebootNode(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
w, err := NewWatcher(t.Context(),
testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, testNodeDesiredGPUCount, test.args.opts...)
testApiURL, testApiKey, testRegion, testClusterID, testNodePoolID, test.args.opts...)
if err != nil {
t.Fatal(err)
}
Expand Down