-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathtools.go
More file actions
97 lines (83 loc) · 3.11 KB
/
tools.go
File metadata and controls
97 lines (83 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package contextwindow
import (
"context"
"encoding/json"
"fmt"
)
// TODO(tqbf): this is all pretty gnarly and half-baked, but comes of having
// only implemented this for OpenAI's client library; it'll stay gnarly until
// I do something with Claude.
// ToolDefinition represents a tool that can be called by the model.
type ToolDefinition struct {
Name string `json:"name"`
Definition interface{} `json:"definition"` // Model-specific tool definition (e.g., OpenAI FunctionDefinitionParam)
}
// ToolRunner defines the interface for executing a tool.
type ToolRunner interface {
Run(ctx context.Context, args json.RawMessage) (string, error)
}
// ToolRunnerFunc allows functions to implement ToolRunner.
type ToolRunnerFunc func(ctx context.Context, args json.RawMessage) (string, error)
func (f ToolRunnerFunc) Run(ctx context.Context, args json.RawMessage) (string, error) {
return f(ctx, args)
}
// ToolExecutor can execute tools by name and provide access to tool definitions.
type ToolExecutor interface {
ExecuteTool(ctx context.Context, name string, args json.RawMessage) (string, error)
GetRegisteredTools() []ToolDefinition
}
// RegisterTool registers a tool with this ContextWindow instance and stores the tool name as a hint in the database.
func (cw *ContextWindow) RegisterTool(name string, definition interface{}, runner ToolRunner) error {
cw.registeredTools[name] = ToolDefinition{
Name: name,
Definition: definition,
}
cw.toolRunners[name] = runner
// Store the tool name in the database as a hint
contextID, err := getContextIDByName(cw.db, cw.currentContext)
if err != nil {
return fmt.Errorf("register tool: %w", err)
}
_, err = AddContextTool(cw.db, contextID, name)
if err != nil {
return fmt.Errorf("register tool: %w", err)
}
return nil
}
// GetTool retrieves a registered tool runner by name.
func (cw *ContextWindow) GetTool(name string) (ToolRunner, bool) {
runner, exists := cw.toolRunners[name]
return runner, exists
}
// ExecuteTool implements ToolExecutor interface.
func (cw *ContextWindow) ExecuteTool(ctx context.Context, name string, args json.RawMessage) (string, error) {
runner, exists := cw.toolRunners[name]
if !exists {
return "", fmt.Errorf("tool '%s' not registered", name)
}
return runner.Run(ctx, args)
}
// GetRegisteredTools returns all registered tool definitions.
func (cw *ContextWindow) GetRegisteredTools() []ToolDefinition {
var tools []ToolDefinition
for _, toolDef := range cw.registeredTools {
tools = append(tools, toolDef)
}
return tools
}
// ListTools returns the names of all tools available in this context.
func (cw *ContextWindow) ListTools() ([]string, error) {
contextID, err := getContextIDByName(cw.db, cw.currentContext)
if err != nil {
return nil, fmt.Errorf("list tools: %w", err)
}
return ListContextToolNames(cw.db, contextID)
}
// HasTool checks if a tool name is available in this context.
func (cw *ContextWindow) HasTool(name string) (bool, error) {
contextID, err := getContextIDByName(cw.db, cw.currentContext)
if err != nil {
return false, fmt.Errorf("has tool: %w", err)
}
return HasContextTool(cw.db, contextID, name)
}