From 869fbc692f67c44f29b1766d3deb9665d87450a2 Mon Sep 17 00:00:00 2001 From: packet-mover Date: Mon, 6 Apr 2026 22:57:58 +0200 Subject: [PATCH] feat: support GitHub rulesets for branch protection rules Check rulesets first via GET /repos/{owner}/{repo}/rules/branches/{branch} (single API call). Fall back to classic branch protection only if no rulesets are configured. No merging needed. --- client.go | 44 ++++++++++++++++++++++++++++++++++++++++ client_mock_test.go | 16 +++++++++++++-- scanner.go | 10 +++++++-- scanner_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index ea9d536..dd86a39 100644 --- a/client.go +++ b/client.go @@ -37,6 +37,7 @@ type GitHubClient interface { ListRepos(ctx context.Context, org string) ([]Repo, error) GetTree(ctx context.Context, owner, repo, branch string) ([]FileEntry, error) GetBranchProtection(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) + GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) CreateIssue(ctx context.Context, owner, repo, title, body string) error } @@ -183,6 +184,49 @@ func (c *realGitHubClient) GetBranchProtection(ctx context.Context, owner, repo, return bp, nil } +func (c *realGitHubClient) GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/%s/rules/branches/%s", owner, repo, branch) + + var rules []struct { + Type string `json:"type"` + Parameters *struct { + RequiredApprovingReviewCount int `json:"required_approving_review_count"` + RequiredStatusChecks []struct { + Context string `json:"context"` + } `json:"required_status_checks"` + } `json:"parameters"` + } + if err := c.doRequest(ctx, url, &rules); err != nil { + return nil, fmt.Errorf("get branch rules for %s/%s: %w", owner, repo, err) + } + + var bp BranchProtection + found := false + + for _, rule := range rules { + if rule.Parameters == nil { + continue + } + switch rule.Type { + case "pull_request": + found = true + if rule.Parameters.RequiredApprovingReviewCount > bp.RequiredReviewers { + bp.RequiredReviewers = rule.Parameters.RequiredApprovingReviewCount + } + case "required_status_checks": + found = true + for _, sc := range rule.Parameters.RequiredStatusChecks { + bp.RequiredStatusChecks = append(bp.RequiredStatusChecks, sc.Context) + } + } + } + + if !found { + return nil, nil + } + return &bp, nil +} + func (c *realGitHubClient) CreateIssue(ctx context.Context, owner, repo, title, body string) error { url := fmt.Sprintf("https://api.github.com/repos/%s/%s/issues", owner, repo) diff --git a/client_mock_test.go b/client_mock_test.go index 308207d..cd631a2 100644 --- a/client_mock_test.go +++ b/client_mock_test.go @@ -6,10 +6,12 @@ import "context" type MockGitHubClient struct { Repos []Repo Err error - Tree map[string][]FileEntry // repo name -> file entries + Tree map[string][]FileEntry // repo name -> file entries TreeErr error - Protection map[string]*BranchProtection // repo name -> branch protection + Protection map[string]*BranchProtection // repo name -> classic branch protection ProtectionErr error + Rulesets map[string]*BranchProtection // repo name -> rulesets protection + RulesetsErr error IssueErr error // CreatedIssue records the last CreateIssue call for assertions. CreatedIssue struct { @@ -41,6 +43,16 @@ func (m *MockGitHubClient) GetBranchProtection(ctx context.Context, owner, repo, return nil, nil } +func (m *MockGitHubClient) GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) { + if m.RulesetsErr != nil { + return nil, m.RulesetsErr + } + if m.Rulesets != nil { + return m.Rulesets[repo], nil + } + return nil, nil +} + func (m *MockGitHubClient) CreateIssue(ctx context.Context, owner, repo, title, body string) error { m.CreatedIssue.Owner = owner m.CreatedIssue.Repo = repo diff --git a/scanner.go b/scanner.go index 438575d..73e0d8e 100644 --- a/scanner.go +++ b/scanner.go @@ -65,9 +65,15 @@ func Scan(ctx context.Context, client GitHubClient, org string) ([]RepoResult, e } repo.Files = files - protection, err := client.GetBranchProtection(ctx, org, repo.Name, repo.DefaultBranch) + protection, err := client.GetRulesets(ctx, org, repo.Name, repo.DefaultBranch) if err != nil { - return nil, fmt.Errorf("get branch protection for repo %s: %w", repo.Name, err) + return nil, fmt.Errorf("get rulesets for repo %s: %w", repo.Name, err) + } + if protection == nil { + protection, err = client.GetBranchProtection(ctx, org, repo.Name, repo.DefaultBranch) + if err != nil { + return nil, fmt.Errorf("get branch protection for repo %s: %w", repo.Name, err) + } } repo.BranchProtection = protection diff --git a/scanner_test.go b/scanner_test.go index 41c65fb..807b58e 100644 --- a/scanner_test.go +++ b/scanner_test.go @@ -95,3 +95,52 @@ func TestScan_PropagatesClientError(t *testing.T) { t.Fatal("expected error, got nil") } } + +func TestScan_UsesRulesetsWhenAvailable(t *testing.T) { + client := &MockGitHubClient{ + Repos: []Repo{ + {Name: "modern-repo", Description: "Uses rulesets", DefaultBranch: "main"}, + }, + Rulesets: map[string]*BranchProtection{ + "modern-repo": {RequiredReviewers: 2}, + }, + Protection: map[string]*BranchProtection{ + "modern-repo": {RequiredReviewers: 1}, + }, + } + + results, err := Scan(context.Background(), client, "test-org") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should use rulesets (2 reviewers), not classic (1 reviewer) + for _, r := range results[0].Results { + if r.RuleName == "Has required reviewers" && !r.Passed { + t.Error("expected pass from rulesets") + } + } +} + +func TestScan_FallsBackToClassicProtection(t *testing.T) { + client := &MockGitHubClient{ + Repos: []Repo{ + {Name: "legacy-repo", Description: "Uses classic", DefaultBranch: "main"}, + }, + // No rulesets - returns nil + Protection: map[string]*BranchProtection{ + "legacy-repo": {RequiredReviewers: 1}, + }, + } + + results, err := Scan(context.Background(), client, "test-org") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + for _, r := range results[0].Results { + if r.RuleName == "Has branch protection" && !r.Passed { + t.Error("expected pass from classic branch protection fallback") + } + } +}