Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ import (
"context"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
anthropicOption "github.com/anthropics/anthropic-sdk-go/option"
"github.com/openai/openai-go"
openaiOption "github.com/openai/openai-go/option"
"google.golang.org/genai"
)

const (
openAIAgent = "openai"
anthropicAIAgent = "anthropic"
geminiAIAgent = "gemini"
defaultMaxTokens = 1024
)

// Gemini models
Expand All @@ -21,25 +24,31 @@ const (
)

type Assister interface {
GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error)
GetTerminalCommand(ctx context.Context, userMessage string) (string, error)
}

var _ Assister = (*OpenAIAssister)(nil)
var _ Assister = (*AnthropicAIAssister)(nil)
var _ Assister = (*GeminiAIAssister)(nil)

type OpenAIAssister struct {
model string
AiParameters
}

func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) {
client := openai.NewClient()
func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) {
apiKey := o.AiParameters.apiKey
var client openai.Client
if apiKey != "" {
client = openai.NewClient(openaiOption.WithAPIKey(apiKey))
} else {
client = openai.NewClient()
}
chatCompletion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
openai.SystemMessage(systemMessage),
openai.SystemMessage(o.AiParameters.systemPrompt),
},
Model: o.model,
Model: o.AiParameters.model,
})
if err != nil {
return "", err
Expand All @@ -48,15 +57,25 @@ func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage str
}

type AnthropicAIAssister struct {
model string
AiParameters
}

func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) {
client := anthropic.NewClient() //defaults to os.LookupEnv("ANTHROPIC_API_KEY")
func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) {
apiKey := c.AiParameters.apiKey
maxTokens := c.AiParameters.maxTokens
if maxTokens == 0 {
maxTokens = defaultMaxTokens
}
var client anthropic.Client
if apiKey != "" {
client = anthropic.NewClient(anthropicOption.WithAPIKey(apiKey))
} else {
client = anthropic.NewClient()
}
message, err := client.Messages.New(ctx, anthropic.MessageNewParams{
MaxTokens: 1024,
MaxTokens: maxTokens,
System: []anthropic.TextBlockParam{
{Text: systemMessage},
{Text: c.AiParameters.systemPrompt},
},
Messages: []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock(userMessage)),
Expand All @@ -77,20 +96,22 @@ func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessag
}

type GeminiAIAssister struct {
model string
AiParameters
}

func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) {
func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) {
apiKey := g.AiParameters.apiKey
client, err := genai.NewClient(ctx, &genai.ClientConfig{
Backend: genai.BackendGeminiAPI,
APIKey: apiKey,
})
if err != nil {
return "", err
}
config := &genai.GenerateContentConfig{
SystemInstruction: &genai.Content{
Parts: []*genai.Part{
{Text: systemMessage},
{Text: g.AiParameters.systemPrompt},
},
},
}
Expand All @@ -106,38 +127,32 @@ func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage s
}

type AssisterCreator interface {
GetAssister(agent string, model string) (Assister, error)
GetAssister(parameters AiParameters) (Assister, error)
}

var _ AssisterCreator = (*defaultAIAssisterCreator)(nil)

type defaultAIAssisterCreator struct{}

func (d *defaultAIAssisterCreator) GetAssister(agent, model string) (Assister, error) {
func (d *defaultAIAssisterCreator) GetAssister(p AiParameters) (Assister, error) {
switch {
case agent == "" || agent == openAIAgent:
if model == "" {
model = openai.ChatModelGPT4o
case p.agent == "" || p.agent == openAIAgent:
if p.model == "" {
p.model = openai.ChatModelGPT4o
}
return &OpenAIAssister{
model: model,
}, nil
return &OpenAIAssister{p}, nil

case agent == anthropicAIAgent:
if model == "" {
model = string(anthropic.ModelClaude3_5HaikuLatest)
case p.agent == anthropicAIAgent:
if p.model == "" {
p.model = string(anthropic.ModelClaude3_5HaikuLatest)
}
return &AnthropicAIAssister{
model: model,
}, nil
case agent == geminiAIAgent:
if model == "" {
model = geminiFlashLite
return &AnthropicAIAssister{p}, nil
case p.agent == geminiAIAgent:
if p.model == "" {
p.model = geminiFlashLite
}
return &GeminiAIAssister{
model,
}, nil
return &GeminiAIAssister{p}, nil
default:
return nil, fmt.Errorf("cannot create AI agent for %s and model %s", agent, model)
return nil, fmt.Errorf("cannot create AI agent for %s and model %s", p.agent, p.model)
}
}
95 changes: 67 additions & 28 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
)

type assisterTestCase struct {
name string
inputAgent, inputModel string
expectedModel string
err string
name string
inputAiParameters AiParameters
expectedModel string
err string
}

func TestGetOpenAIAssister(t *testing.T) {
Expand All @@ -23,28 +23,38 @@ func TestGetOpenAIAssister(t *testing.T) {
expectedModel: openai.ChatModelGPT4o,
},
{
name: "OpenAI agent with given model should create correct assister",
inputAgent: openAIAgent,
inputModel: "gpt-5o-mini",
name: "OpenAI agent with given model should create correct assister",
inputAiParameters: AiParameters{
agent: openAIAgent,
model: "gpt-5o-mini",
apiKey: "key1234",
maxTokens: defaultMaxTokens,
},
expectedModel: "gpt-5o-mini",
},
{
name: "Unknown agent should result in error",
inputAgent: "Unknown agent",
inputModel: "unknown model",
err: "cannot create AI agent for Unknown agent and model unknown model",
name: "Unknown agent should result in error",
inputAiParameters: AiParameters{
agent: "Unknown agent",
model: "unknown model",
apiKey: "key1234",
maxTokens: defaultMaxTokens,
},
err: "cannot create AI agent for Unknown agent and model unknown model",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var creator defaultAIAssisterCreator
assister, err := creator.GetAssister(test.inputAgent, test.inputModel)
assister, err := creator.GetAssister(test.inputAiParameters)
if test.err != "" {
require.EqualError(t, err, test.err)
} else {
require.NoError(t, err)
if openAIAssister, ok := assister.(*OpenAIAssister); ok {
require.Equal(t, test.expectedModel, openAIAssister.model)
require.Equal(t, test.inputAiParameters.apiKey, openAIAssister.apiKey)
require.Equal(t, test.inputAiParameters.maxTokens, openAIAssister.maxTokens)
} else {
require.Fail(t, "Expected OpenAIAssister")
}
Expand All @@ -56,27 +66,37 @@ func TestGetOpenAIAssister(t *testing.T) {
func TestGetAnthropicAssister(t *testing.T) {
tests := []assisterTestCase{
{
name: "Anthropic agent with no model should default to Claude 3.5 Haiku latest",
inputAgent: anthropicAIAgent,
name: "Anthropic agent with no model should default to Claude 3.5 Haiku latest",
inputAiParameters: AiParameters{
agent: anthropicAIAgent,
apiKey: "key1234",
maxTokens: defaultMaxTokens,
},
expectedModel: string(anthropic.ModelClaude3_5HaikuLatest),
},
{
name: "Anthropic agent with given model should create correct assister",
inputAgent: anthropicAIAgent,
inputModel: string(anthropic.ModelClaudeOpus4_0),
name: "Anthropic agent with given model should create correct assister",
inputAiParameters: AiParameters{
agent: anthropicAIAgent,
model: string(anthropic.ModelClaudeOpus4_0),
apiKey: "key1234",
maxTokens: defaultMaxTokens,
},
expectedModel: string(anthropic.ModelClaudeOpus4_0),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var creator defaultAIAssisterCreator
assister, err := creator.GetAssister(test.inputAgent, test.inputModel)
assister, err := creator.GetAssister(test.inputAiParameters)
if test.err != "" {
require.EqualError(t, err, test.err)
} else {
require.NoError(t, err)
if anthropicAssister, ok := assister.(*AnthropicAIAssister); ok {
require.Equal(t, test.expectedModel, anthropicAssister.model)
require.Equal(t, test.inputAiParameters.apiKey, anthropicAssister.apiKey)
require.Equal(t, test.inputAiParameters.maxTokens, anthropicAssister.maxTokens)
} else {
require.Fail(t, "Expected Anthropic AI assister")
}
Expand All @@ -88,27 +108,37 @@ func TestGetAnthropicAssister(t *testing.T) {
func TestGetGeminiAssister(t *testing.T) {
tests := []assisterTestCase{
{
name: "Gemini agent with no model should default to gemini-2.5-flash-lite",
inputAgent: geminiAIAgent,
name: "Gemini agent with no model should default to gemini-2.5-flash-lite",
inputAiParameters: AiParameters{
agent: geminiAIAgent,
apiKey: "key1234",
maxTokens: defaultMaxTokens,
},
expectedModel: geminiFlashLite,
},
{
name: "Gemini agent with given model should create correct assister",
inputAgent: geminiAIAgent,
inputModel: "gemini-2.5-flash-preview-tts",
name: "Gemini agent with given model should create correct assister",
inputAiParameters: AiParameters{
agent: geminiAIAgent,
model: "gemini-2.5-flash-preview-tts",
apiKey: "key1234",
maxTokens: defaultMaxTokens,
},
expectedModel: "gemini-2.5-flash-preview-tts",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var creator defaultAIAssisterCreator
assister, err := creator.GetAssister(test.inputAgent, test.inputModel)
assister, err := creator.GetAssister(test.inputAiParameters)
if test.err != "" {
require.EqualError(t, err, test.err)
} else {
require.NoError(t, err)
if geminiAssister, ok := assister.(*GeminiAIAssister); ok {
require.Equal(t, test.expectedModel, geminiAssister.model)
require.Equal(t, test.inputAiParameters.apiKey, geminiAssister.apiKey)
require.Equal(t, test.inputAiParameters.maxTokens, geminiAssister.maxTokens)
} else {
require.Fail(t, "Expected Gemini AI assister")
}
Expand Down Expand Up @@ -145,18 +175,27 @@ func TestGetTerminalCommand(t *testing.T) {
switch test.agent {
case openAIAgent:
assister = &OpenAIAssister{
model: openai.ChatModelGPT4o,
AiParameters{
model: openai.ChatModelGPT4o,
systemPrompt: systemPrompt,
},
}
case anthropicAIAgent:
assister = &AnthropicAIAssister{
model: string(anthropic.ModelClaude3_5HaikuLatest),
AiParameters{
model: string(anthropic.ModelClaude3_5HaikuLatest),
systemPrompt: systemPrompt,
},
}
case geminiAIAgent:
assister = &GeminiAIAssister{
model: geminiFlash,
AiParameters{
model: geminiFlash,
systemPrompt: systemPrompt,
},
}
}
command, err := assister.GetTerminalCommand(t.Context(), "I want to list all files in current directory", systemPrompt)
command, err := assister.GetTerminalCommand(t.Context(), "I want to list all files in current directory")
require.NoError(t, err)
require.Contains(t, command, "ls")
})
Expand Down
Loading