Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
)

type Assister interface {
GetTerminalCommand(ctx context.Context, userMessage string) (string, error)
GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error)
}

var _ Assister = (*OpenAIAssister)(nil)
Expand All @@ -35,20 +35,27 @@ type OpenAIAssister struct {
AiParameters
}

func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) {
func (o *OpenAIAssister) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) {
apiKey := o.AiParameters.apiKey
var client openai.Client
if apiKey != "" {
client = openai.NewClient(openaiOption.WithAPIKey(apiKey))
} else {
client = openai.NewClient()
}
messages := []openai.ChatCompletionMessageParamUnion{
openai.SystemMessage(o.AiParameters.systemPrompt),
}
for _, conversation := range *conversations {
if conversation.Role == UserRole {
messages = append(messages, openai.UserMessage(conversation.Text))
} else if conversation.Role == AssistantRole {
messages = append(messages, openai.AssistantMessage(conversation.Text))
}
}
chatCompletion, err := client.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
openai.SystemMessage(o.AiParameters.systemPrompt),
},
Model: o.AiParameters.model,
Messages: messages,
Model: o.AiParameters.model,
})
if err != nil {
return "", err
Expand All @@ -60,7 +67,7 @@ type AnthropicAIAssister struct {
AiParameters
}

func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) {
func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) {
apiKey := c.AiParameters.apiKey
maxTokens := c.AiParameters.maxTokens
if maxTokens == 0 {
Expand All @@ -72,15 +79,23 @@ func (c *AnthropicAIAssister) GetTerminalCommand(ctx context.Context, userMessag
} else {
client = anthropic.NewClient()
}

var messages []anthropic.MessageParam
for _, conversation := range *conversations {
if conversation.Role == UserRole {
messages = append(messages, anthropic.NewUserMessage(anthropic.NewTextBlock(conversation.Text)))
} else if conversation.Role == AssistantRole {
messages = append(messages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(conversation.Text)))
}
}

message, err := client.Messages.New(ctx, anthropic.MessageNewParams{
MaxTokens: maxTokens,
System: []anthropic.TextBlockParam{
{Text: c.AiParameters.systemPrompt},
},
Messages: []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock(userMessage)),
},
Model: anthropic.Model(c.model),
Messages: messages,
Model: anthropic.Model(c.model),
})
if err != nil {
return "", err
Expand All @@ -99,7 +114,7 @@ type GeminiAIAssister struct {
AiParameters
}

func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage string) (string, error) {
func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, conversations *[]Conversation) (string, error) {
apiKey := g.AiParameters.apiKey
client, err := genai.NewClient(ctx, &genai.ClientConfig{
Backend: genai.BackendGeminiAPI,
Expand All @@ -115,11 +130,27 @@ func (g *GeminiAIAssister) GetTerminalCommand(ctx context.Context, userMessage s
},
},
}
chat, err := client.Chats.Create(ctx, g.model, config, nil)
var history []*genai.Content
var lastUserMessage string
for _, conversation := range *conversations {
if conversation.Role == UserRole {
history = append(history, &genai.Content{
Role: genai.RoleUser,
Parts: []*genai.Part{{Text: conversation.Text}},
})
lastUserMessage = conversation.Text
} else if conversation.Role == AssistantRole {
history = append(history, &genai.Content{
Role: genai.RoleModel,
Parts: []*genai.Part{{Text: conversation.Text}},
})
}
}
chat, err := client.Chats.Create(ctx, g.model, config, history)
if err != nil {
return "", err
}
result, err := chat.SendMessage(ctx, genai.Part{Text: userMessage})
result, err := chat.SendMessage(ctx, genai.Part{Text: lastUserMessage})
if err != nil {
return "", err
}
Expand Down
6 changes: 5 additions & 1 deletion agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@ func TestGetTerminalCommand(t *testing.T) {
},
}
}
command, err := assister.GetTerminalCommand(t.Context(), "I want to list all files in current directory")
conversations := &[]Conversation{
{Text: systemPrompt, Role: SystemRole},
{Text: "I want to list all files in current directory", Role: UserRole},
}
command, err := assister.GetTerminalCommand(t.Context(), conversations)
require.NoError(t, err)
require.Contains(t, command, "ls")
})
Expand Down
140 changes: 94 additions & 46 deletions gromit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,47 @@ package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"regexp"
"runtime"
"sort"
"strings"

"github.com/urfave/cli/v3"
)

const systemPrompt = `You are an assistant providing terminal commands based on user's questions.
You will be given a question about how to do something in the CLI environment.
You will then find out what command to execute and provide the command.
Make sure to enclose the actual command inside *** marker.
For example, if question is about listing all files in a directory for linux, respond with "***ls***".
If no question is asked by user, continue the conversation. If they want to exit, respond with "***exit***".`
const systemPrompt = `You are an assistant providing terminal commands based on user's questions.
Also you should keep up the conversation with the user in case there is no terminal command to execute.
You will be given a question about how to do something in the CLI environment and then find out what command to execute and provide the command.
Always provide your response in the following json format:
{
"command": "the command to execute, can be empty if a response is provided",
"response": "the response to the user, can be empty if a command is provided"
"exit": "true if the user wants to exit, false otherwise"
}
For example, if question is about listing all files in a directory for linux, respond with
{
"command": "ls",
"response": "",
"exit": false
}
If the question is about telling a joke, respond with:
{
"command": "",
"response": "Some funny joke!",
"exit": false
}
If no question is asked by user, continue the conversation. If they want to exit, respond with
{
"command": "",
"response": "",
"exit": true
}`

type Gromit struct {
cli.Command
Expand Down Expand Up @@ -106,7 +128,10 @@ func getAvailablePathExecutables() []string {
}

func (m *messagePrinter) print(s string) {
fmt.Fprintf(m.w, "%s %s %s", m.promptPrefix, s, m.delimiter)
_, err := fmt.Fprintf(m.w, "%s %s %s", m.promptPrefix, s, m.delimiter)
if err != nil {
log.Fatal(err)
}
}

func WithPromptPrefix(prefix string) ConfigurationModifier {
Expand All @@ -130,80 +155,87 @@ func WithAskForConfirmation(confirm bool) ConfigurationModifier {
}
}

var conversations []Conversation
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recording the conversation with AI to maintain history during the chat.


func (g *Gromit) actionGromit(ctx context.Context, command *cli.Command) error {
commandArgs := command.Args().Slice()
query := strings.Join(commandArgs, " ")
if query == "" {
query = "Can you please introduce yourself or continue the conversation?"
}
prompt := g.String("systemPrompt")
if prompt == "" {
prompt = systemPrompt
}
prompt = addEnvironmentInfo(g.configuration.systemInfo, prompt)
conversations = append(conversations, Conversation{
Role: SystemRole,
Text: prompt,
}, Conversation{
Role: UserRole,
Text: query,
})
g.configuration.AiParameters = AiParameters{
maxTokens: g.Int64("maxTokens"),
apiKey: g.String("apiKey"),
agent: g.String("agent"),
model: g.String("model"),
systemPrompt: prompt,
}
terminalCommand, err := g.extractCommandForQuery(ctx, query)
if err != nil {
return err
}
if terminalCommand != "" {
err = g.handleTerminalCommand(ctx, terminalCommand)
if err != nil {
return err
}
}

for ctx.Err() == nil {
//read the user input, pass it to AI
reader := bufio.NewReader(g.Reader)
query, err := reader.ReadString('\n')
response, err := g.extractResponseForQuery(ctx, &conversations)
if err != nil {
return err
}
terminalCommand, err := g.extractCommandForQuery(ctx, query)
exit, err := g.handleAiResponse(ctx, response)
if err != nil {
return err
}
if terminalCommand == "exit" {
break
if exit {
return nil
}
if terminalCommand != "" {
err = g.handleTerminalCommand(ctx, terminalCommand)
if err != nil {
return err
}
reader := bufio.NewReader(g.Reader)
query, err = reader.ReadString('\n')
if err != nil {
return err
}
conversations = append(conversations, Conversation{
Role: UserRole,
Text: query,
})
}
return nil
}

func (g *Gromit) extractCommandForQuery(ctx context.Context, query string) (string, error) {
func (g *Gromit) extractResponseForQuery(ctx context.Context, conversations *[]Conversation) (AiResponse, error) {
var result AiResponse
assister, err := g.AssisterCreator.GetAssister(g.configuration.AiParameters)
if err != nil {
return "", err
return result, err
}
if query == "" {
query = "Can you please introduce yourself or continue the conversation?"
}
response, err := assister.GetTerminalCommand(ctx, query)

response, err := assister.GetTerminalCommand(ctx, conversations)
if err != nil {
return "", err
return result, err
}
//command is enclosed in *** marker
regexp := regexp.MustCompile(`\*\*\*(.*?)\*\*\*`)
commands := regexp.FindStringSubmatch(response)
var command string
if len(commands) > 0 {
command = commands[1]
} else {
g.print(response)
for _, s := range []string{"json", "```"} {
response = strings.ReplaceAll(response, s, "")
}
if !json.Valid([]byte(response)) {
return result, fmt.Errorf("received invalid json response: %s", response)
}
if err = json.Unmarshal([]byte(response), &result); err != nil {
return result, fmt.Errorf("failed to unmarshal json response: \n %s \n error: %s", response, err.Error())
}
return command, nil
*conversations = append(*conversations, Conversation{
Role: AssistantRole,
Text: response,
})
return result, nil
}

func (g *Gromit) handleTerminalCommand(ctx context.Context, terminalCommand string) error {
func (g *Gromit) handleTerminalCommand(_ context.Context, terminalCommand string) error {
g.print("In order to do that, you need to run:")
g.print(terminalCommand)
confirmation, err := g.askConfirmation("Would you like to run this command?")
Expand All @@ -222,6 +254,22 @@ func (g *Gromit) handleTerminalCommand(ctx context.Context, terminalCommand stri
return nil
}

func (g *Gromit) handleAiResponse(ctx context.Context, aiResponse AiResponse) (shouldExit bool, error error) {
if aiResponse.Response != "" {
g.print(aiResponse.Response)
}
if aiResponse.Command != "" {
err := g.handleTerminalCommand(ctx, aiResponse.Command)
if err != nil {
return false, err
}
}
if aiResponse.Exit {
return true, nil
}
return false, nil
}

// adds environment info such as OS, available shells, etc to the system prompt for the AI
func addEnvironmentInfo(systemInfo systemInfo, systemPrompt string) string {
result := fmt.Sprintf("%s. User's operating system is %s", systemPrompt, systemInfo.operatingSystem)
Expand Down
Loading