diff --git a/agent.go b/agent.go index 04cf8e3..213e0cd 100644 --- a/agent.go +++ b/agent.go @@ -24,7 +24,7 @@ const ( ) type Assister interface { - GetTerminalCommand(ctx context.Context, userMessage string) (string, error) + GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) } var _ Assister = (*OpenAIAssister)(nil) @@ -35,7 +35,7 @@ type OpenAIAssister struct { AiParameters } -func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { +func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) { apiKey := o.AiParameters.apiKey var client openai.Client if apiKey != "" { @@ -43,12 +43,19 @@ func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage str } else { client = openai.NewClient() } + messages := []openai.ChatCompletionMessageParamUnion{ + openai.SystemMessage(o.AiParameters.systemPrompt), + } + for _, conversation := range *conversations { + if conversation.Role == UserRole { + messages = append(messages, openai.UserMessage(conversation.Text)) + } else if conversation.Role == AssistantRole { + messages = append(messages, openai.AssistantMessage(conversation.Text)) + } + } chatCompletion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - openai.SystemMessage(o.AiParameters.systemPrompt), - }, - Model: o.AiParameters.model, + Messages: messages, + Model: o.AiParameters.model, }) if err != nil { return "", err @@ -60,7 +67,7 @@ type AnthropicAIAssister struct { AiParameters } -func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { +func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) { apiKey := c.AiParameters.apiKey maxTokens := c.AiParameters.maxTokens if maxTokens == 0 { @@ -72,15 +79,23 @@ func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessag } else { client = anthropic.NewClient() } + + var messages []anthropic.MessageParam + for _, conversation := range *conversations { + if conversation.Role == UserRole { + messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(conversation.Text))) + } else if conversation.Role == AssistantRole { + messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(conversation.Text))) + } + } + message, err := client.Messages.New(ctx, anthropic.MessageNewParams{ MaxTokens: maxTokens, System: []anthropic.TextBlockParam{ {Text: c.AiParameters.systemPrompt}, }, - Messages: []anthropic.MessageParam{ - anthropic.NewUserMessage(anthropic.NewTextBlock(userMessage)), - }, - Model: anthropic.Model(c.model), + Messages: messages, + Model: anthropic.Model(c.model), }) if err != nil { return "", err @@ -99,7 +114,7 @@ type GeminiAIAssister struct { AiParameters } -func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { +func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) { apiKey := g.AiParameters.apiKey client, err := genai.NewClient(ctx, &genai.ClientConfig{ Backend: genai.BackendGeminiAPI, @@ -115,11 +130,27 @@ func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage s }, }, } - chat, err := client.Chats.Create(ctx, g.model, config, nil) + var history []*genai.Content + var lastUserMessage string + for _, conversation := range *conversations { + if conversation.Role == UserRole { + history = append(history, &genai.Content{ + Role: genai.RoleUser, + Parts: []*genai.Part{{Text: conversation.Text}}, + }) + lastUserMessage = conversation.Text + } else if conversation.Role == AssistantRole { + history = append(history, &genai.Content{ + Role: genai.RoleModel, + Parts: []*genai.Part{{Text: conversation.Text}}, + }) + } + } + chat, err := client.Chats.Create(ctx, g.model, config, history) if err != nil { return "", err } - result, err := chat.SendMessage(ctx, genai.Part{Text: userMessage}) + result, err := chat.SendMessage(ctx, genai.Part{Text: lastUserMessage}) if err != nil { return "", err } diff --git a/agent_test.go b/agent_test.go index 813d2c8..6fa5161 100644 --- a/agent_test.go +++ b/agent_test.go @@ -195,7 +195,11 @@ func TestGetTerminalCommand(t *testing.T) { }, } } - command, err := assister.GetTerminalCommand(t.Context(), "I want to list all files in current directory") + conversations := &[]Conversation{ + {Text: systemPrompt, Role: SystemRole}, + {Text: "I want to list all files in current directory", Role: UserRole}, + } + command, err := assister.GetTerminalCommand(t.Context(), conversations) require.NoError(t, err) require.Contains(t, command, "ls") }) diff --git a/gromit.go b/gromit.go index 6bd426b..0957095 100644 --- a/gromit.go +++ b/gromit.go @@ -3,12 +3,13 @@ package main import ( "bufio" "context" + "encoding/json" "errors" "fmt" "io" + "log" "os" "os/exec" - "regexp" "runtime" "sort" "strings" @@ -16,12 +17,33 @@ import ( "github.com/urfave/cli/v3" ) -const systemPrompt = `You are an assistant providing terminal commands based on user's questions. - You will be given a question about how to do something in the CLI environment. - You will then find out what command to execute and provide the command. - Make sure to enclose the actual command inside *** marker. - For example, if question is about listing all files in a directory for linux, respond with "***ls***". - If no question is asked by user, continue the conversation. If they want to exit, respond with "***exit***".` +const systemPrompt = `You are an assistant providing terminal commands based on user's questions. + Also you should keep up the conversation with the user in case there is no terminal command to execute. + You will be given a question about how to do something in the CLI environment and then find out what command to execute and provide the command. + Always provide your response in the following json format: + { + "command": "the command to execute, can be empty if a response is provided", + "response": "the response to the user, can be empty if a command is provided" + "exit": "true if the user wants to exit, false otherwise" + } + For example, if question is about listing all files in a directory for linux, respond with + { + "command": "ls", + "response": "", + "exit": false + } + If the question is about telling a joke, respond with: + { + "command": "", + "response": "Some funny joke!", + "exit": false + } + If no question is asked by user, continue the conversation. If they want to exit, respond with + { + "command": "", + "response": "", + "exit": true + }` type Gromit struct { cli.Command @@ -106,7 +128,10 @@ func getAvailablePathExecutables() []string { } func (m *messagePrinter) print(s string) { - fmt.Fprintf(m.w, "%s %s %s", m.promptPrefix, s, m.delimiter) + _, err := fmt.Fprintf(m.w, "%s %s %s", m.promptPrefix, s, m.delimiter) + if err != nil { + log.Fatal(err) + } } func WithPromptPrefix(prefix string) ConfigurationModifier { @@ -130,14 +155,26 @@ func WithAskForConfirmation(confirm bool) ConfigurationModifier { } } +var conversations []Conversation + func (g *Gromit) actionGromit(ctx context.Context, command *cli.Command) error { commandArgs := command.Args().Slice() query := strings.Join(commandArgs, " ") + if query == "" { + query = "Can you please introduce yourself or continue the conversation?" + } prompt := g.String("systemPrompt") if prompt == "" { prompt = systemPrompt } prompt = addEnvironmentInfo(g.configuration.systemInfo, prompt) + conversations = append(conversations, Conversation{ + Role: SystemRole, + Text: prompt, + }, Conversation{ + Role: UserRole, + Text: query, + }) g.configuration.AiParameters = AiParameters{ maxTokens: g.Int64("maxTokens"), apiKey: g.String("apiKey"), @@ -145,65 +182,60 @@ func (g *Gromit) actionGromit(ctx context.Context, command *cli.Command) error { model: g.String("model"), systemPrompt: prompt, } - terminalCommand, err := g.extractCommandForQuery(ctx, query) - if err != nil { - return err - } - if terminalCommand != "" { - err = g.handleTerminalCommand(ctx, terminalCommand) - if err != nil { - return err - } - } + for ctx.Err() == nil { - //read the user input, pass it to AI - reader := bufio.NewReader(g.Reader) - query, err := reader.ReadString('\n') + response, err := g.extractResponseForQuery(ctx, &conversations) if err != nil { return err } - terminalCommand, err := g.extractCommandForQuery(ctx, query) + exit, err := g.handleAiResponse(ctx, response) if err != nil { return err } - if terminalCommand == "exit" { - break + if exit { + return nil } - if terminalCommand != "" { - err = g.handleTerminalCommand(ctx, terminalCommand) - if err != nil { - return err - } + reader := bufio.NewReader(g.Reader) + query, err = reader.ReadString('\n') + if err != nil { + return err } + conversations = append(conversations, Conversation{ + Role: UserRole, + Text: query, + }) } return nil } -func (g *Gromit) extractCommandForQuery(ctx context.Context, query string) (string, error) { +func (g *Gromit) extractResponseForQuery(ctx context.Context, conversations *[]Conversation) (AiResponse, error) { + var result AiResponse assister, err := g.AssisterCreator.GetAssister(g.configuration.AiParameters) if err != nil { - return "", err + return result, err } - if query == "" { - query = "Can you please introduce yourself or continue the conversation?" - } - response, err := assister.GetTerminalCommand(ctx, query) + + response, err := assister.GetTerminalCommand(ctx, conversations) if err != nil { - return "", err + return result, err } - //command is enclosed in *** marker - regexp := regexp.MustCompile(`\*\*\*(.*?)\*\*\*`) - commands := regexp.FindStringSubmatch(response) - var command string - if len(commands) > 0 { - command = commands[1] - } else { - g.print(response) + for _, s := range []string{"json", "```"} { + response = strings.ReplaceAll(response, s, "") + } + if !json.Valid([]byte(response)) { + return result, fmt.Errorf("received invalid json response: %s", response) + } + if err = json.Unmarshal([]byte(response), &result); err != nil { + return result, fmt.Errorf("failed to unmarshal json response: \n %s \n error: %s", response, err.Error()) } - return command, nil + *conversations = append(*conversations, Conversation{ + Role: AssistantRole, + Text: response, + }) + return result, nil } -func (g *Gromit) handleTerminalCommand(ctx context.Context, terminalCommand string) error { +func (g *Gromit) handleTerminalCommand(_ context.Context, terminalCommand string) error { g.print("In order to do that, you need to run:") g.print(terminalCommand) confirmation, err := g.askConfirmation("Would you like to run this command?") @@ -222,6 +254,22 @@ func (g *Gromit) handleTerminalCommand(ctx context.Context, terminalCommand stri return nil } +func (g *Gromit) handleAiResponse(ctx context.Context, aiResponse AiResponse) (shouldExit bool, error error) { + if aiResponse.Response != "" { + g.print(aiResponse.Response) + } + if aiResponse.Command != "" { + err := g.handleTerminalCommand(ctx, aiResponse.Command) + if err != nil { + return false, err + } + } + if aiResponse.Exit { + return true, nil + } + return false, nil +} + // adds environment info such as OS, available shells, etc to the system prompt for the AI func addEnvironmentInfo(systemInfo systemInfo, systemPrompt string) string { result := fmt.Sprintf("%s. User's operating system is %s", systemPrompt, systemInfo.operatingSystem) diff --git a/gromit_test.go b/gromit_test.go index b2d9934..95c7ea2 100644 --- a/gromit_test.go +++ b/gromit_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "io" "runtime" "strings" "testing" @@ -60,7 +61,7 @@ func TestWhenAIProviderFailsToCreateAssister(t *testing.T) { } g, err := NewGromit(m) require.NoError(t, err) - _, err = g.extractCommandForQuery(t.Context(), "some query") + _, err = g.extractResponseForQuery(t.Context(), &[]Conversation{}) require.EqualError(t, err, "Unable to create assister") } @@ -77,15 +78,26 @@ func TestWhenAIProviderFailsToFindTheCommand(t *testing.T) { func TestAIAssisterFindingCorrectCommand(t *testing.T) { var buff bytes.Buffer m := &mockAIProvider{ - commandResult: "***ls***", + aiResponse: []string{ + "json```{\"Command\": \"\",\"Response\": \"Hello to you as well!\",\"Exit\": false}```", + "json```{\"Command\": \"ls\",\"Response\": \"\",\"Exit\": false}```", + "json```{\"Command\": \"\",\"Response\": \"\",\"Exit\": true}```", + }, } g, err := NewGromit(m, WithWriter(&buff), WithPromptPrefix("🐶"), WithAskForConfirmation(false)) require.NoError(t, err) - g.Reader = strings.NewReader("I want to list all files in current directory\n") - g.Run(t.Context(), []string{"gromit", "--model", "myModel", "--agent", "myAgent", + + readers := []io.Reader{ + strings.NewReader("I want to list all files in current directory\n"), + strings.NewReader("Thank you! bye!\n"), + } + g.Reader = io.MultiReader(readers...) + err = g.Run(t.Context(), []string{"gromit", "--model", "myModel", "--agent", "myAgent", "--apiKey=key1234", "--maxTokens=2000", "--systemPrompt", "myPrompt", "hello", "my", "ai", "friend!"}) + require.NoError(t, err) result := buff.String() + require.Contains(t, result, "🐶 Hello to you as well!") require.Contains(t, result, "🐶 In order to do that, you need to run") require.Contains(t, result, "🐶 ls") require.Contains(t, result, "README.md") @@ -98,15 +110,28 @@ func TestAIAssisterFindingCorrectCommand(t *testing.T) { require.Contains(t, m.actualAiParameters.systemPrompt, "User's operating system is") require.Contains(t, m.actualAiParameters.systemPrompt, "User's current shell is") require.Contains(t, m.actualAiParameters.systemPrompt, "User's available path commands are") - require.Contains(t, m.actualUserMessage, "I want to list all files in current directory") +} + +func TestAIAssisterProvidingInvalidJsonResponse(t *testing.T) { + var buff bytes.Buffer + m := &mockAIProvider{ + aiResponse: []string{ + "json```invalid response```", + }, + } + g, err := NewGromit(m, WithWriter(&buff)) + require.NoError(t, err) + err = g.Run(t.Context(), []string{}) + require.Errorf(t, err, "received invalid json response: invalid response") } type mockAIProvider struct { - assisterError error - commandError error - commandResult string + assisterError error + commandError error + aiResponse []string + aiResponseIndex int - actualUserMessage string + actualConversations *[]Conversation actualAiParameters AiParameters } @@ -119,10 +144,12 @@ func (m *mockAIProvider) GetAssister(p AiParameters) (Assister, error) { return m, nil } -func (m *mockAIProvider) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) { - m.actualUserMessage = userMessage +func (m *mockAIProvider) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) { + m.actualConversations = conversations if m.commandError != nil { return "", m.commandError } - return m.commandResult, nil + response := m.aiResponse[m.aiResponseIndex] + m.aiResponseIndex = m.aiResponseIndex + 1 + return response, nil } diff --git a/types.go b/types.go index 80dd9f9..238fbd2 100644 --- a/types.go +++ b/types.go @@ -37,3 +37,22 @@ type AiParameters struct { apiKey string maxTokens int64 } + +type AiResponse struct { + Command string `json:"command"` + Response string `json:"response"` + Exit bool `json:"exit"` +} + +type Conversation struct { + Role Role + Text string +} + +type Role int + +const ( + SystemRole Role = iota + UserRole + AssistantRole +)