Skip to content

Commit 331e5d0

Browse files
authored
Merge pull request #68 from thomas-vilte/dev
feat(gemini): enforce JSON schema for structured AI output
2 parents 6699d85 + 5267fdc commit 331e5d0

14 files changed

Lines changed: 192 additions & 1019 deletions

internal/ai/gemini/commit_summarizer_service.go

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,71 @@ type (
4646
}
4747
)
4848

49+
// getCommitSuggestionSchema returns the JSON schema for commit suggestions
50+
func getCommitSuggestionSchema() *genai.Schema {
51+
return &genai.Schema{
52+
Type: genai.TypeArray,
53+
Items: &genai.Schema{
54+
Type: genai.TypeObject,
55+
Required: []string{"title", "desc", "files"},
56+
Properties: map[string]*genai.Schema{
57+
"title": {
58+
Type: genai.TypeString,
59+
Description: "Commit title (type(scope): message)",
60+
},
61+
"desc": {
62+
Type: genai.TypeString,
63+
Description: "Detailed explanation in first person",
64+
},
65+
"files": {
66+
Type: genai.TypeArray,
67+
Items: &genai.Schema{
68+
Type: genai.TypeString,
69+
},
70+
Description: "Array of file paths as strings",
71+
},
72+
"analysis": {
73+
Type: genai.TypeObject,
74+
Required: []string{"overview", "purpose", "impact"},
75+
Properties: map[string]*genai.Schema{
76+
"overview": {Type: genai.TypeString},
77+
"purpose": {Type: genai.TypeString},
78+
"impact": {Type: genai.TypeString},
79+
},
80+
},
81+
"requirements": {
82+
Type: genai.TypeObject,
83+
Required: []string{"status", "missing", "completed_indices", "suggestions"},
84+
Properties: map[string]*genai.Schema{
85+
"status": {
86+
Type: genai.TypeString,
87+
Enum: []string{"full_met", "partially_met", "not_met"},
88+
},
89+
"missing": {
90+
Type: genai.TypeArray,
91+
Items: &genai.Schema{
92+
Type: genai.TypeString,
93+
},
94+
},
95+
"completed_indices": {
96+
Type: genai.TypeArray,
97+
Items: &genai.Schema{
98+
Type: genai.TypeInteger,
99+
},
100+
},
101+
"suggestions": {
102+
Type: genai.TypeArray,
103+
Items: &genai.Schema{
104+
Type: genai.TypeString,
105+
},
106+
},
107+
},
108+
},
109+
},
110+
},
111+
}
112+
}
113+
49114
func NewGeminiCommitSummarizer(ctx context.Context, cfg *config.Config, onConfirmation ai.ConfirmationCallback) (*GeminiCommitSummarizer, error) {
50115
providerCfg, exists := cfg.AIProviders["gemini"]
51116
if !exists || providerCfg.APIKey == "" {
@@ -97,13 +162,11 @@ func NewGeminiCommitSummarizer(ctx context.Context, cfg *config.Config, onConfir
97162

98163
func (s *GeminiCommitSummarizer) defaultGenerate(ctx context.Context, mName string, p string) (interface{}, *models.TokenUsage, error) {
99164
log := logger.FromContext(ctx)
100-
101165
log.Debug("calling gemini API",
102166
"model", mName,
103167
"prompt_length", len(p))
104-
105-
genConfig := GetGenerateConfig(mName, "application/json")
106-
168+
schema := getCommitSuggestionSchema()
169+
genConfig := GetGenerateConfig(mName, "application/json", schema)
107170
resp, err := s.Client.Models.GenerateContent(ctx, mName, genai.Text(p), genConfig)
108171
if err != nil {
109172
log.Error("gemini API call failed",
@@ -116,23 +179,24 @@ func (s *GeminiCommitSummarizer) defaultGenerate(ctx context.Context, mName stri
116179
strings.Contains(errMsg, "resource exhausted") {
117180
return nil, nil, domainErrors.ErrGeminiQuotaExceeded.WithError(err)
118181
}
119-
120182
if strings.Contains(errMsg, "invalid") ||
121183
strings.Contains(errMsg, "unauthorized") ||
122184
strings.Contains(errMsg, "api key") {
123185
return nil, nil, domainErrors.ErrGeminiAPIKeyInvalid.WithError(err)
124186
}
125-
126187
return nil, nil, domainErrors.ErrAIGeneration.WithError(err)
127188
}
128-
129189
usage := extractUsage(resp)
130-
131-
log.Debug("gemini API response received",
132-
"input_tokens", usage.InputTokens,
133-
"output_tokens", usage.OutputTokens,
134-
"candidates", len(resp.Candidates))
135-
190+
if usage != nil {
191+
log.Debug("gemini API response received",
192+
"input_tokens", usage.InputTokens,
193+
"output_tokens", usage.OutputTokens,
194+
"candidates", len(resp.Candidates))
195+
} else {
196+
log.Debug("gemini API response received",
197+
"candidates", len(resp.Candidates),
198+
"usage", "nil")
199+
}
136200
return resp, usage, nil
137201
}
138202

@@ -216,42 +280,33 @@ func (s *GeminiCommitSummarizer) parseSuggestionsJSON(responseText string) ([]mo
216280
if responseText == "" {
217281
return nil, fmt.Errorf("empty response text from AI")
218282
}
219-
220-
responseText = ExtractJSON(responseText)
221-
222283
var jsonSuggestions []CommitSuggestionJSON
223284
if err := json.Unmarshal([]byte(responseText), &jsonSuggestions); err != nil {
224-
// Log at default level (no context available here)
225285
return nil, fmt.Errorf("error parsing JSON: %w", err)
226286
}
227-
228287
suggestions := make([]models.CommitSuggestion, 0, len(jsonSuggestions))
229288
for _, js := range jsonSuggestions {
230289
suggestion := models.CommitSuggestion{
231290
CommitTitle: js.Title,
232291
Explanation: js.Desc,
233292
Files: js.Files,
234293
}
235-
236294
if js.Analysis != nil {
237295
suggestion.CodeAnalysis = models.CodeAnalysis{
238296
ChangesOverview: js.Analysis.OverView,
239297
PrimaryPurpose: js.Analysis.Purpose,
240298
TechnicalImpact: js.Analysis.Impact,
241299
}
242300
}
243-
244301
if js.Requirements != nil {
245302
suggestion.RequirementsAnalysis = models.RequirementsAnalysis{
246303
CriteriaStatus: models.CriteriaStatus(js.Requirements.Status),
247304
MissingCriteria: js.Requirements.Missing,
248305
ImprovementSuggestions: js.Requirements.Suggestions,
249306
}
250307
}
251-
252308
suggestions = append(suggestions, suggestion)
253309
}
254-
255310
return suggestions, nil
256311
}
257312

internal/ai/gemini/commit_summarizer_service_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ func TestGeminiCommitSummarizer(t *testing.T) {
207207
// assert
208208
assert.Contains(t, prompt, "commit", "El prompt debería contener 'commit'")
209209
assert.Contains(t, prompt, "Archivos Modificados", "El prompt debería contener 'Archivos modificados'")
210-
assert.Contains(t, prompt, "Explicación", "El prompt debería contener 'Explicación'")
211210
assert.Contains(t, prompt, "feat", "El prompt debería contener tipos de commit")
212211
assert.Contains(t, prompt, "fix", "El prompt debería contener tipos de commit")
213212
assert.Contains(t, prompt, "refactor", "El prompt debería contener tipos de commit")
@@ -237,7 +236,6 @@ func TestGeminiCommitSummarizer(t *testing.T) {
237236
// assert
238237
assert.Contains(t, prompt, "commit", "The prompt should contain 'commit'")
239238
assert.Contains(t, prompt, "Modified Files", "The prompt should contain 'Modified files'")
240-
assert.Contains(t, prompt, "explanation", "The prompt should contain 'Explanation'")
241239
assert.Contains(t, prompt, "feat", "The prompt should contain commit types")
242240
assert.Contains(t, prompt, "fix", "The prompt should contain commit types")
243241
assert.Contains(t, prompt, "refactor", "The prompt should contain commit types")

internal/ai/gemini/helper.go

Lines changed: 4 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package gemini
22

33
import (
4-
"encoding/json"
54
"strings"
65

76
"github.com/thomas-vilte/matecommit/internal/models"
8-
"github.com/thomas-vilte/matecommit/internal/regex"
97
"google.golang.org/genai"
108
)
119

@@ -22,7 +20,7 @@ func extractUsage(resp *genai.GenerateContentResponse) *models.TokenUsage {
2220
}
2321

2422
// GetGenerateConfig returns the optimal configuration for the model, enabling Thinking Mode if compatible.
25-
func GetGenerateConfig(modelName string, responseType string) *genai.GenerateContentConfig {
23+
func GetGenerateConfig(modelName string, responseType string, schema *genai.Schema) *genai.GenerateContentConfig {
2624
config := &genai.GenerateContentConfig{
2725
Temperature: float32Ptr(0.3),
2826
MaxOutputTokens: int32(10000),
@@ -31,6 +29,9 @@ func GetGenerateConfig(modelName string, responseType string) *genai.GenerateCon
3129

3230
if responseType == "application/json" {
3331
config.ResponseMIMEType = "application/json"
32+
if schema != nil {
33+
config.ResponseJsonSchema = schema
34+
}
3435
}
3536

3637
if strings.HasPrefix(modelName, "gemini-3") {
@@ -43,108 +44,6 @@ func GetGenerateConfig(modelName string, responseType string) *genai.GenerateCon
4344
return config
4445
}
4546

46-
// ExtractJSON attempts to extract a valid JSON block from text, handling Markdown code blocks
47-
// and possible extra text that models with "Thinking" mode might generate.
48-
func ExtractJSON(text string) string {
49-
text = strings.TrimSpace(text)
50-
51-
matches := regex.MarkdownJSONBlock.FindAllStringSubmatch(text, -1)
52-
var bestMarkdown string
53-
for _, m := range matches {
54-
if len(m) > 1 {
55-
content := strings.TrimSpace(m[1])
56-
sanitized := SanitizeJSON(content)
57-
if json.Valid([]byte(sanitized)) {
58-
if len(sanitized) > len(bestMarkdown) {
59-
bestMarkdown = sanitized
60-
}
61-
}
62-
}
63-
}
64-
if bestMarkdown != "" {
65-
return bestMarkdown
66-
}
67-
68-
var bestBlock string
69-
for i := 0; i < len(text); {
70-
startIdx := strings.IndexAny(text[i:], "{[")
71-
if startIdx == -1 {
72-
break
73-
}
74-
startIdx += i
75-
76-
opener := text[startIdx]
77-
var closer byte
78-
if opener == '{' {
79-
closer = '}'
80-
} else {
81-
closer = ']'
82-
}
83-
84-
count := 0
85-
inString := false
86-
escaped := false
87-
foundEnd := false
88-
endIdx := -1
89-
90-
for j := startIdx; j < len(text); j++ {
91-
char := text[j]
92-
if escaped {
93-
escaped = false
94-
continue
95-
}
96-
if char == '\\' {
97-
escaped = true
98-
continue
99-
}
100-
if char == '"' {
101-
inString = !inString
102-
continue
103-
}
104-
105-
if !inString {
106-
if char == opener {
107-
count++
108-
} else if char == closer {
109-
count--
110-
if count == 0 {
111-
foundEnd = true
112-
endIdx = j
113-
break
114-
}
115-
}
116-
}
117-
}
118-
119-
if foundEnd {
120-
block := text[startIdx : endIdx+1]
121-
sanitized := SanitizeJSON(block)
122-
if json.Valid([]byte(sanitized)) {
123-
if len(sanitized) > len(bestBlock) {
124-
bestBlock = sanitized
125-
}
126-
}
127-
i = endIdx + 1
128-
} else {
129-
i = startIdx + 1
130-
}
131-
}
132-
133-
if bestBlock != "" {
134-
return bestBlock
135-
}
136-
137-
return SanitizeJSON(text)
138-
}
139-
140-
// SanitizeJSON cleans malformed JSON that LLMs sometimes generate,
141-
// such as unescaped newlines within String Literals.
142-
func SanitizeJSON(s string) string {
143-
return regex.JSONString.ReplaceAllStringFunc(s, func(m string) string {
144-
return strings.ReplaceAll(m, "\n", "\\n")
145-
})
146-
}
147-
14847
func float32Ptr(f float32) *float32 {
14948
return &f
15049
}

0 commit comments

Comments
 (0)