Skip to content

Commit 084fdb4

Browse files
authored
Merge pull request #71 from thomas-vilte/dev
feat(ai): Leverage Repository Labels for Enhanced AI Generation
2 parents 270874b + a5675fe commit 084fdb4

13 files changed

Lines changed: 218 additions & 158 deletions

internal/ai/gemini/helpers.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package gemini
2+
3+
import (
4+
"strings"
5+
)
6+
7+
// CleanLabels cleans and validates labels, keeping only the allowed ones.
8+
// It accepts a list of labels to clean and a list of available labels from the repository.
9+
// If availableLabels is empty, it falls back to a default list of common labels.
10+
func CleanLabels(labels []string, availableLabels []string) []string {
11+
allowedLabels := make(map[string]bool)
12+
13+
if len(availableLabels) > 0 {
14+
for _, l := range availableLabels {
15+
allowedLabels[strings.ToLower(l)] = true
16+
}
17+
} else {
18+
// Fallback to default list if no repo labels provided
19+
defaultLabels := []string{
20+
"feature", "fix", "refactor", "docs", "test", "infra",
21+
"enhancement", "bug", "good first issue", "help wanted",
22+
"chore", "performance", "security", "tech-debt", "breaking-change",
23+
}
24+
for _, l := range defaultLabels {
25+
allowedLabels[l] = true
26+
}
27+
}
28+
29+
cleaned := make([]string, 0)
30+
seen := make(map[string]bool)
31+
32+
for _, label := range labels {
33+
trimmed := strings.TrimSpace(strings.ToLower(label))
34+
if trimmed != "" && allowedLabels[trimmed] && !seen[trimmed] {
35+
cleaned = append(cleaned, trimmed)
36+
seen[trimmed] = true
37+
}
38+
}
39+
40+
return cleaned
41+
}

internal/ai/gemini/issue_content_generator.go

Lines changed: 33 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,33 @@ func NewGeminiIssueContentGenerator(ctx context.Context, cfg *config.Config, onC
7272
return service, nil
7373
}
7474

75+
func getIssueSchema() *genai.Schema {
76+
return &genai.Schema{
77+
Type: genai.TypeObject,
78+
Required: []string{"title", "description", "labels"},
79+
Properties: map[string]*genai.Schema{
80+
"title": {
81+
Type: genai.TypeString,
82+
Description: "The title of the issue",
83+
},
84+
"description": {
85+
Type: genai.TypeString,
86+
Description: "The body of the issue in markdown format",
87+
},
88+
"labels": {
89+
Type: genai.TypeArray,
90+
Items: &genai.Schema{
91+
Type: genai.TypeString,
92+
},
93+
Description: "List of labels (e.g. bug, feature, refactor, good first issue)",
94+
},
95+
},
96+
}
97+
}
98+
7599
func (s *GeminiIssueContentGenerator) defaultGenerate(ctx context.Context, mName string, p string) (interface{}, *models.TokenUsage, error) {
76-
genConfig := GetGenerateConfig(mName, "", nil)
100+
schema := getIssueSchema()
101+
genConfig := GetGenerateConfig(mName, "application/json", schema)
77102
log := logger.FromContext(ctx)
78103

79104
resp, err := s.Client.Models.GenerateContent(ctx, mName, genai.Text(p), genConfig)
@@ -167,6 +192,7 @@ func (s *GeminiIssueContentGenerator) GenerateIssueContent(ctx context.Context,
167192
return nil, domainErrors.NewAppError(domainErrors.TypeAI, "error parsing AI response", err)
168193
}
169194

195+
result.Labels = CleanLabels(result.Labels, request.AvailableLabels)
170196
result.Usage = usage
171197

172198
log.Info("issue content generated successfully via gemini",
@@ -179,7 +205,7 @@ func (s *GeminiIssueContentGenerator) GenerateIssueContent(ctx context.Context,
179205
// buildIssuePrompt builds the prompt to generate issue content.
180206
func (s *GeminiIssueContentGenerator) buildIssuePrompt(request models.IssueGenerationRequest) string {
181207
if request.Description != "" && request.Diff == "" && request.Hint == "" &&
182-
request.Template == nil && len(request.ChangedFiles) == 0 {
208+
request.Template == nil && len(request.ChangedFiles) == 0 && len(request.AvailableLabels) == 0 {
183209
return request.Description
184210
}
185211

@@ -227,37 +253,8 @@ func (s *GeminiIssueContentGenerator) buildIssuePrompt(request models.IssueGener
227253
return ""
228254
}
229255

230-
if request.Template != nil {
231-
rendered += `
232-
233-
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
234-
235-
🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨
236-
237-
YOU MUST OUTPUT **ONLY** VALID JSON.
238-
239-
The template structure above should be used to FILL the "description" field with markdown content.
240-
241-
BUT your actual response MUST be a JSON object like this:
242-
{
243-
"title": "string here",
244-
"description": "markdown content following the template structure",
245-
"labels": ["array", "of", "strings"]
246-
}
247-
248-
❌ DO NOT output prose like "Here is a high-quality GitHub issue..."
249-
❌ DO NOT output markdown text directly
250-
❌ DO NOT output explanations
251-
252-
✅ ONLY output the JSON object
253-
✅ Use the template to structure the markdown in the "description" field
254-
✅ Return valid parseable JSON
255-
256-
BEGIN YOUR JSON OUTPUT NOW:`
257-
258-
logger.Debug(context.Background(), "full prompt with template and final JSON reminder",
259-
"prompt_length", len(rendered),
260-
"prompt", rendered)
256+
if len(request.AvailableLabels) > 0 {
257+
rendered += fmt.Sprintf("\n\nAvailable Labels (Select ONLY from this list):\n%s", strings.Join(request.AvailableLabels, ", "))
261258
}
262259

263260
return rendered
@@ -270,6 +267,8 @@ func (s *GeminiIssueContentGenerator) parseIssueResponse(content string) (*model
270267
return nil, domainErrors.NewAppError(domainErrors.TypeAI, "empty response from AI", nil)
271268
}
272269

270+
content = strings.TrimSpace(content)
271+
273272
if len(content) > 0 {
274273
preview := content
275274
if len(content) > 200 {
@@ -307,40 +306,12 @@ func (s *GeminiIssueContentGenerator) parseIssueResponse(content string) (*model
307306
result := &models.IssueGenerationResult{
308307
Title: strings.TrimSpace(jsonResult.Title),
309308
Description: strings.TrimSpace(jsonResult.Description),
310-
Labels: s.cleanLabels(jsonResult.Labels),
309+
Labels: jsonResult.Labels,
311310
}
312311

313312
if result.Title == "" {
314313
result.Title = "Generated Issue"
315314
}
316-
if result.Description == "" {
317-
result.Description = content
318-
}
319315

320316
return result, nil
321317
}
322-
323-
// cleanLabels cleans and validates labels, keeping only the allowed ones.
324-
func (s *GeminiIssueContentGenerator) cleanLabels(labels []string) []string {
325-
allowedLabels := map[string]bool{
326-
"feature": true,
327-
"fix": true,
328-
"refactor": true,
329-
"docs": true,
330-
"test": true,
331-
"infra": true,
332-
}
333-
334-
cleaned := make([]string, 0)
335-
seen := make(map[string]bool)
336-
337-
for _, label := range labels {
338-
trimmed := strings.TrimSpace(strings.ToLower(label))
339-
if trimmed != "" && allowedLabels[trimmed] && !seen[trimmed] {
340-
cleaned = append(cleaned, trimmed)
341-
seen[trimmed] = true
342-
}
343-
}
344-
345-
return cleaned
346-
}

internal/ai/gemini/issue_content_generator_test.go

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ func TestBuildIssuePrompt(t *testing.T) {
7171
},
7272
contains: []string{"Code Changes (git diff)", "user description", "special hint"},
7373
},
74+
{
75+
name: "with available labels",
76+
request: models.IssueGenerationRequest{
77+
Description: "user description",
78+
Language: "en",
79+
AvailableLabels: []string{"bug", "enhancement"},
80+
},
81+
contains: []string{"Available Labels", "bug, enhancement"},
82+
},
7483
}
7584

7685
for _, tt := range tests {
@@ -106,26 +115,6 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) {
106115

107116
// Should contain the template
108117
assert.Contains(t, prompt, "Bug Report")
109-
110-
// Should contain the final JSON reminder
111-
assert.Contains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨")
112-
assert.Contains(t, prompt, "YOU MUST OUTPUT **ONLY** VALID JSON")
113-
assert.Contains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:")
114-
115-
// Should contain instructions about using template in description field
116-
assert.Contains(t, prompt, "The template structure above should be used to FILL the \"description\" field")
117-
118-
// Should contain prohibitions
119-
assert.Contains(t, prompt, "❌ DO NOT output prose like \"Here is a high-quality GitHub issue...\"")
120-
assert.Contains(t, prompt, "❌ DO NOT output markdown text directly")
121-
122-
// Verify the reminder is at the end
123-
lastIndex := len(prompt) - 500
124-
if lastIndex < 0 {
125-
lastIndex = 0
126-
}
127-
finalSection := prompt[lastIndex:]
128-
assert.Contains(t, finalSection, "BEGIN YOUR JSON OUTPUT NOW:")
129118
})
130119

131120
t.Run("does NOT add final reminder when no template", func(t *testing.T) {
@@ -137,9 +126,8 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) {
137126

138127
prompt := gen.buildIssuePrompt(request)
139128

140-
// Should NOT contain the final JSON reminder
141-
assert.NotContains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨")
142-
assert.NotContains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:")
129+
// Verification is just that prompt exists and is relevant
130+
assert.Contains(t, prompt, "Code Changes")
143131
})
144132

145133
t.Run("includes template in Spanish", func(t *testing.T) {
@@ -159,10 +147,6 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) {
159147

160148
// Should contain the template
161149
assert.Contains(t, prompt, "Reporte de Bug")
162-
163-
// Should still contain the final JSON reminder (in English for consistency)
164-
assert.Contains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨")
165-
assert.Contains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:")
166150
})
167151

168152
t.Run("handles template with all fields", func(t *testing.T) {
@@ -188,28 +172,11 @@ func TestBuildIssuePrompt_WithTemplate(t *testing.T) {
188172
// Should contain changed files
189173
assert.Contains(t, prompt, "main.go")
190174
assert.Contains(t, prompt, "test.go")
191-
192-
// Should contain the final reminder
193-
assert.Contains(t, prompt, "🚨 FINAL REMINDER - CRITICAL OUTPUT REQUIREMENT 🚨")
194-
assert.Contains(t, prompt, "BEGIN YOUR JSON OUTPUT NOW:")
195175
})
196176

197177
t.Run("reminder contains complete JSON structure example", func(t *testing.T) {
198-
template := &models.IssueTemplate{
199-
Name: "Test Template",
200-
}
201-
202-
request := models.IssueGenerationRequest{
203-
Template: template,
204-
Language: "en",
205-
}
206-
207-
prompt := gen.buildIssuePrompt(request)
208-
209-
// Should show the expected JSON structure
210-
assert.Contains(t, prompt, `"title": "string here"`)
211-
assert.Contains(t, prompt, `"description": "markdown content following the template structure"`)
212-
assert.Contains(t, prompt, `"labels": ["array", "of", "strings"]`)
178+
// This test is now obsolete as structure is enforced by Schema, not prompt text.
179+
// We can remove it or just check nothing.
213180
})
214181
}
215182

@@ -268,33 +235,42 @@ func TestParseIssueResponse(t *testing.T) {
268235
}
269236

270237
func TestCleanLabels(t *testing.T) {
271-
gen := &GeminiIssueContentGenerator{}
272238

273239
tests := []struct {
274-
name string
275-
input []string
276-
expected []string
240+
name string
241+
input []string
242+
availableLabels []string
243+
expected []string
277244
}{
278245
{
279-
name: "only allowed labels",
280-
input: []string{"fix", "feature", "bug", "invalid"},
281-
expected: []string{"fix", "feature"},
246+
name: "default whitelist - allowed",
247+
input: []string{"fix", "feature", "bug", "invalid"},
248+
availableLabels: nil,
249+
expected: []string{"fix", "feature", "bug"},
250+
},
251+
{
252+
name: "default whitelist - mixed case",
253+
input: []string{" Fix ", "FEATURE", "test"},
254+
availableLabels: nil,
255+
expected: []string{"fix", "feature", "test"},
282256
},
283257
{
284-
name: "mixed case and spaces",
285-
input: []string{" Fix ", "FEATURE", "test"},
286-
expected: []string{"fix", "feature", "test"},
258+
name: "strict available labels",
259+
input: []string{"custom-1", "custom-2", "fix"},
260+
availableLabels: []string{"custom-1", "custom-2"},
261+
expected: []string{"custom-1", "custom-2"},
287262
},
288263
{
289-
name: "duplicates",
290-
input: []string{"fix", "fix", "FIX"},
291-
expected: []string{"fix"},
264+
name: "strict available labels - excludes non-existent",
265+
input: []string{"custom-1", "random"},
266+
availableLabels: []string{"custom-1"},
267+
expected: []string{"custom-1"},
292268
},
293269
}
294270

295271
for _, tt := range tests {
296272
t.Run(tt.name, func(t *testing.T) {
297-
result := gen.cleanLabels(tt.input)
273+
result := CleanLabels(tt.input, tt.availableLabels)
298274
assert.ElementsMatch(t, tt.expected, result)
299275
})
300276
}

internal/ai/gemini/pull_requests_summarizer_service.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,14 @@ func (gps *GeminiPRSummarizer) defaultGenerate(ctx context.Context, mName string
126126
return resp, usage, nil
127127
}
128128

129-
func (gps *GeminiPRSummarizer) GeneratePRSummary(ctx context.Context, prContent string) (models.PRSummary, error) {
129+
func (gps *GeminiPRSummarizer) GeneratePRSummary(ctx context.Context, prContent string, availableLabels []string) (models.PRSummary, error) {
130130
log := logger.FromContext(ctx)
131131

132132
log.Info("generating PR summary via gemini",
133-
"content_length", len(prContent))
133+
"content_length", len(prContent),
134+
"available_labels_count", len(availableLabels))
134135

135-
prompt := gps.generatePRPrompt(prContent)
136+
prompt := gps.generatePRPrompt(prContent, availableLabels)
136137

137138
log.Debug("calling gemini API for PR summary",
138139
"prompt_length", len(prompt))
@@ -199,12 +200,12 @@ func (gps *GeminiPRSummarizer) GeneratePRSummary(ctx context.Context, prContent
199200
return models.PRSummary{
200201
Title: jsonSummary.Title,
201202
Body: jsonSummary.Body,
202-
Labels: jsonSummary.Labels,
203+
Labels: CleanLabels(jsonSummary.Labels, availableLabels),
203204
Usage: usage,
204205
}, nil
205206
}
206207

207-
func (gps *GeminiPRSummarizer) generatePRPrompt(prContent string) string {
208+
func (gps *GeminiPRSummarizer) generatePRPrompt(prContent string, availableLabels []string) string {
208209
templateStr := ai.GetPRPromptTemplate(gps.config.Language)
209210
data := ai.PromptData{
210211
PRContent: prContent,
@@ -215,5 +216,9 @@ func (gps *GeminiPRSummarizer) generatePRPrompt(prContent string) string {
215216
return ""
216217
}
217218

219+
if len(availableLabels) > 0 {
220+
rendered += fmt.Sprintf("\n\nAvailable Labels (Select ONLY from this list):\n%s", strings.Join(availableLabels, ", "))
221+
}
222+
218223
return rendered
219224
}

0 commit comments

Comments
 (0)