Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions pkg/runtime/toolexec/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -724,13 +724,21 @@ func buildMultiContent(text string, images []tools.MediaContent) []chat.MessageP
parts := make([]chat.MessagePart, 0, 1+len(images))
parts = append(parts, chat.MessagePart{Type: chat.MessagePartTypeText, Text: text})
for _, img := range images {
parts = append(parts, chat.MessagePart{
Type: chat.MessagePartTypeImageURL,
ImageURL: &chat.MessageImageURL{
URL: "data:" + img.MimeType + ";base64," + img.Data,
Detail: chat.ImageURLDetailAuto,
},
})
switch {
case img.FilePath != "":
parts = append(parts, chat.MessagePart{
Type: chat.MessagePartTypeText,
Text: fmt.Sprintf("[image saved to %s (%s)]", img.FilePath, img.MimeType),
})
case img.Data != "":
parts = append(parts, chat.MessagePart{
Type: chat.MessagePartTypeImageURL,
ImageURL: &chat.MessageImageURL{
URL: "data:" + img.MimeType + ";base64," + img.Data,
Detail: chat.ImageURLDetailAuto,
},
})
}
}
return parts
}
Expand Down
2 changes: 0 additions & 2 deletions pkg/tools/builtin/filesystem/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ type ReadFileArgs struct {

type ReadFileMeta struct {
Path string `json:"path"`
Content string `json:"content"`
LineCount int `json:"lineCount"`
Error string `json:"error,omitempty"`
}
Expand Down Expand Up @@ -1086,7 +1085,6 @@ func (t *ToolSet) handleReadMultipleFiles(ctx context.Context, args ReadMultiple
Path: path,
Content: text,
})
entry.Content = text
entry.LineCount = strings.Count(text, "\n") + 1
meta.Files = append(meta.Files, entry)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/tools/builtin/filesystem/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ func TestFilesystemTool_ReadFile_TildePath(t *testing.T) {
require.NoError(t, err)
assert.False(t, result.IsError)
assert.Equal(t, content, result.Output)
assert.Equal(t, ReadFileMeta{LineCount: 1}, result.Meta)
}

func TestFilesystemTool_WriteFile(t *testing.T) {
Expand Down Expand Up @@ -166,6 +167,7 @@ func TestFilesystemTool_ReadFile(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, content, result.Output)
assert.Equal(t, ReadFileMeta{LineCount: 1}, result.Meta)

result, err = tool.handleReadFile(t.Context(), ReadFileArgs{
Path: "nonexistent.txt",
Expand Down
38 changes: 33 additions & 5 deletions pkg/tools/lifecycle/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ type Supervisor struct {
// fresh channel by Start when transitioning out of a terminal state.
done chan struct{}

// watchDone is closed by the current watcher goroutine. Stop waits on it
// after closing the session so no transport goroutines are left behind.
watchDone chan struct{}

// randFloat is the jitter source; tests may override.
randFloat func() float64
}
Expand Down Expand Up @@ -214,6 +218,9 @@ func (s *Supervisor) Start(ctx context.Context) error {
}
s.session = sess
spawnWatcher := !s.watcherAlive
if spawnWatcher {
s.watchDone = make(chan struct{})
}
s.watcherAlive = true
// Recovering from a terminal state (Failed → Start, or a watcher
// that previously exited): refresh `done` so RestartAndWait callers
Expand Down Expand Up @@ -244,24 +251,40 @@ func (s *Supervisor) Start(ctx context.Context) error {
func (s *Supervisor) Stop(ctx context.Context) error {
s.mu.Lock()
if s.stopping {
watchDone := s.watchDone
s.mu.Unlock()
return nil
return waitForWatcher(ctx, watchDone)
}
s.stopping = true
sess := s.session
s.session = nil
watchDone := s.watchDone
s.mu.Unlock()

s.tracker.Set(StateStopped)
s.signalDone()

if sess == nil {
var closeErr error
if sess != nil {
closeErr = sess.Close(context.WithoutCancel(ctx))
}
waitErr := waitForWatcher(ctx, watchDone)
if closeErr != nil && ctx.Err() == nil {
return closeErr
}
return waitErr
}

func waitForWatcher(ctx context.Context, done <-chan struct{}) error {
if done == nil {
return nil
}
if err := sess.Close(context.WithoutCancel(ctx)); err != nil && ctx.Err() == nil {
return err
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
return nil
}

// RestartAndWait closes the current session (if any) so the watcher
Expand Down Expand Up @@ -326,7 +349,12 @@ func (s *Supervisor) watch(ctx context.Context) {
defer func() {
s.mu.Lock()
s.watcherAlive = false
watchDone := s.watchDone
s.watchDone = nil
s.mu.Unlock()
if watchDone != nil {
close(watchDone)
}
}()

log := s.policy.logger()
Expand Down
15 changes: 15 additions & 0 deletions pkg/tools/lifecycle/supervisor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,18 @@ func TestBackoff_Jitter(t *testing.T) {
d = lifecycle.ExportedBackoffDelay(b, 0, func() float64 { return 0 })
assert.Check(t, d == 50*time.Millisecond)
}

func TestSupervisor_StopWaitsForWatcher(t *testing.T) {
t.Parallel()

sess := newFakeSession()
c := newScriptedConnector(scriptStep{session: sess})
s := lifecycle.New("test", c, lifecycle.Policy{})

assert.NilError(t, s.Start(t.Context()))

start := time.Now()
assert.NilError(t, s.Stop(t.Context()))
assert.Check(t, time.Since(start) < time.Second)
assert.Check(t, is.Equal(s.State().State, lifecycle.StateStopped))
}
55 changes: 50 additions & 5 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"log/slog"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -714,6 +715,8 @@ func isInitNotificationSendError(err error) bool {
return false
}

const maxInlineMediaBytes = 256 * 1024

func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult {
var text strings.Builder
var images, audios []tools.MediaContent
Expand Down Expand Up @@ -760,12 +763,54 @@ func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult {
}
}

// encodeMedia re-encodes raw bytes (as decoded by the MCP SDK) back to base64
// for our internal MediaContent representation.
// encodeMedia keeps small payloads inline and spools larger ones to disk so the
// session and TUI do not retain duplicate base64 copies.
func encodeMedia(data []byte, mimeType string) tools.MediaContent {
return tools.MediaContent{
Data: base64.StdEncoding.EncodeToString(data),
MimeType: mimeType,
media := tools.MediaContent{MimeType: mimeType}
if len(data) <= maxInlineMediaBytes {
media.Data = base64.StdEncoding.EncodeToString(data)
return media
}

path, err := writeMediaFile(data, mimeType)
if err != nil {
slog.Warn("failed to spool MCP media to disk", "mime_type", mimeType, "bytes", len(data), "error", err)
media.Data = base64.StdEncoding.EncodeToString(data)
return media
}
media.FilePath = path
return media
}

func writeMediaFile(data []byte, mimeType string) (string, error) {
dir, err := os.MkdirTemp("", "docker-agent-mcp-media-*")
if err != nil {
return "", err
}
path := filepath.Join(dir, "media"+mediaExtension(mimeType))
if err := os.WriteFile(path, data, 0o600); err != nil {
_ = os.RemoveAll(dir)
return "", err
}
return path, nil
}

func mediaExtension(mimeType string) string {
switch mimeType {
case "image/png":
return ".png"
case "image/jpeg":
return ".jpg"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "audio/wav", "audio/wave", "audio/x-wav":
return ".wav"
case "audio/mpeg", "audio/mp3":
return ".mp3"
default:
return ".bin"
}
}

Expand Down
19 changes: 19 additions & 0 deletions pkg/tools/mcp/mcp_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package mcp

import (
"bytes"
"context"
"fmt"
"iter"
"os"
"path/filepath"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -536,3 +539,19 @@ func TestCallToolRecoversFromErrSessionMissing(t *testing.T) {
assert.Equal(t, "recovered", result.Output)
assert.Equal(t, int32(2), callCount.Load(), "expected exactly 2 CallTool invocations (1 failed + 1 retry)")
}

func TestProcessMCPContentSpoolsLargeMedia(t *testing.T) {
large := bytes.Repeat([]byte("x"), maxInlineMediaBytes+1)
result := processMCPContent(callToolResult(&mcp.ImageContent{Data: large, MIMEType: "image/png"}))

require.Len(t, result.Images, 1)
img := result.Images[0]
assert.Empty(t, img.Data)
assert.Equal(t, "image/png", img.MimeType)
require.NotEmpty(t, img.FilePath)
defer os.RemoveAll(filepath.Dir(img.FilePath))

got, err := os.ReadFile(img.FilePath)
require.NoError(t, err)
assert.Equal(t, large, got)
}
17 changes: 15 additions & 2 deletions pkg/tools/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ type FunctionCall struct {
// MediaContent represents base64-encoded binary data (image, audio, etc.)
// returned by a tool.
type MediaContent struct {
// Data is the base64-encoded payload.
Data string `json:"data"`
// Data is the base64-encoded payload. It is kept only for small media; large
// MCP payloads are spooled to FilePath to avoid retaining duplicate base64.
Data string `json:"data,omitempty"`
// FilePath is an optional local file containing the decoded media payload.
FilePath string `json:"filePath,omitempty"`
// MimeType identifies the content type (e.g. "image/png", "audio/wav").
MimeType string `json:"mimeType"`
}
Expand All @@ -99,6 +102,16 @@ type ToolCallResult struct {
StructuredContent any `json:"structuredContent,omitempty"`
}

func (r *ToolCallResult) WithoutPayload() *ToolCallResult {
if r == nil {
return nil
}
return &ToolCallResult{
IsError: r.IsError,
Meta: r.Meta,
}
}

func ResultError(output string) *ToolCallResult {
return &ToolCallResult{
Output: output,
Expand Down
21 changes: 21 additions & 0 deletions pkg/tools/tools_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,24 @@ func TestNewHandler_InvalidArguments(t *testing.T) {
})
require.Error(t, err)
}

func TestToolCallResultWithoutPayload(t *testing.T) {
result := &ToolCallResult{
Output: "large output",
IsError: true,
Meta: "metadata",
Images: []MediaContent{{Data: "image", MimeType: "image/png"}},
Audios: []MediaContent{{Data: "audio", MimeType: "audio/wav"}},
StructuredContent: map[string]any{"key": "value"},
}

slim := result.WithoutPayload()

require.NotNil(t, slim)
assert.Empty(t, slim.Output)
assert.True(t, slim.IsError)
assert.Equal(t, "metadata", slim.Meta)
assert.Nil(t, slim.Images)
assert.Nil(t, slim.Audios)
assert.Nil(t, slim.StructuredContent)
}
4 changes: 2 additions & 2 deletions pkg/tui/components/messages/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ func (m *model) AddToolResult(msg *runtime.ToolCallResponseEvent, status types.T
if m.messages[i].Type == types.MessageTypeAssistantReasoningBlock {
if block, ok := m.views[i].(*reasoningblock.Model); ok {
if block.HasToolCall(msg.ToolCallID) {
cmd := block.UpdateToolResult(msg.ToolCallID, msg.Response, status, msg.Result)
cmd := block.UpdateToolResult(msg.ToolCallID, msg.Response, status, msg.Result.WithoutPayload())
m.invalidateItem(i)
return cmd
}
Expand All @@ -1489,7 +1489,7 @@ func (m *model) AddToolResult(msg *runtime.ToolCallResponseEvent, status types.T
if toolMessage.Type == types.MessageTypeToolCall && toolMessage.ToolCall.ID == msg.ToolCallID {
toolMessage.Content = strings.ReplaceAll(msg.Response, "\t", " ")
toolMessage.ToolStatus = status
toolMessage.ToolResult = msg.Result
toolMessage.ToolResult = msg.Result.WithoutPayload()
m.invalidateItem(i)

view := m.createToolCallView(toolMessage)
Expand Down
2 changes: 1 addition & 1 deletion pkg/tui/components/reasoningblock/reasoningblock.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (m *Model) UpdateToolResult(toolCallID, content string, status types.ToolSt

entry.msg.Content = strings.ReplaceAll(content, "\t", " ")
entry.msg.ToolStatus = status
entry.msg.ToolResult = result
entry.msg.ToolResult = result.WithoutPayload()

// Set grace period if transitioning from in-progress to completed
// Total visible time = completedToolVisibleDuration + completedToolFadeDuration
Expand Down
Loading