Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,28 @@ type FileEntry struct {
Type string // "blob" (file) or "tree" (directory)
}

// BranchProtection holds the branch protection settings the scanner needs.
type BranchProtection struct {
RequiredReviewers int
RequiredStatusChecks []string
}

// Repo represents a GitHub repository with the fields the scanner needs.
type Repo struct {
Name string
Description string
DefaultBranch string
Archived bool
Files []FileEntry // all files and directories in the repo
Name string
Description string
DefaultBranch string
Archived bool
Files []FileEntry // all files and directories in the repo
BranchProtection *BranchProtection // nil if no protection configured
}

// GitHubClient is the interface for all GitHub API interactions.
// The scanner depends only on this interface, making it testable via mocks.
type GitHubClient interface {
ListRepos(ctx context.Context, org string) ([]Repo, error)
GetTree(ctx context.Context, owner, repo, branch string) ([]FileEntry, error)
GetBranchProtection(ctx context.Context, owner, repo, branch string) (*BranchProtection, error)
CreateIssue(ctx context.Context, owner, repo, title, body string) error
}

Expand Down Expand Up @@ -130,6 +138,51 @@ func (c *realGitHubClient) GetTree(ctx context.Context, owner, repo, branch stri
return files, nil
}

func (c *realGitHubClient) GetBranchProtection(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/branches/%s/protection", owner, repo, branch)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("create request for %s: %w", url, err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
req.Header.Set("Accept", "application/vnd.github+json")

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request %s: %w", url, err)
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("get branch protection for %s/%s: status %d", owner, repo, resp.StatusCode)
}

var result struct {
RequiredPullRequestReviews *struct {
RequiredApprovingReviewCount int `json:"required_approving_review_count"`
} `json:"required_pull_request_reviews"`
RequiredStatusChecks *struct {
Contexts []string `json:"contexts"`
} `json:"required_status_checks"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode branch protection for %s/%s: %w", owner, repo, err)
}

bp := &BranchProtection{}
if result.RequiredPullRequestReviews != nil {
bp.RequiredReviewers = result.RequiredPullRequestReviews.RequiredApprovingReviewCount
}
if result.RequiredStatusChecks != nil {
bp.RequiredStatusChecks = result.RequiredStatusChecks.Contexts
}
return bp, nil
}

func (c *realGitHubClient) CreateIssue(ctx context.Context, owner, repo, title, body string) error {
url := fmt.Sprintf("https://api.github.com/repos/%s/%s/issues", owner, repo)

Expand Down
22 changes: 17 additions & 5 deletions client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import "context"

// MockGitHubClient implements GitHubClient with canned responses for testing.
type MockGitHubClient struct {
Repos []Repo
Err error
Tree map[string][]FileEntry // repo name -> file entries
TreeErr error
IssueErr error
Repos []Repo
Err error
Tree map[string][]FileEntry // repo name -> file entries
TreeErr error
Protection map[string]*BranchProtection // repo name -> branch protection
ProtectionErr error
IssueErr error
// CreatedIssue records the last CreateIssue call for assertions.
CreatedIssue struct {
Owner, Repo, Title, Body string
Expand All @@ -29,6 +31,16 @@ func (m *MockGitHubClient) GetTree(ctx context.Context, owner, repo, branch stri
return nil, nil
}

func (m *MockGitHubClient) GetBranchProtection(ctx context.Context, owner, repo, branch string) (*BranchProtection, error) {
if m.ProtectionErr != nil {
return nil, m.ProtectionErr
}
if m.Protection != nil {
return m.Protection[repo], nil
}
return nil, nil
}

func (m *MockGitHubClient) CreateIssue(ctx context.Context, owner, repo, title, body string) error {
m.CreatedIssue.Owner = owner
m.CreatedIssue.Repo = repo
Expand Down
27 changes: 27 additions & 0 deletions rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ func AllRules() []Rule {
HasCIWorkflow{},
HasTestDirectory{},
HasCodeowners{},
HasBranchProtection{},
HasRequiredReviewers{},
HasRequiredStatusChecks{},
}
}

Expand Down Expand Up @@ -109,6 +112,30 @@ func (r HasCodeowners) Check(repo Repo) bool {
hasFile(repo.Files, ".github/CODEOWNERS")
}

// HasBranchProtection checks that the default branch has protection rules enabled.
type HasBranchProtection struct{}

func (r HasBranchProtection) Name() string { return "Has branch protection" }
func (r HasBranchProtection) Check(repo Repo) bool {
return repo.BranchProtection != nil
}

// HasRequiredReviewers checks that at least one approving review is required.
type HasRequiredReviewers struct{}

func (r HasRequiredReviewers) Name() string { return "Has required reviewers" }
func (r HasRequiredReviewers) Check(repo Repo) bool {
return repo.BranchProtection != nil && repo.BranchProtection.RequiredReviewers >= 1
}

// HasRequiredStatusChecks checks that at least one status check is required before merging.
type HasRequiredStatusChecks struct{}

func (r HasRequiredStatusChecks) Name() string { return "Requires status checks before merging" }
func (r HasRequiredStatusChecks) Check(repo Repo) bool {
return repo.BranchProtection != nil && len(repo.BranchProtection.RequiredStatusChecks) > 0
}

func findFile(files []FileEntry, path string) (FileEntry, bool) {
for _, f := range files {
if f.Path == path {
Expand Down
66 changes: 66 additions & 0 deletions rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,69 @@ func TestHasCodeowners_Fail(t *testing.T) {
t.Error("expected fail when CODEOWNERS is missing from all locations")
}
}

func TestHasBranchProtection_Pass(t *testing.T) {
rule := HasBranchProtection{}

if !rule.Check(Repo{BranchProtection: &BranchProtection{}}) {
t.Error("expected pass when branch protection is enabled")
}
}

func TestHasBranchProtection_Fail(t *testing.T) {
rule := HasBranchProtection{}

if rule.Check(Repo{BranchProtection: nil}) {
t.Error("expected fail when branch protection is nil")
}
}

func TestHasRequiredReviewers_Pass(t *testing.T) {
rule := HasRequiredReviewers{}

if !rule.Check(Repo{BranchProtection: &BranchProtection{RequiredReviewers: 1}}) {
t.Error("expected pass when required reviewers >= 1")
}
}

func TestHasRequiredReviewers_Fail_Zero(t *testing.T) {
rule := HasRequiredReviewers{}

if rule.Check(Repo{BranchProtection: &BranchProtection{RequiredReviewers: 0}}) {
t.Error("expected fail when required reviewers is 0")
}
}

func TestHasRequiredReviewers_Fail_NoProtection(t *testing.T) {
rule := HasRequiredReviewers{}

if rule.Check(Repo{BranchProtection: nil}) {
t.Error("expected fail when branch protection is nil")
}
}

func TestHasRequiredStatusChecks_Pass(t *testing.T) {
rule := HasRequiredStatusChecks{}

bp := &BranchProtection{RequiredStatusChecks: []string{"ci/build"}}
if !rule.Check(Repo{BranchProtection: bp}) {
t.Error("expected pass when status checks are configured")
}
}

func TestHasRequiredStatusChecks_Fail_Empty(t *testing.T) {
rule := HasRequiredStatusChecks{}

bp := &BranchProtection{RequiredStatusChecks: []string{}}
if rule.Check(Repo{BranchProtection: bp}) {
t.Error("expected fail when status checks list is empty")
}
}

func TestHasRequiredStatusChecks_Fail_NoProtection(t *testing.T) {
rule := HasRequiredStatusChecks{}

if rule.Check(Repo{BranchProtection: nil}) {
t.Error("expected fail when branch protection is nil")
}
}
6 changes: 6 additions & 0 deletions scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ func Scan(ctx context.Context, client GitHubClient, org string) ([]RepoResult, e
}
repo.Files = files

protection, err := client.GetBranchProtection(ctx, org, repo.Name, repo.DefaultBranch)
if err != nil {
return nil, fmt.Errorf("get branch protection for repo %s: %w", repo.Name, err)
}
repo.BranchProtection = protection

rr := RepoResult{RepoName: repo.Name}
for _, rule := range rules {
rr.Results = append(rr.Results, RuleResult{
Expand Down
Loading