feat: [CODE-2987]: Add codeowners as reviewers (#3439)

* Handle violations in branch CreatePullReqVerify
* Merge remote-tracking branch 'origin/main' into dd/codeowner-reviewers
* Add PullReq infix to Create rule interface and dedup reviewers in PR create
* Add create field to pullreq field in rule def
* Add codeowners as reviewers
This commit is contained in:
Darko Draskovic 2025-02-20 16:03:06 +00:00 committed by Harness
parent 3fee159170
commit e39ae83e78
8 changed files with 209 additions and 20 deletions

View File

@ -28,6 +28,7 @@ import (
pullreqevents "github.com/harness/gitness/app/events/pullreq" pullreqevents "github.com/harness/gitness/app/events/pullreq"
"github.com/harness/gitness/app/services/instrument" "github.com/harness/gitness/app/services/instrument"
labelsvc "github.com/harness/gitness/app/services/label" labelsvc "github.com/harness/gitness/app/services/label"
"github.com/harness/gitness/app/services/protection"
"github.com/harness/gitness/errors" "github.com/harness/gitness/errors"
"github.com/harness/gitness/git" "github.com/harness/gitness/git"
gitenum "github.com/harness/gitness/git/enum" gitenum "github.com/harness/gitness/git/enum"
@ -51,6 +52,8 @@ type CreateInput struct {
ReviewerIDs []int64 `json:"reviewer_ids"` ReviewerIDs []int64 `json:"reviewer_ids"`
Labels []*types.PullReqLabelAssignInput `json:"labels"` Labels []*types.PullReqLabelAssignInput `json:"labels"`
BypassRules bool `json:"bypass_rules"`
} }
func (in *CreateInput) Sanitize() error { func (in *CreateInput) Sanitize() error {
@ -110,7 +113,9 @@ func (c *Controller) Create(
return nil, err return nil, err
} }
targetWriteParams, err := controller.CreateRPCInternalWriteParams(ctx, c.urlProvider, session, targetRepo) targetWriteParams, err := controller.CreateRPCInternalWriteParams(
ctx, c.urlProvider, session, targetRepo,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create RPC write params: %w", err) return nil, fmt.Errorf("failed to create RPC write params: %w", err)
} }
@ -147,9 +152,22 @@ func (c *Controller) Create(
var reviewers []*types.PullReqReviewer var reviewers []*types.PullReqReviewer
reviewerInput, err := c.prepareReviewers(ctx, session, in.ReviewerIDs, targetRepo) reviewerInputEmailMap, err := c.prepareReviewers(ctx, session, in.ReviewerIDs, targetRepo)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to prepare reviewers: %w", err)
}
codeowners, err := c.prepareCodeowners(
ctx, session, targetRepo, in, mergeBaseSHA.String(), sourceSHA.String(),
)
if err != nil {
return nil, fmt.Errorf("failed to prepare codeowners: %w", err)
}
for _, codeowner := range codeowners {
if _, ok := reviewerInputEmailMap[codeowner.Email]; !ok {
reviewerInputEmailMap[codeowner.Email] = codeowner
}
} }
var labelAssignOuts []*labelsvc.AssignToPullReqOut var labelAssignOuts []*labelsvc.AssignToPullReqOut
@ -201,7 +219,7 @@ func (c *Controller) Create(
// Create reviewers and assign labels // Create reviewers and assign labels
reviewers, err = c.createReviewers(ctx, session, reviewerInput, targetRepo, pr) reviewers, err = c.createReviewers(ctx, session, reviewerInputEmailMap, targetRepo, pr)
if err != nil { if err != nil {
return err return err
} }
@ -268,14 +286,14 @@ func (c *Controller) prepareReviewers(
session *auth.Session, session *auth.Session,
reviewers []int64, reviewers []int64,
repo *types.RepositoryCore, repo *types.RepositoryCore,
) ([]*types.Principal, error) { ) (map[string]*types.Principal, error) {
if len(reviewers) == 0 { if len(reviewers) == 0 {
return []*types.Principal{}, nil return map[string]*types.Principal{}, nil
} }
principals := make([]*types.Principal, len(reviewers)) principalEmailMap := make(map[string]*types.Principal, len(reviewers))
for i, id := range reviewers { for _, id := range reviewers {
if id == session.Principal.ID { if id == session.Principal.ID {
return nil, usererror.BadRequest("PR creator cannot be added as a reviewer.") return nil, usererror.BadRequest("PR creator cannot be added as a reviewer.")
} }
@ -305,7 +323,55 @@ func (c *Controller) prepareReviewers(
"reviewer principal %s access error: %w", reviewerPrincipal.UID, err) "reviewer principal %s access error: %w", reviewerPrincipal.UID, err)
} }
principals[i] = reviewerPrincipal principalEmailMap[reviewerPrincipal.Email] = reviewerPrincipal
}
return principalEmailMap, nil
}
func (c *Controller) prepareCodeowners(
ctx context.Context,
session *auth.Session,
targetRepo *types.RepositoryCore,
in *CreateInput,
mergeBaseSHA string,
sourceSHA string,
) ([]*types.Principal, error) {
rules, isRepoOwner, err := c.fetchRules(ctx, session, targetRepo)
if err != nil {
return nil, fmt.Errorf("failed to fetch protection rules: %w", err)
}
out, _, err := rules.CreatePullReqVerify(ctx, protection.CreatePullReqVerifyInput{
ResolveUserGroupID: c.userGroupService.ListUserIDsByGroupIDs,
Actor: &session.Principal,
AllowBypass: in.BypassRules,
IsRepoOwner: isRepoOwner,
DefaultBranch: targetRepo.DefaultBranch,
TargetBranch: in.TargetBranch,
})
if err != nil {
return nil, fmt.Errorf("failed to verify protection rules: %w", err)
}
if !out.RequestCodeOwners {
return []*types.Principal{}, nil
}
codeowners, err := c.codeOwners.GetApplicableCodeOwners(
ctx, targetRepo, in.TargetBranch, mergeBaseSHA, sourceSHA,
)
if err != nil {
return nil, fmt.Errorf("failed to get applicable code owners: %w", err)
}
var emails []string
for _, entry := range codeowners.Entries {
emails = append(emails, entry.Owners...)
}
principals, err := c.principalStore.FindManyByEmail(ctx, emails)
if err != nil {
return nil, fmt.Errorf("failed to find many principals by email: %w", err)
} }
return principals, nil return principals, nil
@ -314,7 +380,7 @@ func (c *Controller) prepareReviewers(
func (c *Controller) createReviewers( func (c *Controller) createReviewers(
ctx context.Context, ctx context.Context,
session *auth.Session, session *auth.Session,
principals []*types.Principal, principals map[string]*types.Principal,
repo *types.RepositoryCore, repo *types.RepositoryCore,
pr *types.PullReq, pr *types.PullReq,
) ([]*types.PullReqReviewer, error) { ) ([]*types.PullReqReviewer, error) {
@ -324,7 +390,8 @@ func (c *Controller) createReviewers(
reviewers := make([]*types.PullReqReviewer, len(principals)) reviewers := make([]*types.PullReqReviewer, len(principals))
for i, principal := range principals { var i int
for _, principal := range principals {
reviewer := newPullReqReviewer( reviewer := newPullReqReviewer(
session, pr, repo, session, pr, repo,
principal.ToPrincipalInfo(), principal.ToPrincipalInfo(),
@ -340,6 +407,7 @@ func (c *Controller) createReviewers(
} }
reviewers[i] = reviewer reviewers[i] = reviewer
i++
} }
return reviewers, nil return reviewers, nil

View File

@ -356,20 +356,22 @@ func (s *Service) getCodeOwnerFileNode(
return nil, fmt.Errorf("no codeowner file found: %w", ErrNotFound) return nil, fmt.Errorf("no codeowner file found: %w", ErrNotFound)
} }
func (s *Service) getApplicableCodeOwnersForPR( func (s *Service) GetApplicableCodeOwners(
ctx context.Context, ctx context.Context,
repo *types.RepositoryCore, repo *types.RepositoryCore,
pr *types.PullReq, targetBranch string,
baseRef string,
headRef string,
) (*CodeOwners, error) { ) (*CodeOwners, error) {
codeOwners, err := s.get(ctx, repo, pr.TargetBranch) codeOwners, err := s.get(ctx, repo, targetBranch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
diffFileStats, err := s.git.DiffFileNames(ctx, &git.DiffParams{ diffFileStats, err := s.git.DiffFileNames(ctx, &git.DiffParams{
ReadParams: git.CreateReadParams(repo), ReadParams: git.CreateReadParams(repo),
BaseRef: pr.MergeBaseSHA, BaseRef: baseRef, // MergeBaseSHA,
HeadRef: pr.SourceSHA, HeadRef: headRef, // SourceSHA,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get diff file stat: %w", err) return nil, fmt.Errorf("failed to get diff file stat: %w", err)
@ -415,7 +417,9 @@ func (s *Service) Evaluate(
pr *types.PullReq, pr *types.PullReq,
reviewers []*types.PullReqReviewer, reviewers []*types.PullReqReviewer,
) (*Evaluation, error) { ) (*Evaluation, error) {
owners, err := s.getApplicableCodeOwnersForPR(ctx, repo, pr) owners, err := s.GetApplicableCodeOwners(
ctx, repo, pr.TargetBranch, pr.MergeBaseSHA, pr.SourceSHA,
)
if err != nil { if err != nil {
return &Evaluation{}, fmt.Errorf("failed to get codeOwners: %w", err) return &Evaluation{}, fmt.Errorf("failed to get codeOwners: %w", err)
} }

View File

@ -33,6 +33,7 @@ type Branch struct {
var ( var (
// ensures that the Branch type implements Definition interface. // ensures that the Branch type implements Definition interface.
_ Definition = (*Branch)(nil) _ Definition = (*Branch)(nil)
_ Protection = (*Branch)(nil)
) )
func (v *Branch) MergeVerify( func (v *Branch) MergeVerify(
@ -85,6 +86,27 @@ func (v *Branch) RequiredChecks(
}, nil }, nil
} }
func (v *Branch) CreatePullReqVerify(
ctx context.Context,
in CreatePullReqVerifyInput,
) (CreatePullReqVerifyOutput, []types.RuleViolations, error) {
var out CreatePullReqVerifyOutput
out, violations, err := v.PullReq.CreatePullReqVerify(ctx, in)
if err != nil {
return CreatePullReqVerifyOutput{}, nil, err
}
bypassable := v.Bypass.matches(ctx, in.Actor, in.IsRepoOwner, in.ResolveUserGroupID)
bypassed := in.AllowBypass && bypassable
for i := range violations {
violations[i].Bypassable = bypassable
violations[i].Bypassed = bypassed
}
return out, violations, nil
}
func (v *Branch) RefChangeVerify( func (v *Branch) RefChangeVerify(
ctx context.Context, ctx context.Context,
in RefChangeVerifyInput, in RefChangeVerifyInput,

View File

@ -35,6 +35,7 @@ type (
Protection interface { Protection interface {
MergeVerifier MergeVerifier
RefChangeVerifier RefChangeVerifier
CreatePullReqVerifier
UserIDs() ([]int64, error) UserIDs() ([]int64, error)
UserGroupIDs() ([]int64, error) UserGroupIDs() ([]int64, error)
} }

View File

@ -104,6 +104,33 @@ func (s ruleSet) RequiredChecks(
}, nil }, nil
} }
func (s ruleSet) CreatePullReqVerify(
ctx context.Context,
in CreatePullReqVerifyInput,
) (CreatePullReqVerifyOutput, []types.RuleViolations, error) {
var out CreatePullReqVerifyOutput
var violations []types.RuleViolations
err := s.forEachRuleMatchBranch(in.DefaultBranch, in.TargetBranch,
func(r *types.RuleInfoInternal, p Protection) error {
rOut, rVs, err := p.CreatePullReqVerify(ctx, in)
if err != nil {
return err
}
// combine output across rules
violations = append(violations, backFillRule(rVs, r.RuleInfo)...)
out.RequestCodeOwners = out.RequestCodeOwners || rOut.RequestCodeOwners
return nil
})
if err != nil {
return out, nil, fmt.Errorf("failed to process each rule in ruleSet: %w", err)
}
return out, violations, nil
}
func (s ruleSet) RefChangeVerify(ctx context.Context, in RefChangeVerifyInput) ([]types.RuleViolations, error) { func (s ruleSet) RefChangeVerify(ctx context.Context, in RefChangeVerifyInput) ([]types.RuleViolations, error) {
var violations []types.RuleViolations var violations []types.RuleViolations

View File

@ -71,12 +71,33 @@ type (
RequiredIdentifiers map[string]struct{} RequiredIdentifiers map[string]struct{}
BypassableIdentifiers map[string]struct{} BypassableIdentifiers map[string]struct{}
} }
CreatePullReqVerifier interface {
CreatePullReqVerify(
ctx context.Context,
in CreatePullReqVerifyInput,
) (CreatePullReqVerifyOutput, []types.RuleViolations, error)
}
CreatePullReqVerifyInput struct {
ResolveUserGroupID func(ctx context.Context, userGroupIDs []int64) ([]int64, error)
Actor *types.Principal
AllowBypass bool
IsRepoOwner bool
DefaultBranch string
TargetBranch string
}
CreatePullReqVerifyOutput struct {
RequestCodeOwners bool
}
) )
// ensures that the DefPullReq type implements Sanitizer and MergeVerifier interface. // Ensures that the DefPullReq type implements Sanitizer, MergeVerifier and CreatePullReqVerifier interface.
var ( var (
_ Sanitizer = (*DefPullReq)(nil) _ Sanitizer = (*DefPullReq)(nil)
_ MergeVerifier = (*DefPullReq)(nil) _ MergeVerifier = (*DefPullReq)(nil)
_ CreatePullReqVerifier = (*DefPullReq)(nil)
) )
const ( const (
@ -271,6 +292,15 @@ func (v *DefPullReq) RequiredChecks(
}, nil }, nil
} }
func (v *DefPullReq) CreatePullReqVerify(
context.Context,
CreatePullReqVerifyInput,
) (CreatePullReqVerifyOutput, []types.RuleViolations, error) {
return CreatePullReqVerifyOutput{
RequestCodeOwners: v.Reviewers.RequestCodeOwners,
}, nil, nil
}
type DefApprovals struct { type DefApprovals struct {
RequireCodeOwners bool `json:"require_code_owners,omitempty"` RequireCodeOwners bool `json:"require_code_owners,omitempty"`
RequireMinimumCount int `json:"require_minimum_count,omitempty"` RequireMinimumCount int `json:"require_minimum_count,omitempty"`
@ -370,6 +400,10 @@ func (v *DefMerge) Sanitize() error {
return nil return nil
} }
type DefReviewers struct {
RequestCodeOwners bool `json:"request_code_owners,omitempty"`
}
type DefPush struct { type DefPush struct {
Block bool `json:"block,omitempty"` Block bool `json:"block,omitempty"`
} }
@ -383,6 +417,7 @@ type DefPullReq struct {
Comments DefComments `json:"comments"` Comments DefComments `json:"comments"`
StatusChecks DefStatusChecks `json:"status_checks"` StatusChecks DefStatusChecks `json:"status_checks"`
Merge DefMerge `json:"merge"` Merge DefMerge `json:"merge"`
Reviewers DefReviewers `json:"reviewers"`
} }
func (v *DefPullReq) Sanitize() error { func (v *DefPullReq) Sanitize() error {

View File

@ -44,6 +44,9 @@ type (
// FindByEmail finds the principal by email. // FindByEmail finds the principal by email.
FindByEmail(ctx context.Context, email string) (*types.Principal, error) FindByEmail(ctx context.Context, email string) (*types.Principal, error)
// FindManyByEmail finds all principals for the provided emails.
FindManyByEmail(ctx context.Context, emails []string) ([]*types.Principal, error)
/* /*
* USER RELATED OPERATIONS. * USER RELATED OPERATIONS.
*/ */

View File

@ -162,6 +162,35 @@ func (s *PrincipalStore) FindByEmail(ctx context.Context, email string) (*types.
return s.mapDBPrincipal(dst), nil return s.mapDBPrincipal(dst), nil
} }
func (s *PrincipalStore) FindManyByEmail(
ctx context.Context,
emails []string,
) ([]*types.Principal, error) {
lowerCaseEmails := make([]string, len(emails))
for i := range emails {
lowerCaseEmails[i] = strings.ToLower(emails[i])
}
stmt := database.Builder.
Select(principalColumns).
From("principals").
Where(squirrel.Eq{"principal_email": lowerCaseEmails})
db := dbtx.GetAccessor(ctx, s.db)
sqlQuery, params, err := stmt.ToSql()
if err != nil {
return nil, database.ProcessSQLErrorf(ctx, err, "failed to generate find many principal query")
}
dst := []*principal{}
if err := db.SelectContext(ctx, &dst, sqlQuery, params...); err != nil {
return nil, database.ProcessSQLErrorf(ctx, err, "find many by email for principal query failed")
}
return s.mapDBPrincipals(dst), nil
}
// List lists the principals matching the provided filter. // List lists the principals matching the provided filter.
func (s *PrincipalStore) List(ctx context.Context, func (s *PrincipalStore) List(ctx context.Context,
opts *types.PrincipalFilter) ([]*types.Principal, error) { opts *types.PrincipalFilter) ([]*types.Principal, error) {