Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions e2e/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,18 @@ var CachedVMSizeSupportsNVMe = cachedFunc(func(ctx context.Context, req VMSizeSK
var CachedIsVMSizeGen2Only = cachedFunc(func(ctx context.Context, req VMSizeSKURequest) (bool, error) {
return config.Azure.IsVMSizeGen2Only(ctx, req.Location, req.VMSize)
})

// GetLatestExtensionVersionRequest is the cache key for VM extension version lookups.
type GetLatestExtensionVersionRequest struct {
Location string
ExtType string
Publisher string
}

// CachedGetLatestVMExtensionImageVersion caches the result of querying the Azure API
// for the latest VM extension image version.
var CachedGetLatestVMExtensionImageVersion = cachedFunc(
func(ctx context.Context, req GetLatestExtensionVersionRequest) (string, error) {
return config.Azure.GetLatestVMExtensionImageVersion(ctx, req.Location, req.ExtType, req.Publisher)
},
)
Comment thread
surajssd marked this conversation as resolved.
123 changes: 123 additions & 0 deletions e2e/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package e2e

import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_cachedFunc_returns_consistent_results(t *testing.T) {
var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
callCount.Add(1)
return "result-" + key, nil
})

ctx := context.Background()
first, err := fn(ctx, "a")
require.NoError(t, err)

second, err := fn(ctx, "a")
require.NoError(t, err)

assert.Equal(t, first, second, "cached function should return the same result on repeated calls")
assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once for the same key")
}

func Test_cachedFunc_warm_call_is_faster_than_cold(t *testing.T) {
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
// simulate a slow operation like a network call
time.Sleep(10 * time.Millisecond)
return "result", nil
})

ctx := context.Background()

start := time.Now()
_, err := fn(ctx, "key")
coldDuration := time.Since(start)
require.NoError(t, err)

start = time.Now()
_, err = fn(ctx, "key")
warmDuration := time.Since(start)
require.NoError(t, err)

assert.Less(t, warmDuration, coldDuration, "warm (cached) call should be faster than cold call")
Comment on lines +33 to +51
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion makes the unit test depend on timing. Even with the 10ms sleep, timing-based comparisons can be flaky under load (GC pauses, scheduler delays). Consider asserting caching via call counts (as done in the other tests) instead of comparing durations, or use a much larger margin to reduce flake risk.

Suggested change
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
// simulate a slow operation like a network call
time.Sleep(10 * time.Millisecond)
return "result", nil
})
ctx := context.Background()
start := time.Now()
_, err := fn(ctx, "key")
coldDuration := time.Since(start)
require.NoError(t, err)
start = time.Now()
_, err = fn(ctx, "key")
warmDuration := time.Since(start)
require.NoError(t, err)
assert.Less(t, warmDuration, coldDuration, "warm (cached) call should be faster than cold call")
var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
// simulate a slow operation like a network call
callCount.Add(1)
time.Sleep(10 * time.Millisecond)
return "result", nil
})
ctx := context.Background()
_, err := fn(ctx, "key")
require.NoError(t, err)
_, err = fn(ctx, "key")
require.NoError(t, err)
assert.Equal(t, int32(1), callCount.Load(), "warm (cached) call should not invoke the slow underlying function again for the same key")

Copilot uses AI. Check for mistakes.
}
Comment on lines +32 to +52
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test_cachedFunc_warm_call_is_faster_than_cold compares wall-clock durations to assert caching, which can be flaky due to scheduler noise (warm call can occasionally be slower). A more deterministic assertion is to track underlying function invocations (e.g., via an atomic counter) and assert it’s called exactly once per key, optionally keeping the sleep to make the cold path slower but not asserting on timing.

Copilot uses AI. Check for mistakes.

func Test_cachedFunc_different_keys_produce_different_cache_entries(t *testing.T) {
var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
callCount.Add(1)
return "result-" + key, nil
})

ctx := context.Background()

resultA, err := fn(ctx, "a")
require.NoError(t, err)

resultB, err := fn(ctx, "b")
require.NoError(t, err)

assert.Equal(t, "result-a", resultA)
assert.Equal(t, "result-b", resultB)
assert.NotEqual(t, resultA, resultB, "different keys should produce different results")
assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique key")
}

func Test_cachedFunc_caches_errors(t *testing.T) {
var callCount atomic.Int32
expectedErr := fmt.Errorf("something went wrong")
fn := cachedFunc(func(ctx context.Context, key string) (string, error) {
callCount.Add(1)
return "", expectedErr
})

ctx := context.Background()

_, err1 := fn(ctx, "a")
require.ErrorIs(t, err1, expectedErr)

_, err2 := fn(ctx, "a")
require.ErrorIs(t, err2, expectedErr)

assert.Equal(t, int32(1), callCount.Load(), "underlying function should only be called once even when it returns an error")
}

func Test_cachedFunc_with_struct_key(t *testing.T) {
type request struct {
Location string
Type string
}

var callCount atomic.Int32
fn := cachedFunc(func(ctx context.Context, req request) (string, error) {
callCount.Add(1)
return req.Location + "-" + req.Type, nil
})

ctx := context.Background()

r1, err := fn(ctx, request{Location: "eastus", Type: "ext1"})
require.NoError(t, err)
assert.Equal(t, "eastus-ext1", r1)

// same key should return cached result
r2, err := fn(ctx, request{Location: "eastus", Type: "ext1"})
require.NoError(t, err)
assert.Equal(t, r1, r2)

// different key should call the function again
r3, err := fn(ctx, request{Location: "westus", Type: "ext1"})
require.NoError(t, err)
assert.Equal(t, "westus-ext1", r3)

assert.Equal(t, int32(2), callCount.Load(), "underlying function should be called once per unique struct key")
}
79 changes: 52 additions & 27 deletions e2e/config/azure.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package config

import (
"cmp"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"sort"
"slices"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -745,75 +746,99 @@ func (a *AzureClient) DeleteSnapshot(ctx context.Context, resourceGroupName, sna
return nil
}

// vmExtensionImageVersionLister abstracts the ListVersions method of the VM extension images client for testability.
type vmExtensionImageVersionLister interface {
ListVersions(ctx context.Context, location string, publisherName string, typeParam string,
options *armcompute.VirtualMachineExtensionImagesClientListVersionsOptions,
) (armcompute.VirtualMachineExtensionImagesClientListVersionsResponse, error)
}

// GetLatestVMExtensionImageVersion lists VM extension images for a given extension name and returns the latest version.
// This is equivalent to: az vm extension image list -n Compute.AKS.Linux.AKSNode --latest
func (a *AzureClient) GetLatestVMExtensionImageVersion(ctx context.Context, location, extType, extPublisher string) (string, error) {
return getLatestVMExtensionImageVersion(ctx, a.VMExtensionImages, location, extType, extPublisher)
}

// getLatestVMExtensionImageVersion lists VM extension images using the provided lister and returns the latest version.
func getLatestVMExtensionImageVersion(ctx context.Context, lister vmExtensionImageVersionLister, location, extType, extPublisher string) (string, error) {
// List extension versions
resp, err := a.VMExtensionImages.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{})
resp, err := lister.ListVersions(ctx, location, extPublisher, extType, &armcompute.VirtualMachineExtensionImagesClientListVersionsOptions{})
if err != nil {
return "", fmt.Errorf("listing extension versions: %w", err)
}

if len(resp.VirtualMachineExtensionImageArray) == 0 {
return "", fmt.Errorf("no extension versions found")
}

version := make([]VMExtenstionVersion, len(resp.VirtualMachineExtensionImageArray))
versions := make([]vmExtensionVersion, len(resp.VirtualMachineExtensionImageArray))
for i, ext := range resp.VirtualMachineExtensionImageArray {
version[i] = parseVersion(ext)
versions[i] = parseVersion(ctx, ext)
}

sort.Slice(version, func(i, j int) bool {
return version[i].Less(version[j])
latest := slices.MaxFunc(versions, func(a, b vmExtensionVersion) int {
return a.cmp(b)
})

return *version[len(version)-1].Original.Name, nil
if latest.original.Name == nil {
return "", fmt.Errorf("latest extension version has nil name")
Comment on lines +778 to +782
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getLatestVMExtensionImageVersion can return a malformed/non-numeric version name without error. parseVersion currently maps parse failures to 0.0.0, so if all returned names are malformed (or parse to 0), slices.MaxFunc will still pick one and the function will return its Name, bypassing the caller’s fallback-on-error behavior. Consider tracking whether a version parsed successfully and returning an error when no valid numeric version is found (so callers can reliably fall back).

Copilot uses AI. Check for mistakes.
}
return *latest.original.Name, nil
Comment thread
surajssd marked this conversation as resolved.
}

// VMExtenstionVersion represents a parsed version of a VM extension image.
type VMExtenstionVersion struct {
Original *armcompute.VirtualMachineExtensionImage
Major int
Minor int
Patch int
// vmExtensionVersion represents a parsed version of a VM extension image.
type vmExtensionVersion struct {
original *armcompute.VirtualMachineExtensionImage
major int
minor int
patch int
}

// parseVersion parses the version from a VM extension image name, which can be in the format 1.151, 1.0.1, etc.
// You can find all the versions of a specific VM extension by running:
// az vm extension image list -n Compute.AKS.Linux.AKSNode
func parseVersion(v *armcompute.VirtualMachineExtensionImage) VMExtenstionVersion {
func parseVersion(ctx context.Context, v *armcompute.VirtualMachineExtensionImage) vmExtensionVersion {
version := vmExtensionVersion{original: v}
if v.Name == nil {
toolkit.Logf(ctx, "warning: VM extension image has nil name, skipping version parse")
return version
}

// Split by dots
parts := strings.Split(*v.Name, ".")

version := VMExtenstionVersion{Original: v}

if len(parts) >= 1 {
if major, err := strconv.Atoi(parts[0]); err == nil {
version.Major = major
version.major = major
} else {
toolkit.Logf(ctx, "warning: failed to parse major version from %q: %v", *v.Name, err)
}
}
if len(parts) >= 2 {
if minor, err := strconv.Atoi(parts[1]); err == nil {
version.Minor = minor
version.minor = minor
} else {
toolkit.Logf(ctx, "warning: failed to parse minor version from %q: %v", *v.Name, err)
}
}
if len(parts) >= 3 {
if patch, err := strconv.Atoi(parts[2]); err == nil {
version.Patch = patch
version.patch = patch
} else {
toolkit.Logf(ctx, "warning: failed to parse patch version from %q: %v", *v.Name, err)
}
Comment thread
surajssd marked this conversation as resolved.
}

return version
}

func (v VMExtenstionVersion) Less(other VMExtenstionVersion) bool {
if v.Major != other.Major {
return v.Major < other.Major
// cmp compares two versions, returning -1, 0, or 1.
func (v vmExtensionVersion) cmp(other vmExtensionVersion) int {
if c := cmp.Compare(v.major, other.major); c != 0 {
return c
}
if v.Minor != other.Minor {
return v.Minor < other.Minor
if c := cmp.Compare(v.minor, other.minor); c != 0 {
return c
}
return v.Patch < other.Patch
return cmp.Compare(v.patch, other.patch)
}

// getResourceSKU queries the Azure Resource SKUs API to find the SKU for the given VM size and location.
Expand Down
Loading
Loading