diff --git a/agent.go b/agent.go index 5a8c801..04cf8e3 100644 --- a/agent.go +++ b/agent.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "github.com/anthropics/anthropic-sdk-go" + anthropicOption "github.com/anthropics/anthropic-sdk-go/option" "github.com/openai/openai-go" + openaiOption "github.com/openai/openai-go/option" "google.golang.org/genai" ) @@ -12,6 +14,7 @@ const ( openAIAgent = "openai" anthropicAIAgent = "anthropic" geminiAIAgent = "gemini" + defaultMaxTokens = 1024 ) // Gemini models @@ -21,7 +24,7 @@ const ( ) type Assister interface { - GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) + GetTerminalCommand(ctx context.Context, userMessage string) (string, error) } var _ Assister = (*OpenAIAssister)(nil) @@ -29,17 +32,23 @@ var _ Assister = (*AnthropicAIAssister)(nil) var _ Assister = (*GeminiAIAssister)(nil) type OpenAIAssister struct { - model string + AiParameters } -func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) { - client := openai.NewClient() +func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { + apiKey := o.AiParameters.apiKey + var client openai.Client + if apiKey != "" { + client = openai.NewClient(openaiOption.WithAPIKey(apiKey)) + } else { + client = openai.NewClient() + } chatCompletion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ Messages: []openai.ChatCompletionMessageParamUnion{ openai.UserMessage(userMessage), - openai.SystemMessage(systemMessage), + openai.SystemMessage(o.AiParameters.systemPrompt), }, - Model: o.model, + Model: o.AiParameters.model, }) if err != nil { return "", err @@ -48,15 +57,25 @@ func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage str } type AnthropicAIAssister struct { - model string + AiParameters } -func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) { - client := anthropic.NewClient() //defaults to os.LookupEnv("ANTHROPIC_API_KEY") +func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { + apiKey := c.AiParameters.apiKey + maxTokens := c.AiParameters.maxTokens + if maxTokens == 0 { + maxTokens = defaultMaxTokens + } + var client anthropic.Client + if apiKey != "" { + client = anthropic.NewClient(anthropicOption.WithAPIKey(apiKey)) + } else { + client = anthropic.NewClient() + } message, err := client.Messages.New(ctx, anthropic.MessageNewParams{ - MaxTokens: 1024, + MaxTokens: maxTokens, System: []anthropic.TextBlockParam{ - {Text: systemMessage}, + {Text: c.AiParameters.systemPrompt}, }, Messages: []anthropic.MessageParam{ anthropic.NewUserMessage(anthropic.NewTextBlock(userMessage)), @@ -77,12 +96,14 @@ func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessag } type GeminiAIAssister struct { - model string + AiParameters } -func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) { +func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { + apiKey := g.AiParameters.apiKey client, err := genai.NewClient(ctx, &genai.ClientConfig{ Backend: genai.BackendGeminiAPI, + APIKey: apiKey, }) if err != nil { return "", err @@ -90,7 +111,7 @@ func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage s config := &genai.GenerateContentConfig{ SystemInstruction: &genai.Content{ Parts: []*genai.Part{ - {Text: systemMessage}, + {Text: g.AiParameters.systemPrompt}, }, }, } @@ -106,38 +127,32 @@ func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage s } type AssisterCreator interface { - GetAssister(agent string, model string) (Assister, error) + GetAssister(parameters AiParameters) (Assister, error) } var _ AssisterCreator = (*defaultAIAssisterCreator)(nil) type defaultAIAssisterCreator struct{} -func (d *defaultAIAssisterCreator) GetAssister(agent, model string) (Assister, error) { +func (d *defaultAIAssisterCreator) GetAssister(p AiParameters) (Assister, error) { switch { - case agent == "" || agent == openAIAgent: - if model == "" { - model = openai.ChatModelGPT4o + case p.agent == "" || p.agent == openAIAgent: + if p.model == "" { + p.model = openai.ChatModelGPT4o } - return &OpenAIAssister{ - model: model, - }, nil + return &OpenAIAssister{p}, nil - case agent == anthropicAIAgent: - if model == "" { - model = string(anthropic.ModelClaude3_5HaikuLatest) + case p.agent == anthropicAIAgent: + if p.model == "" { + p.model = string(anthropic.ModelClaude3_5HaikuLatest) } - return &AnthropicAIAssister{ - model: model, - }, nil - case agent == geminiAIAgent: - if model == "" { - model = geminiFlashLite + return &AnthropicAIAssister{p}, nil + case p.agent == geminiAIAgent: + if p.model == "" { + p.model = geminiFlashLite } - return &GeminiAIAssister{ - model, - }, nil + return &GeminiAIAssister{p}, nil default: - return nil, fmt.Errorf("cannot create AI agent for %s and model %s", agent, model) + return nil, fmt.Errorf("cannot create AI agent for %s and model %s", p.agent, p.model) } } diff --git a/agent_test.go b/agent_test.go index 60ce144..813d2c8 100644 --- a/agent_test.go +++ b/agent_test.go @@ -10,10 +10,10 @@ import ( ) type assisterTestCase struct { - name string - inputAgent, inputModel string - expectedModel string - err string + name string + inputAiParameters AiParameters + expectedModel string + err string } func TestGetOpenAIAssister(t *testing.T) { @@ -23,28 +23,38 @@ func TestGetOpenAIAssister(t *testing.T) { expectedModel: openai.ChatModelGPT4o, }, { - name: "OpenAI agent with given model should create correct assister", - inputAgent: openAIAgent, - inputModel: "gpt-5o-mini", + name: "OpenAI agent with given model should create correct assister", + inputAiParameters: AiParameters{ + agent: openAIAgent, + model: "gpt-5o-mini", + apiKey: "key1234", + maxTokens: defaultMaxTokens, + }, expectedModel: "gpt-5o-mini", }, { - name: "Unknown agent should result in error", - inputAgent: "Unknown agent", - inputModel: "unknown model", - err: "cannot create AI agent for Unknown agent and model unknown model", + name: "Unknown agent should result in error", + inputAiParameters: AiParameters{ + agent: "Unknown agent", + model: "unknown model", + apiKey: "key1234", + maxTokens: defaultMaxTokens, + }, + err: "cannot create AI agent for Unknown agent and model unknown model", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var creator defaultAIAssisterCreator - assister, err := creator.GetAssister(test.inputAgent, test.inputModel) + assister, err := creator.GetAssister(test.inputAiParameters) if test.err != "" { require.EqualError(t, err, test.err) } else { require.NoError(t, err) if openAIAssister, ok := assister.(*OpenAIAssister); ok { require.Equal(t, test.expectedModel, openAIAssister.model) + require.Equal(t, test.inputAiParameters.apiKey, openAIAssister.apiKey) + require.Equal(t, test.inputAiParameters.maxTokens, openAIAssister.maxTokens) } else { require.Fail(t, "Expected OpenAIAssister") } @@ -56,27 +66,37 @@ func TestGetOpenAIAssister(t *testing.T) { func TestGetAnthropicAssister(t *testing.T) { tests := []assisterTestCase{ { - name: "Anthropic agent with no model should default to Claude 3.5 Haiku latest", - inputAgent: anthropicAIAgent, + name: "Anthropic agent with no model should default to Claude 3.5 Haiku latest", + inputAiParameters: AiParameters{ + agent: anthropicAIAgent, + apiKey: "key1234", + maxTokens: defaultMaxTokens, + }, expectedModel: string(anthropic.ModelClaude3_5HaikuLatest), }, { - name: "Anthropic agent with given model should create correct assister", - inputAgent: anthropicAIAgent, - inputModel: string(anthropic.ModelClaudeOpus4_0), + name: "Anthropic agent with given model should create correct assister", + inputAiParameters: AiParameters{ + agent: anthropicAIAgent, + model: string(anthropic.ModelClaudeOpus4_0), + apiKey: "key1234", + maxTokens: defaultMaxTokens, + }, expectedModel: string(anthropic.ModelClaudeOpus4_0), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var creator defaultAIAssisterCreator - assister, err := creator.GetAssister(test.inputAgent, test.inputModel) + assister, err := creator.GetAssister(test.inputAiParameters) if test.err != "" { require.EqualError(t, err, test.err) } else { require.NoError(t, err) if anthropicAssister, ok := assister.(*AnthropicAIAssister); ok { require.Equal(t, test.expectedModel, anthropicAssister.model) + require.Equal(t, test.inputAiParameters.apiKey, anthropicAssister.apiKey) + require.Equal(t, test.inputAiParameters.maxTokens, anthropicAssister.maxTokens) } else { require.Fail(t, "Expected Anthropic AI assister") } @@ -88,27 +108,37 @@ func TestGetAnthropicAssister(t *testing.T) { func TestGetGeminiAssister(t *testing.T) { tests := []assisterTestCase{ { - name: "Gemini agent with no model should default to gemini-2.5-flash-lite", - inputAgent: geminiAIAgent, + name: "Gemini agent with no model should default to gemini-2.5-flash-lite", + inputAiParameters: AiParameters{ + agent: geminiAIAgent, + apiKey: "key1234", + maxTokens: defaultMaxTokens, + }, expectedModel: geminiFlashLite, }, { - name: "Gemini agent with given model should create correct assister", - inputAgent: geminiAIAgent, - inputModel: "gemini-2.5-flash-preview-tts", + name: "Gemini agent with given model should create correct assister", + inputAiParameters: AiParameters{ + agent: geminiAIAgent, + model: "gemini-2.5-flash-preview-tts", + apiKey: "key1234", + maxTokens: defaultMaxTokens, + }, expectedModel: "gemini-2.5-flash-preview-tts", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { var creator defaultAIAssisterCreator - assister, err := creator.GetAssister(test.inputAgent, test.inputModel) + assister, err := creator.GetAssister(test.inputAiParameters) if test.err != "" { require.EqualError(t, err, test.err) } else { require.NoError(t, err) if geminiAssister, ok := assister.(*GeminiAIAssister); ok { require.Equal(t, test.expectedModel, geminiAssister.model) + require.Equal(t, test.inputAiParameters.apiKey, geminiAssister.apiKey) + require.Equal(t, test.inputAiParameters.maxTokens, geminiAssister.maxTokens) } else { require.Fail(t, "Expected Gemini AI assister") } @@ -145,18 +175,27 @@ func TestGetTerminalCommand(t *testing.T) { switch test.agent { case openAIAgent: assister = &OpenAIAssister{ - model: openai.ChatModelGPT4o, + AiParameters{ + model: openai.ChatModelGPT4o, + systemPrompt: systemPrompt, + }, } case anthropicAIAgent: assister = &AnthropicAIAssister{ - model: string(anthropic.ModelClaude3_5HaikuLatest), + AiParameters{ + model: string(anthropic.ModelClaude3_5HaikuLatest), + systemPrompt: systemPrompt, + }, } case geminiAIAgent: assister = &GeminiAIAssister{ - model: geminiFlash, + AiParameters{ + model: geminiFlash, + systemPrompt: systemPrompt, + }, } } - command, err := assister.GetTerminalCommand(t.Context(), "I want to list all files in current directory", systemPrompt) + command, err := assister.GetTerminalCommand(t.Context(), "I want to list all files in current directory") require.NoError(t, err) require.Contains(t, command, "ls") }) diff --git a/gromit.go b/gromit.go index 49f93be..d492017 100644 --- a/gromit.go +++ b/gromit.go @@ -27,13 +27,6 @@ type Gromit struct { *configuration } -type systemInfo struct { - operatingSystem string - currentShell string - delimiter string - kernelInfo string -} - func getSystemInfo() systemInfo { o := runtime.GOOS var eol, shell, kernelInfo string @@ -57,25 +50,10 @@ func getSystemInfo() systemInfo { } } -type messagePrinter struct { - w io.Writer - promptPrefix string - delimiter string -} - -type configuration struct { - promptPrefix string - w io.Writer - askForConfirmation bool - systemInfo -} - func (m *messagePrinter) print(s string) { fmt.Fprintf(m.w, "%s %s %s", m.promptPrefix, s, m.delimiter) } -type ConfigurationModifier func(*configuration) error - func WithPromptPrefix(prefix string) ConfigurationModifier { return func(c *configuration) error { c.promptPrefix = prefix @@ -104,6 +82,18 @@ func (g *Gromit) actionGromit(ctx context.Context, command *cli.Command) error { g.print("Please run ./gromit --help to see usage") return nil } + prompt := g.String("systemPrompt") + if prompt == "" { + prompt = systemPrompt + } + prompt = addEnvironmentInfo(g.configuration.systemInfo, prompt) + g.configuration.AiParameters = AiParameters{ + maxTokens: g.Int64("maxTokens"), + apiKey: g.String("apiKey"), + agent: g.String("agent"), + model: g.String("model"), + systemPrompt: prompt, + } err := g.handleUserQuery(ctx, query) if err != nil { return err @@ -131,16 +121,11 @@ func (g *Gromit) actionGromit(ctx context.Context, command *cli.Command) error { } func (g *Gromit) handleUserQuery(ctx context.Context, query string) error { - assister, err := g.AssisterCreator.GetAssister(g.String("agent"), g.String("model")) + assister, err := g.AssisterCreator.GetAssister(g.configuration.AiParameters) if err != nil { return err } - prompt := g.String("systemPrompt") - if prompt == "" { - prompt = systemPrompt - } - prompt = addEnvironmentInfo(g.configuration.systemInfo, prompt) - exeCommand, err := assister.GetTerminalCommand(ctx, query, prompt) + exeCommand, err := assister.GetTerminalCommand(ctx, query) if err != nil { return err } @@ -175,10 +160,6 @@ func addEnvironmentInfo(systemInfo systemInfo, systemPrompt string) string { return result } -type userConfirmation struct { - confirmed bool -} - func (g *Gromit) askConfirmation(message string) (userConfirmation, error) { if !g.configuration.askForConfirmation { return userConfirmation{ @@ -257,12 +238,21 @@ func NewGromit(a AssisterCreator, mods ...ConfigurationModifier) (*Gromit, error Name: "systemPrompt", Usage: "The system prompt for the AI agent. Defaults to command line helper in a linux environment.", }, + &cli.StringFlag{ + Name: "apiKey", + Usage: "The API key to use for given AI agent. By default it is read from environment variables.", + }, + &cli.Int64Flag{ + Name: "maxTokens", + Usage: "Maximum number of tokens for AI agents to generate", + }, } config := configuration{ promptPrefix: "⚡️🐶", w: os.Stdout, askForConfirmation: true, systemInfo: getSystemInfo(), + AiParameters: AiParameters{}, } gromit := Gromit{ AssisterCreator: a, diff --git a/gromit_test.go b/gromit_test.go index 2470d52..84c6371 100644 --- a/gromit_test.go +++ b/gromit_test.go @@ -80,18 +80,22 @@ func TestAIAssisterFindingCorrectCommand(t *testing.T) { g, err := NewGromit(m, WithWriter(&buff), WithPromptPrefix("🐶"), WithAskForConfirmation(false)) require.NoError(t, err) - g.Run(t.Context(), []string{"gromit", "--model", "myModel", "--agent", "myAgent", "--systemPrompt", "myPrompt", "I", "want", "to", "list", "all", "files", "in", "current", "directory"}) + g.Run(t.Context(), []string{"gromit", "--model", "myModel", "--agent", "myAgent", + "--apiKey=key1234", "--maxTokens=2000", + "--systemPrompt", "myPrompt", "I", "want", "to", "list", "all", "files", "in", "current", "directory"}) result := buff.String() require.Contains(t, result, "🐶 In order to do that, you need to run") require.Contains(t, result, "🐶 ls") require.Contains(t, result, "README.md") require.Contains(t, result, "🐶 How can I help?") - require.Equal(t, "myAgent", m.actualAgent) - require.Equal(t, "myModel", m.actualModel) - require.Contains(t, m.actualSystemMessage, "myPrompt") - require.Contains(t, m.actualSystemMessage, "User's operating system is") - require.Contains(t, m.actualSystemMessage, "User's current shell is") + require.Equal(t, "myAgent", m.actualAiParameters.agent) + require.Equal(t, "myModel", m.actualAiParameters.model) + require.Equal(t, "key1234", m.actualAiParameters.apiKey) + require.Equal(t, int64(2000), m.actualAiParameters.maxTokens) + require.Contains(t, m.actualAiParameters.systemPrompt, "myPrompt") + require.Contains(t, m.actualAiParameters.systemPrompt, "User's operating system is") + require.Contains(t, m.actualAiParameters.systemPrompt, "User's current shell is") require.Equal(t, "I want to list all files in current directory", m.actualUserMessage) } @@ -100,23 +104,20 @@ type mockAIProvider struct { commandError error commandResult string - actualAgent string - actualModel string - actualSystemMessage string - actualUserMessage string + actualUserMessage string + + actualAiParameters AiParameters } -func (m *mockAIProvider) GetAssister(agent string, model string) (Assister, error) { - m.actualAgent = agent - m.actualModel = model +func (m *mockAIProvider) GetAssister(p AiParameters) (Assister, error) { + m.actualAiParameters = p if m.assisterError != nil { return nil, m.assisterError } return m, nil } -func (m *mockAIProvider) GetTerminalCommand(ctx context.Context, userMessage string, systemMessage string) (string, error) { - m.actualSystemMessage = systemMessage +func (m *mockAIProvider) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { m.actualUserMessage = userMessage if m.commandError != nil { return "", m.commandError diff --git a/types.go b/types.go new file mode 100644 index 0000000..5ffa30c --- /dev/null +++ b/types.go @@ -0,0 +1,38 @@ +package main + +import "io" + +type systemInfo struct { + operatingSystem string + currentShell string + delimiter string + kernelInfo string +} + +type messagePrinter struct { + w io.Writer + promptPrefix string + delimiter string +} + +type configuration struct { + AiParameters + promptPrefix string + w io.Writer + askForConfirmation bool + systemInfo +} + +type userConfirmation struct { + confirmed bool +} + +type ConfigurationModifier func(*configuration) error + +type AiParameters struct { + systemPrompt string + agent string + model string + apiKey string + maxTokens int64 +}