Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type GitHubClient interface {
ListRepos(ctx context.Context, org string) ([]Repo, error)
GetTree(ctx context.Context, owner, repo, branch string) ([]FileEntry, error)
GetBranchProtection(ctx context.Context, owner, repo, branch string) (*BranchProtection, error)
GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error)
CreateIssue(ctx context.Context, owner, repo, title, body string) error
}

Expand Down Expand Up @@ -183,6 +184,49 @@ func (c *realGitHubClient) GetBranchProtection(ctx context.Context, owner, repo,
return bp, nil
}

func (c *realGitHubClient) GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/rules/branches/%s", owner, repo, branch)

var rules []struct {
Type string `json:"type"`
Parameters *struct {
RequiredApprovingReviewCount int `json:"required_approving_review_count"`
RequiredStatusChecks []struct {
Context string `json:"context"`
} `json:"required_status_checks"`
} `json:"parameters"`
}
if err := c.doRequest(ctx, url, &rules); err != nil {
return nil, fmt.Errorf("get branch rules for %s/%s: %w", owner, repo, err)
}

var bp BranchProtection
found := false

for _, rule := range rules {
if rule.Parameters == nil {
continue
}
switch rule.Type {
case "pull_request":
found = true
if rule.Parameters.RequiredApprovingReviewCount > bp.RequiredReviewers {
bp.RequiredReviewers = rule.Parameters.RequiredApprovingReviewCount
}
case "required_status_checks":
found = true
for _, sc := range rule.Parameters.RequiredStatusChecks {
bp.RequiredStatusChecks = append(bp.RequiredStatusChecks, sc.Context)
}
}
}

if !found {
return nil, nil
}
return &bp, nil
}

func (c *realGitHubClient) CreateIssue(ctx context.Context, owner, repo, title, body string) error {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/issues", owner, repo)

Expand Down
16 changes: 14 additions & 2 deletions client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import "context"
type MockGitHubClient struct {
Repos []Repo
Err error
Tree map[string][]FileEntry // repo name -> file entries
Tree map[string][]FileEntry // repo name -> file entries
TreeErr error
Protection map[string]*BranchProtection // repo name -> branch protection
Protection map[string]*BranchProtection // repo name -> classic branch protection
ProtectionErr error
Rulesets map[string]*BranchProtection // repo name -> rulesets protection
RulesetsErr error
IssueErr error
// CreatedIssue records the last CreateIssue call for assertions.
CreatedIssue struct {
Expand Down Expand Up @@ -41,6 +43,16 @@ func (m *MockGitHubClient) GetBranchProtection(ctx context.Context, owner, repo,
return nil, nil
}

func (m *MockGitHubClient) GetRulesets(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) {
if m.RulesetsErr != nil {
return nil, m.RulesetsErr
}
if m.Rulesets != nil {
return m.Rulesets[repo], nil
}
return nil, nil
}

func (m *MockGitHubClient) CreateIssue(ctx context.Context, owner, repo, title, body string) error {
m.CreatedIssue.Owner = owner
m.CreatedIssue.Repo = repo
Expand Down
10 changes: 8 additions & 2 deletions scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,15 @@ func Scan(ctx context.Context, client GitHubClient, org string) ([]RepoResult, e
}
repo.Files = files

protection, err := client.GetBranchProtection(ctx, org, repo.Name, repo.DefaultBranch)
protection, err := client.GetRulesets(ctx, org, repo.Name, repo.DefaultBranch)
if err != nil {
return nil, fmt.Errorf("get branch protection for repo %s: %w", repo.Name, err)
return nil, fmt.Errorf("get rulesets for repo %s: %w", repo.Name, err)
}
if protection == nil {
protection, err = client.GetBranchProtection(ctx, org, repo.Name, repo.DefaultBranch)
if err != nil {
return nil, fmt.Errorf("get branch protection for repo %s: %w", repo.Name, err)
}
}
repo.BranchProtection = protection

Expand Down
49 changes: 49 additions & 0 deletions scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,52 @@ func TestScan_PropagatesClientError(t *testing.T) {
t.Fatal("expected error, got nil")
}
}

func TestScan_UsesRulesetsWhenAvailable(t *testing.T) {
client := &MockGitHubClient{
Repos: []Repo{
{Name: "modern-repo", Description: "Uses rulesets", DefaultBranch: "main"},
},
Rulesets: map[string]*BranchProtection{
"modern-repo": {RequiredReviewers: 2},
},
Protection: map[string]*BranchProtection{
"modern-repo": {RequiredReviewers: 1},
},
}

results, err := Scan(context.Background(), client, "test-org")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Should use rulesets (2 reviewers), not classic (1 reviewer)
for _, r := range results[0].Results {
if r.RuleName == "Has required reviewers" && !r.Passed {
t.Error("expected pass from rulesets")
}
}
}

func TestScan_FallsBackToClassicProtection(t *testing.T) {
client := &MockGitHubClient{
Repos: []Repo{
{Name: "legacy-repo", Description: "Uses classic", DefaultBranch: "main"},
},
// No rulesets - returns nil
Protection: map[string]*BranchProtection{
"legacy-repo": {RequiredReviewers: 1},
},
}

results, err := Scan(context.Background(), client, "test-org")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

for _, r := range results[0].Results {
if r.RuleName == "Has branch protection" && !r.Passed {
t.Error("expected pass from classic branch protection fallback")
}
}
}
Loading