diff --git a/README.md b/README.md index 32e9a1ed..3039485f 100644 --- a/README.md +++ b/README.md @@ -257,6 +257,28 @@ acr purge \ --include-locked ``` +#### ABAC (Attribute-Based Access Control) registries + +Registries with ABAC enabled use repository-scoped permissions instead of registry-wide roles. When using `acr purge` with an ABAC registry, keep the following in mind: + +**Required permissions:** +- **Catalog listing:** The user must have permission to list repositories (e.g., the `Container Registry Repository Catalog Lister` role or equivalent `registry:catalog:*` scope). +- **Repository access:** The user needs the `Container Registry Repository Contributor` role for deletes, which can be scoped to specific repositories using ABAC conditions. + +**Partial access behavior:** + +If a broad `--filter` matches repositories that the user does not have permission to purge, the command will stop at the first unauthorized repository and report: +- Which repository failed due to insufficient permissions +- Which repositories were already successfully purged +- Which repositories were not yet processed + +To avoid this, use a more specific `--filter` to target only repositories you have access to. + +**Batch size (environment variable):** + +ABAC registries process repositories in batches, where each batch shares a single token scope. Token refresh happens dynamically when API calls detect token expiration. The batch size can be configured via the `ABAC_BATCH_SIZE` environment variable (default: 10). + + ### Integration with ACR Tasks To run a locally built version of the ACR-CLI using ACR Tasks follow these steps: diff --git a/cmd/acr/purge.go b/cmd/acr/purge.go index 837b1518..ec28087f 100644 --- a/cmd/acr/purge.go +++ b/cmd/acr/purge.go @@ -5,10 +5,13 @@ package main import ( "context" + "errors" "fmt" "net/http" + "os" "runtime" "sort" + "strconv" "strings" "time" @@ -16,6 +19,7 @@ import ( "github.com/Azure/acr-cli/cmd/repository" "github.com/Azure/acr-cli/internal/api" "github.com/Azure/acr-cli/internal/worker" + "github.com/Azure/go-autorest/autorest" "github.com/dlclark/regexp2" "github.com/spf13/cobra" ) @@ -89,6 +93,7 @@ type purgeParameters struct { includeLocked bool concurrency int repoPageSize int32 + verbose bool } // newPurgeCmd defines the purge command. @@ -178,9 +183,9 @@ func newPurgeCmd(rootParams *rootParameters) *cobra.Command { // Combine flags for clarity - these are mutually exclusive supportUntaggedCleanup := purgeParams.untagged || purgeParams.untaggedOnly - deletedTagsCount, deletedManifestsCount, err := purge(ctx, acrClient, loginURL, repoParallelism, agoDuration, purgeParams.keep, purgeParams.filterTimeout, supportUntaggedCleanup, purgeParams.untaggedOnly, tagFilters, purgeParams.dryRun, purgeParams.includeLocked) + deletedTagsCount, deletedManifestsCount, err := purge(ctx, acrClient, loginURL, repoParallelism, agoDuration, purgeParams.keep, purgeParams.filterTimeout, supportUntaggedCleanup, purgeParams.untaggedOnly, tagFilters, purgeParams.dryRun, purgeParams.includeLocked, purgeParams.verbose) - if err != nil { + if err != nil && !strings.Contains(err.Error(), "insufficient permissions") { fmt.Printf("Failed to complete purge: %v \n", err) } @@ -208,6 +213,7 @@ func newPurgeCmd(rootParams *rootParameters) *cobra.Command { cmd.Flags().Int64Var(&purgeParams.filterTimeout, "filter-timeout-seconds", defaultRegexpMatchTimeoutSeconds, "This limits the evaluation of the regex filter, and will return a timeout error if this duration is exceeded during a single evaluation. If written incorrectly a regexp filter with backtracking can result in an infinite loop.") cmd.Flags().IntVar(&purgeParams.concurrency, "concurrency", defaultPoolSize, concurrencyDescription) cmd.Flags().Int32Var(&purgeParams.repoPageSize, "repository-page-size", defaultRepoPageSize, repoPageSizeDescription) + cmd.Flags().BoolVar(&purgeParams.verbose, "verbose", false, "Enable verbose output including detailed repository names during ABAC token operations") cmd.Flags().BoolP("help", "h", false, "Print usage") // Make filter and ago conditionally required based on untagged-only flag cmd.MarkFlagsOneRequired("filter", "untagged-only") @@ -226,36 +232,92 @@ func purge(ctx context.Context, untaggedOnly bool, tagFilters map[string]string, dryRun bool, - includeLocked bool) (deletedTagsCount int, deletedManifestsCount int, err error) { - - // In order to print a summary of the deleted tags/manifests the counters get updated everytime a repo is purged. - for repoName, tagRegex := range tagFilters { - var singleDeletedTagsCount int - var manifestToTagsCountMap map[string]int - - // Handle tag deletion based on mode - if untaggedOnly { - // Initialize empty map for untagged-only mode (no tag deletion) - manifestToTagsCountMap = make(map[string]int) - } else { - // Standard mode: delete matching tags first - singleDeletedTagsCount, manifestToTagsCountMap, err = purgeTags(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, tagRegex, keep, filterTimeout, dryRun, includeLocked) - if err != nil { - return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge tags: %w", err) + includeLocked bool, + verbose bool) (deletedTagsCount int, deletedManifestsCount int, err error) { + + // Load ABAC batch size from environment variable + abacBatchSize := 10 // default + if envVal, exists := os.LookupEnv("ABAC_BATCH_SIZE"); exists { + if parsed, err := strconv.Atoi(envVal); err == nil && parsed > 0 { + abacBatchSize = parsed + } + } + + // Collect all repository names into a slice for batching + repos := make([]string, 0, len(tagFilters)) + for repoName := range tagFilters { + repos = append(repos, repoName) + } + + // Track which repositories have been successfully processed for error reporting. + var completedRepos []string + + // Process repositories in batches of abacBatchSize. + // For ABAC-enabled registries, we set the current repositories for the batch so that + // token refresh happens dynamically when needed (on API calls that detect token expiration). + // For non-ABAC registries, the batching loop is harmless (no special token handling needed). + for i := 0; i < len(repos); i += abacBatchSize { + end := i + abacBatchSize + if end > len(repos) { + end = len(repos) + } + batch := repos[i:end] + + // For ABAC registries, refresh the token with scopes for this batch of repositories. + // ABAC registries don't support wildcard repository scopes, so we must explicitly + // request access for each repository before operating on it. + if acrClient.IsAbac() { + if err := acrClient.RefreshTokenForAbac(ctx, batch); err != nil { + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to refresh ABAC token for batch: %w", err) + } + if verbose { + fmt.Printf("ABAC: Setting token scope for %d repositories: %v\n", len(batch), batch) + } else { + fmt.Printf("ABAC: Setting token scope for %d repositories\n", len(batch)) } } - singleDeletedManifestsCount := 0 - // If the untagged flag is set or untagged-only mode is enabled, delete manifests - if removeUntaggedManifests { - singleDeletedManifestsCount, err = purgeDanglingManifests(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, keep, manifestToTagsCountMap, dryRun, includeLocked) - if err != nil { - return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge manifests: %w", err) + // Process all repositories in this batch + for _, repoName := range batch { + tagRegex := tagFilters[repoName] + var singleDeletedTagsCount int + var manifestToTagsCountMap map[string]int + + // Handle tag deletion based on mode + if untaggedOnly { + // Initialize empty map for untagged-only mode (no tag deletion) + manifestToTagsCountMap = make(map[string]int) + } else { + // Standard mode: delete matching tags first + singleDeletedTagsCount, manifestToTagsCountMap, err = purgeTags(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, tagRegex, keep, filterTimeout, dryRun, includeLocked) + if err != nil { + if isUnauthorizedError(err) { + remainingRepos := repos[i+indexOf(batch, repoName):] + return deletedTagsCount, deletedManifestsCount, + formatPermissionError(repoName, "purge tags", completedRepos, remainingRepos) + } + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge tags: %w", err) + } + } + + singleDeletedManifestsCount := 0 + // If the untagged flag is set or untagged-only mode is enabled, delete manifests + if removeUntaggedManifests { + singleDeletedManifestsCount, err = purgeDanglingManifests(ctx, acrClient, repoParallelism, loginURL, repoName, agoDuration, keep, manifestToTagsCountMap, dryRun, includeLocked) + if err != nil { + if isUnauthorizedError(err) { + remainingRepos := repos[i+indexOf(batch, repoName):] + return deletedTagsCount, deletedManifestsCount, + formatPermissionError(repoName, "purge manifests", completedRepos, remainingRepos) + } + return deletedTagsCount, deletedManifestsCount, fmt.Errorf("failed to purge manifests: %w", err) + } } + // After every repository is purged the counters are updated. + deletedTagsCount += singleDeletedTagsCount + deletedManifestsCount += singleDeletedManifestsCount + completedRepos = append(completedRepos, repoName) } - // After every repository is purged the counters are updated. - deletedTagsCount += singleDeletedTagsCount - deletedManifestsCount += singleDeletedManifestsCount } return deletedTagsCount, deletedManifestsCount, nil @@ -563,3 +625,51 @@ func purgeDanglingManifests(ctx context.Context, acrClient api.AcrCLIClientInter } return deletedManifestsCount, nil } + +// isUnauthorizedError checks if an error is an HTTP 401 Unauthorized response. +// This is used to detect permission failures on ABAC-enabled registries where +// the user may have access to some repositories but not others. +func isUnauthorizedError(err error) bool { + if err == nil { + return false + } + var detailedErr autorest.DetailedError + if errors.As(err, &detailedErr) { + if statusCode, ok := detailedErr.StatusCode.(int); ok { + return statusCode == http.StatusUnauthorized + } + } + return strings.Contains(err.Error(), "StatusCode=401") +} + +// formatPermissionError builds a clear error message when a purge operation fails +// due to insufficient permissions on a repository. It reports which repository +// failed, which repositories were already processed, and which remain untouched. +func formatPermissionError(failedRepo string, operation string, completedRepos []string, remainingRepos []string) error { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("insufficient permissions to %s for repository %q", operation, failedRepo)) + + if len(completedRepos) > 0 { + sb.WriteString(fmt.Sprintf("\n Completed repositories (%d): %s", len(completedRepos), strings.Join(completedRepos, ", "))) + } else { + sb.WriteString("\n Completed repositories: none") + } + + // remainingRepos includes the failed repo; show the ones after it as not yet processed + if len(remainingRepos) > 1 { + sb.WriteString(fmt.Sprintf("\n Remaining repositories not yet processed (%d): %s", len(remainingRepos)-1, strings.Join(remainingRepos[1:], ", "))) + } + + sb.WriteString("\n Hint: use a more specific --filter to target only repositories you have permissions for") + return errors.New(sb.String()) +} + +// indexOf returns the index of s in slice, or 0 if not found. +func indexOf(slice []string, s string) int { + for i, v := range slice { + if v == s { + return i + } + } + return 0 +} diff --git a/cmd/acr/purge_test.go b/cmd/acr/purge_test.go index 912d3d7b..20fb3b05 100644 --- a/cmd/acr/purge_test.go +++ b/cmd/acr/purge_test.go @@ -552,9 +552,13 @@ func TestDryRun(t *testing.T) { t.Run("RepositoryNotFoundTest", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + // Mock IsAbac to return false (non-ABAC registry) to use standard wildcard token flow + mockClient.On("IsAbac").Return(false) + // Need a .Maybe() since it's only called for ABAC registries (this test mocks IsAbac to return false) + mockClient.On("IsTokenExpired").Return(false).Maybe() mockClient.On("GetAcrManifests", mock.Anything, testRepo, "", "").Return(notFoundManifestResponse, errors.New("testRepo not found")).Once() mockClient.On("GetAcrTags", mock.Anything, testRepo, "timedesc", "").Return(notFoundTagResponse, errors.New("testRepo not found")).Once() - deletedTags, deletedManifests, err := purge(testCtx, mockClient, testLoginURL, 60, -24*time.Hour, 0, 1, true, false, map[string]string{testRepo: "[\\s\\S]*"}, true, false) + deletedTags, deletedManifests, err := purge(testCtx, mockClient, testLoginURL, 60, -24*time.Hour, 0, 1, true, false, map[string]string{testRepo: "[\\s\\S]*"}, true, false, false) assert.Equal(0, deletedTags, "Number of deleted elements should be 0") assert.Equal(0, deletedManifests, "Number of deleted elements should be 0") assert.Equal(nil, err, "Error should be nil") diff --git a/cmd/acr/purge_untagged_only_test.go b/cmd/acr/purge_untagged_only_test.go index 4d59f23c..baf3c3b6 100644 --- a/cmd/acr/purge_untagged_only_test.go +++ b/cmd/acr/purge_untagged_only_test.go @@ -3,8 +3,11 @@ package main import ( + "bytes" "context" + "io" "net/http" + "os" "testing" "github.com/Azure/acr-cli/acr" @@ -25,6 +28,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyPurgeManifestsOnly", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Setup mock response for manifests without tags manifestDigest := "sha256:abc123" @@ -85,6 +90,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, false, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted in untagged-only mode") @@ -97,6 +103,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyNoFilterAllRepos", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // We won't test GetRepositories here since the purge function is called // with already-created tagFilters. Instead test that all repos are processed. @@ -137,6 +145,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { tagFilters, false, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted") @@ -149,6 +158,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithFilter", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() manifestDigest := "sha256:def456" mediaType := "application/vnd.docker.distribution.manifest.v2+json" @@ -206,6 +217,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{"specific-repo": ".*"}, false, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted in untagged-only mode") @@ -218,6 +230,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyDryRun", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() manifestDigest := "sha256:ghi789" mediaType := "application/vnd.docker.distribution.manifest.v2+json" @@ -270,6 +284,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, true, // dryRun false, // includeLocked + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted in dry-run") @@ -282,6 +297,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithLockedManifests", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Create locked and unlocked untagged manifests lockedDigest := "sha256:locked123" @@ -351,6 +368,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, false, // dryRun false, // includeLocked = false + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted") @@ -363,6 +381,8 @@ func TestPurgeUntaggedOnly(t *testing.T) { t.Run("UntaggedOnlyWithIncludeLocked", func(t *testing.T) { assert := assert.New(t) mockClient := &mocks.AcrCLIClientInterface{} + mockClient.On("IsAbac").Return(false) + mockClient.On("IsTokenExpired").Return(false).Maybe() // Create locked untagged manifest lockedDigest := "sha256:locked789" @@ -429,6 +449,7 @@ func TestPurgeUntaggedOnly(t *testing.T) { map[string]string{testRepo: ".*"}, false, // dryRun true, // includeLocked = true + false, // verbose ) assert.Equal(0, deletedTagsCount, "No tags should be deleted") @@ -750,3 +771,194 @@ func TestPurgeDanglingManifestsWithAgoAndKeep(t *testing.T) { mockClient.AssertExpectations(t) }) } + +// TestPurgeAbacVerboseMode tests the verbose output for ABAC registries +func TestPurgeAbacVerboseMode(t *testing.T) { + testCtx := context.Background() + testLoginURL := "registry.azurecr.io" + defaultPoolSize := 1 + + // Test: ABAC verbose mode should output repository names + t.Run("AbacVerboseModeOutputsRepoNames", func(t *testing.T) { + assert := assert.New(t) + mockClient := &mocks.AcrCLIClientInterface{} + + // Mock ABAC registry + mockClient.On("IsAbac").Return(true) + mockClient.On("RefreshTokenForAbac", mock.Anything, mock.Anything).Return(nil) + + // Empty manifests result for each repo + emptyManifestsResult := &acr.Manifests{ + Response: autorest.Response{ + Response: &http.Response{ + StatusCode: 200, + }, + }, + ManifestsAttributes: &[]acr.ManifestAttributesBase{}, + } + + repos := []string{"repo1", "repo2", "repo3"} + for _, repo := range repos { + mockClient.On("GetAcrManifests", mock.Anything, repo, "", "").Return(emptyManifestsResult, nil).Once() + } + + tagFilters := make(map[string]string) + for _, repo := range repos { + tagFilters[repo] = ".*" + } + + // Capture stdout to verify verbose output + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Call purge with verbose=true and ABAC enabled + deletedTagsCount, deletedManifestsCount, err := purge( + testCtx, + mockClient, + testLoginURL, + defaultPoolSize, + 0, // ago + 0, // keep + 60, // filterTimeout + true, // removeUntaggedManifests + true, // untaggedOnly + tagFilters, + false, // dryRun + false, // includeLocked + true, // verbose = true + ) + + // Restore stdout and read captured output + w.Close() + os.Stdout = oldStdout + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + assert.Equal(0, deletedTagsCount, "No tags should be deleted") + assert.Equal(0, deletedManifestsCount, "No manifests deleted when none are untagged") + assert.Nil(err, "Error should be nil") + // Verify verbose output contains repository names + assert.Contains(output, "ABAC: Setting token scope for 3 repositories:", "Should output repo count") + assert.Contains(output, "repo1", "Should output repo1 in verbose mode") + assert.Contains(output, "repo2", "Should output repo2 in verbose mode") + assert.Contains(output, "repo3", "Should output repo3 in verbose mode") + mockClient.AssertCalled(t, "RefreshTokenForAbac", mock.Anything, mock.Anything) + mockClient.AssertExpectations(t) + }) + + // Test: ABAC non-verbose mode should only output count, not names + t.Run("AbacNonVerboseModeOutputsCountOnly", func(t *testing.T) { + assert := assert.New(t) + mockClient := &mocks.AcrCLIClientInterface{} + + // Mock ABAC registry + mockClient.On("IsAbac").Return(true) + mockClient.On("RefreshTokenForAbac", mock.Anything, mock.Anything).Return(nil) + + // Empty manifests result for each repo + emptyManifestsResult := &acr.Manifests{ + Response: autorest.Response{ + Response: &http.Response{ + StatusCode: 200, + }, + }, + ManifestsAttributes: &[]acr.ManifestAttributesBase{}, + } + + repos := []string{"repo1", "repo2", "repo3"} + for _, repo := range repos { + mockClient.On("GetAcrManifests", mock.Anything, repo, "", "").Return(emptyManifestsResult, nil).Once() + } + + tagFilters := make(map[string]string) + for _, repo := range repos { + tagFilters[repo] = ".*" + } + + // Capture stdout to verify non-verbose output + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Call purge with verbose=false and ABAC enabled + deletedTagsCount, deletedManifestsCount, err := purge( + testCtx, + mockClient, + testLoginURL, + defaultPoolSize, + 0, // ago + 0, // keep + 60, // filterTimeout + true, // removeUntaggedManifests + true, // untaggedOnly + tagFilters, + false, // dryRun + false, // includeLocked + false, // verbose = false + ) + + // Restore stdout and read captured output + w.Close() + os.Stdout = oldStdout + var buf bytes.Buffer + io.Copy(&buf, r) + output := buf.String() + + assert.Equal(0, deletedTagsCount, "No tags should be deleted") + assert.Equal(0, deletedManifestsCount, "No manifests deleted when none are untagged") + assert.Nil(err, "Error should be nil") + // Verify non-verbose output contains count but NOT the repository list + assert.Contains(output, "ABAC: Setting token scope for 3 repositories", "Should output repo count") + // The non-verbose output should NOT contain the bracketed list of repos + assert.NotContains(output, "[repo1", "Should NOT output repo list in non-verbose mode") + mockClient.AssertCalled(t, "RefreshTokenForAbac", mock.Anything, mock.Anything) + mockClient.AssertExpectations(t) + }) + + // Test: Non-ABAC registry should not call RefreshTokenForAbac + t.Run("NonAbacDoesNotCallRefreshTokenForAbac", func(t *testing.T) { + assert := assert.New(t) + mockClient := &mocks.AcrCLIClientInterface{} + + // Mock non-ABAC registry + mockClient.On("IsAbac").Return(false) + + // Empty manifests result + emptyManifestsResult := &acr.Manifests{ + Response: autorest.Response{ + Response: &http.Response{ + StatusCode: 200, + }, + }, + ManifestsAttributes: &[]acr.ManifestAttributesBase{}, + } + + mockClient.On("GetAcrManifests", mock.Anything, "test-repo", "", "").Return(emptyManifestsResult, nil).Once() + + // Call purge with verbose=true but non-ABAC registry + deletedTagsCount, deletedManifestsCount, err := purge( + testCtx, + mockClient, + testLoginURL, + defaultPoolSize, + 0, // ago + 0, // keep + 60, // filterTimeout + true, // removeUntaggedManifests + true, // untaggedOnly + map[string]string{"test-repo": ".*"}, + false, // dryRun + false, // includeLocked + true, // verbose = true + ) + + assert.Equal(0, deletedTagsCount, "No tags should be deleted") + assert.Equal(0, deletedManifestsCount, "No manifests deleted") + assert.Nil(err, "Error should be nil") + // Verify RefreshTokenForAbac was NOT called for non-ABAC + mockClient.AssertNotCalled(t, "RefreshTokenForAbac", mock.Anything, mock.Anything) + mockClient.AssertExpectations(t) + }) +} diff --git a/cmd/mocks/AcrCLIClientInterface.go b/cmd/mocks/AcrCLIClientInterface.go index d553d8d5..bac2214d 100644 --- a/cmd/mocks/AcrCLIClientInterface.go +++ b/cmd/mocks/AcrCLIClientInterface.go @@ -2,11 +2,15 @@ package mocks -import acr "github.com/Azure/acr-cli/acr" +import ( + acr "github.com/Azure/acr-cli/acr" -import autorest "github.com/Azure/go-autorest/autorest" -import context "context" -import mock "github.com/stretchr/testify/mock" + autorest "github.com/Azure/go-autorest/autorest" + + context "context" + + mock "github.com/stretchr/testify/mock" +) // AcrCLIClientInterface is an autogenerated mock type for the AcrCLIClientInterface type type AcrCLIClientInterface struct { @@ -196,3 +200,50 @@ func (_m *AcrCLIClientInterface) UpdateAcrManifestAttributes(ctx context.Context return r0, r1 } + +// IsAbac provides a mock function that returns whether the registry is ABAC-enabled +func (_m *AcrCLIClientInterface) IsAbac() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsTokenExpired provides a mock function for checking if token is expired +func (_m *AcrCLIClientInterface) IsTokenExpired() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// RefreshTokenForAbac provides a mock function for refreshing tokens with specific repository scopes +func (_m *AcrCLIClientInterface) RefreshTokenForAbac(ctx context.Context, repositories []string) error { + ret := _m.Called(ctx, repositories) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []string) error); ok { + r0 = rf(ctx, repositories) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetCurrentRepositories provides a mock function for setting the current repositories for ABAC token scope +func (_m *AcrCLIClientInterface) SetCurrentRepositories(repositories []string) { + _m.Called(repositories) +} diff --git a/internal/api/acrsdk.go b/internal/api/acrsdk.go index 0b3a7c11..68580750 100644 --- a/internal/api/acrsdk.go +++ b/internal/api/acrsdk.go @@ -7,6 +7,7 @@ package api import ( "bytes" "context" + "fmt" "io/ioutil" "strings" "time" @@ -52,6 +53,13 @@ type AcrCLIClient struct { // accessTokenExp refers to the expiration time for the access token, it is in a unix time format represented by a // 64 bit integer. accessTokenExp int64 + // isAbac indicates whether this registry uses Attribute-Based Access Control (ABAC). + // ABAC registries require repository-level permissions instead of registry-wide wildcards. + // This is detected by checking if the refresh token contains the "aad_identity" claim. + isAbac bool + // currentRepositories holds the repository names for which the current ABAC token has permissions. + // This is used for dynamic token refresh when the token expires during operations. + currentRepositories []string } // LoginURL returns the FQDN for a registry. @@ -91,10 +99,18 @@ func newAcrCLIClientWithBasicAuth(loginURL string, username string, password str } // newAcrCLIClientWithBearerAuth creates a client that uses bearer token authentication. +// It detects ABAC-enabled registries via the "aad_identity" claim in the refresh token. +// It always requests both catalog and wildcard repository access; on ABAC registries the +// wildcard is silently ignored by the server, so callers must use RefreshTokenForAbac to +// request repository-specific scopes before accessing individual repositories. func newAcrCLIClientWithBearerAuth(loginURL string, refreshToken string) (AcrCLIClient, error) { newAcrCLIClient := newAcrCLIClient(loginURL) + newAcrCLIClient.isAbac = hasAadIdentityClaim(refreshToken) + ctx := context.Background() - accessTokenResponse, err := newAcrCLIClient.AutorestClient.GetAcrAccessToken(ctx, loginURL, "registry:catalog:* repository:*:*", refreshToken) + scope := "registry:catalog:* repository:*:*" + + accessTokenResponse, err := newAcrCLIClient.AutorestClient.GetAcrAccessToken(ctx, loginURL, scope, refreshToken) if err != nil { return newAcrCLIClient, err } @@ -154,8 +170,43 @@ func GetAcrCLIClientWithAuth(loginURL string, username string, password string, } // refreshAcrCLIClientToken obtains a new token and gets its expiration time. -func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient) error { - accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, "repository:*:*", c.token.RefreshToken) +// For non-ABAC registries, this uses the wildcard scope. +// For ABAC registries, this uses the currentRepositories to refresh with the appropriate scope. +func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient, repoName string) error { + var scope string + if c.isAbac { + // For ABAC registries, build scope from currentRepositories and ensure repoName is included + repoSet := make(map[string]bool) + for _, repo := range c.currentRepositories { + repoSet[repo] = true + } + // Ensure the current repoName is in the set + if repoName != "" { + repoSet[repoName] = true + } + var scopeParts []string + // "catalog" is a sentinel value meaning "include registry:catalog:* scope" + // (access to the catalog API for listing repositories). It must NOT be + // treated as a repository name, since "repository:catalog:..." would try + // to access a repository literally named "catalog". + if repoSet["catalog"] { + scopeParts = append(scopeParts, "registry:catalog:*") + delete(repoSet, "catalog") + } + for repo := range repoSet { + scopeParts = append(scopeParts, fmt.Sprintf("repository:%s:pull,delete,metadata_read,metadata_write", repo)) + } + if len(scopeParts) == 0 { + // Fallback: if no repositories specified, return error for ABAC + return errors.New("ABAC registry requires repository scope but none specified") + } + scope = strings.Join(scopeParts, " ") + } else { + // For non-ABAC registries, use the wildcard scope + scope = "repository:*:*" + } + + accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, scope, c.token.RefreshToken) if err != nil { return err } @@ -173,6 +224,90 @@ func refreshAcrCLIClientToken(ctx context.Context, c *AcrCLIClient) error { return nil } +// hasAadIdentityClaim checks if a JWT token contains the "aad_identity" claim. +// The presence of this claim indicates that the registry is ABAC-enabled. +// ABAC (Attribute-Based Access Control) registries grant permissions at the repository level, +// not at the registry level, so wildcard scopes like "repository:*:*" will not work. +func hasAadIdentityClaim(tokenString string) bool { + parser := jwt.Parser{SkipClaimsValidation: true} + mapC := jwt.MapClaims{} + // We only need to check for the claim, not verify the signature + _, _, err := parser.ParseUnverified(tokenString, mapC) + if err != nil { + return false + } + _, ok := mapC["aad_identity"] + return ok +} + +// SetCurrentRepositories sets the repositories for which ABAC token refresh should request permissions. +// This should be called before performing operations on repositories in ABAC-enabled registries. +// When the token expires, the refresh will automatically request permissions for these repositories. +func (c *AcrCLIClient) SetCurrentRepositories(repositories []string) { + c.currentRepositories = repositories +} + +// RefreshTokenForAbac obtains a new access token scoped to specific repositories. +// This is used for ABAC-enabled registries where wildcard repository access is not allowed. +// The token will include permissions for all specified repositories. +// It also updates currentRepositories so subsequent automatic refreshes use the same scope. +// +// Parameters: +// - repositories: list of repository names to request access for +// +// The scope format is: "registry:catalog:* repository::pull repository::delete ..." +// This allows batching multiple repositories into a single token request for efficiency. +func (c *AcrCLIClient) RefreshTokenForAbac(ctx context.Context, repositories []string) error { + if c.token == nil { + return errors.New("no refresh token available for ABAC token refresh") + } + + // Update the current repositories so automatic refreshes use the same scope + c.currentRepositories = repositories + + // Build the scope string for all requested repositories. + // Each repository needs pull, delete, and metadata permissions for purge operations. + // "catalog" is a sentinel value meaning "include registry:catalog:* scope" + // (access to the catalog API for listing repositories). It must NOT be + // treated as a repository name, since "repository:catalog:..." would try + // to access a repository literally named "catalog". + var scopeParts []string + for _, repo := range repositories { + if repo == "catalog" { + scopeParts = append(scopeParts, "registry:catalog:*") + } else { + scopeParts = append(scopeParts, fmt.Sprintf("repository:%s:pull,delete,metadata_read,metadata_write", repo)) + } + } + scope := strings.Join(scopeParts, " ") + + accessTokenResponse, err := c.AutorestClient.GetAcrAccessToken(ctx, c.loginURL, scope, c.token.RefreshToken) + if err != nil { + return errors.Wrap(err, "failed to refresh token for ABAC repositories") + } + + token := &adal.Token{ + AccessToken: *accessTokenResponse.AccessToken, + RefreshToken: c.token.RefreshToken, + } + c.token = token + c.AutorestClient.Authorizer = autorest.NewBearerAuthorizer(token) + + exp, err := getExpiration(token.AccessToken) + if err != nil { + return err + } + c.accessTokenExp = exp + + return nil +} + +// IsAbac returns true if this client is connected to an ABAC-enabled registry. +// ABAC registries require repository-level token scopes instead of wildcard scopes. +func (c *AcrCLIClient) IsAbac() bool { + return c.isAbac +} + // getExpiration is used to obtain the expiration out of a jwt token. func getExpiration(token string) (int64, error) { parser := jwt.Parser{SkipClaimsValidation: true} @@ -198,10 +333,17 @@ func (c *AcrCLIClient) isExpired() bool { return (time.Now().Add(5 * time.Minute)).Unix() > c.accessTokenExp } +// IsTokenExpired returns true when the token is expired or close to expiring. +// This is the public version of isExpired for use by callers that need to check +// token expiration before making batched ABAC token refresh requests. +func (c *AcrCLIClient) IsTokenExpired() bool { + return c.isExpired() +} + // GetAcrTags list the tags of a repository with their attributes. func (c *AcrCLIClient) GetAcrTags(ctx context.Context, repoName string, orderBy string, last string) (*acrapi.RepositoryTagsType, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -216,7 +358,7 @@ func (c *AcrCLIClient) GetAcrTags(ctx context.Context, repoName string, orderBy // DeleteAcrTag deletes the tag by reference. func (c *AcrCLIClient) DeleteAcrTag(ctx context.Context, repoName string, reference string) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -230,7 +372,7 @@ func (c *AcrCLIClient) DeleteAcrTag(ctx context.Context, repoName string, refere // GetAcrManifests list all the manifest in a repository with their attributes. func (c *AcrCLIClient) GetAcrManifests(ctx context.Context, repoName string, orderBy string, last string) (*acrapi.Manifests, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -244,7 +386,7 @@ func (c *AcrCLIClient) GetAcrManifests(ctx context.Context, repoName string, ord // DeleteManifest deletes a manifest using the digest as a reference. func (c *AcrCLIClient) DeleteManifest(ctx context.Context, repoName string, reference string) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -259,7 +401,7 @@ func (c *AcrCLIClient) DeleteManifest(ctx context.Context, repoName string, refe // This is used when a manifest list is wanted, first the bytes are obtained and then unmarshalled into a new struct. func (c *AcrCLIClient) GetManifest(ctx context.Context, repoName string, reference string) ([]byte, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -299,7 +441,7 @@ func (c *AcrCLIClient) GetManifest(ctx context.Context, repoName string, referen // GetAcrManifestAttributes gets the attributes of a manifest. func (c *AcrCLIClient) GetAcrManifestAttributes(ctx context.Context, repoName string, reference string) (*acrapi.ManifestAttributes, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -313,7 +455,7 @@ func (c *AcrCLIClient) GetAcrManifestAttributes(ctx context.Context, repoName st // UpdateAcrTagAttributes updates tag attributes to enable/disable deletion and writing. func (c *AcrCLIClient) UpdateAcrTagAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -327,7 +469,7 @@ func (c *AcrCLIClient) UpdateAcrTagAttributes(ctx context.Context, repoName stri // UpdateAcrManifestAttributes updates manifest attributes to enable/disable deletion and writing. func (c *AcrCLIClient) UpdateAcrManifestAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) { if c.isExpired() { - if err := refreshAcrCLIClientToken(ctx, c); err != nil { + if err := refreshAcrCLIClientToken(ctx, c, repoName); err != nil { return nil, err } } @@ -348,4 +490,13 @@ type AcrCLIClientInterface interface { GetAcrManifestAttributes(ctx context.Context, repoName string, reference string) (*acrapi.ManifestAttributes, error) UpdateAcrTagAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) UpdateAcrManifestAttributes(ctx context.Context, repoName string, reference string, value *acrapi.ChangeableAttributes) (*autorest.Response, error) + + // IsAbac returns true if the registry uses Attribute-Based Access Control. + IsAbac() bool + // IsTokenExpired returns true if the access token is expired or close to expiring. + IsTokenExpired() bool + // RefreshTokenForAbac refreshes the access token with scopes for specific repositories. + RefreshTokenForAbac(ctx context.Context, repositories []string) error + // SetCurrentRepositories sets the repositories for ABAC token refresh scope. + SetCurrentRepositories(repositories []string) } diff --git a/internal/api/acrsdk_test.go b/internal/api/acrsdk_test.go index 60224873..30d9c266 100644 --- a/internal/api/acrsdk_test.go +++ b/internal/api/acrsdk_test.go @@ -4,6 +4,7 @@ package api import ( + "context" "encoding/base64" "fmt" "net/http" @@ -234,3 +235,209 @@ func TestGetAcrCLIClientWithAuth(t *testing.T) { }) } } + +// TestHasAadIdentityClaim tests the ABAC detection function +func TestHasAadIdentityClaim(t *testing.T) { + tests := []struct { + name string + token string + expected bool + }{ + { + name: "token with aad_identity claim - ABAC enabled", + // JWT with {"aad_identity": "user@example.com"} in payload + token: strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":1563910981,"aad_identity":"user@example.com"}`)), + "", + }, "."), + expected: true, + }, + { + name: "token without aad_identity claim - non-ABAC", + // JWT without aad_identity + token: strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":1563910981}`)), + "", + }, "."), + expected: false, + }, + { + name: "invalid token", + token: "not-a-valid-jwt", + expected: false, + }, + { + name: "empty token", + token: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasAadIdentityClaim(tt.token) + if result != tt.expected { + t.Errorf("hasAadIdentityClaim() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// TestAcrCLIClientIsAbac tests the IsAbac method +func TestAcrCLIClientIsAbac(t *testing.T) { + tests := []struct { + name string + isAbac bool + expected bool + }{ + { + name: "ABAC enabled client", + isAbac: true, + expected: true, + }, + { + name: "non-ABAC client", + isAbac: false, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := AcrCLIClient{ + isAbac: tt.isAbac, + } + result := client.IsAbac() + if result != tt.expected { + t.Errorf("IsAbac() = %v, expected %v", result, tt.expected) + } + }) + } +} + +// TestRefreshAcrCLIClientTokenAbac tests the ABAC-aware token refresh path. +// This ensures that when SDK methods (GetAcrTags, DeleteAcrTag, etc.) detect token expiry, +// the refresh uses repository-scoped tokens for ABAC registries instead of wildcard scope. +func TestRefreshAcrCLIClientTokenAbac(t *testing.T) { + testAccessToken := strings.Join([]string{ + base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)), + base64.RawURLEncoding.EncodeToString([]byte(`{"exp":9999999999}`)), // Far future expiry + "", + }, ".") + testRefreshToken := "test/refresh/token" + + tests := []struct { + name string + isAbac bool + currentRepositories []string + repoName string + expectedScopePrefix string // What the scope should start with or contain + shouldContainRepo string // Specific repo that must be in scope + wantErr bool + }{ + { + name: "ABAC with currentRepositories and repoName - includes both", + isAbac: true, + currentRepositories: []string{"repo1", "repo2"}, + repoName: "repo3", + expectedScopePrefix: "repository:", + shouldContainRepo: "repo3", + wantErr: false, + }, + { + name: "ABAC with only repoName - uses repoName for scope", + isAbac: true, + currentRepositories: []string{}, + repoName: "my-repo", + expectedScopePrefix: "repository:my-repo:", + shouldContainRepo: "my-repo", + wantErr: false, + }, + { + name: "ABAC with repoName already in currentRepositories - no duplicate", + isAbac: true, + currentRepositories: []string{"repo1", "repo2"}, + repoName: "repo1", + expectedScopePrefix: "repository:", + shouldContainRepo: "repo1", + wantErr: false, + }, + { + name: "ABAC with no repos and no repoName - returns error", + isAbac: true, + currentRepositories: []string{}, + repoName: "", + wantErr: true, + }, + { + name: "Non-ABAC registry - uses wildcard scope", + isAbac: false, + currentRepositories: []string{}, + repoName: "any-repo", + expectedScopePrefix: "repository:*:*", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedScope string + + // Create a test server that captures the scope parameter + as := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusNotFound) + return + } + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + capturedScope = r.PostForm.Get("scope") + // Return a valid access token + fmt.Fprintf(w, `{"access_token":%q}`, testAccessToken) + })) + defer as.Close() + + // Create client with test configuration + client := newAcrCLIClient(as.URL) + client.isAbac = tt.isAbac + client.currentRepositories = tt.currentRepositories + client.token = &adal.Token{ + AccessToken: testAccessToken, + RefreshToken: testRefreshToken, + } + // Replace transport to trust test server + client.AutorestClient.Sender = as.Client() + + // Call refreshAcrCLIClientToken + err := refreshAcrCLIClientToken(context.Background(), &client, tt.repoName) + + // Check error expectation + if (err != nil) != tt.wantErr { + t.Errorf("refreshAcrCLIClientToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + // Verify the scope was correct + if tt.expectedScopePrefix != "" && !strings.Contains(capturedScope, tt.expectedScopePrefix) { + t.Errorf("Expected scope to contain %q, got %q", tt.expectedScopePrefix, capturedScope) + } + + if tt.shouldContainRepo != "" && !strings.Contains(capturedScope, tt.shouldContainRepo) { + t.Errorf("Expected scope to contain repo %q, got %q", tt.shouldContainRepo, capturedScope) + } + + // For ABAC, verify we're NOT using wildcard + if tt.isAbac && strings.Contains(capturedScope, "repository:*:*") { + t.Errorf("ABAC refresh should NOT use wildcard scope, got %q", capturedScope) + } + }) + } +}