diff --git a/go.mod b/go.mod index f803a3b9..fc68968e 100644 --- a/go.mod +++ b/go.mod @@ -18,10 +18,11 @@ require ( github.com/gptscript-ai/chat-completion-client v0.0.0-20250224164718-139cb4507b1d github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 - github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee + github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9 github.com/hexops/autogold/v2 v2.2.1 github.com/hexops/valast v1.4.4 github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 + github.com/mark3labs/mcp-go v0.25.0 github.com/mholt/archives v0.1.0 github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 @@ -113,6 +114,7 @@ require ( github.com/skeema/knownhosts v1.2.2 // indirect github.com/sorairolake/lzip-go v0.3.5 // indirect github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/therootcompany/xz v1.0.1 // indirect github.com/tidwall/match v1.1.1 // indirect @@ -122,6 +124,7 @@ require ( github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/goldmark v1.5.4 // indirect github.com/yuin/goldmark-emoji v1.0.2 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect diff --git a/go.sum b/go.sum index 74341af5..7ce2cd38 100644 --- a/go.sum +++ b/go.sum @@ -203,8 +203,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7J github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61 h1:QxLjsLOYlsVLPwuRkP0Q8EcAoZT1s8vU2ZBSX0+R6CI= github.com/gptscript-ai/go-gptscript v0.9.6-0.20250204133419-744b25b84a61/go.mod h1:/FVuLwhz+sIfsWUgUHWKi32qT0i6+IXlUlzs70KKt/Q= -github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee h1:70PHW6Xw70yNNZ5aX936XqcMLwNmfMZpCV3FCOGKpxE= -github.com/gptscript-ai/tui v0.0.0-20250204145344-33cd15de4cee/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= +github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9 h1:wQC8sKyeGA50WnCEG+Jo5FNRIkuX3HX8d3ubyWCCoI8= +github.com/gptscript-ai/tui v0.0.0-20250419050840-5e79e16786c9/go.mod h1:iwHxuueg2paOak7zIg0ESBWx7A0wIHGopAratbgaPNY= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -270,6 +270,8 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.25.0 h1:UUpcMT3L5hIhuDy7aifj4Bphw4Pfx1Rf8mzMXDe8RQw= +github.com/mark3labs/mcp-go v0.25.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -361,6 +363,8 @@ github.com/sorairolake/lzip-go v0.3.5 h1:ms5Xri9o1JBIWvOFAorYtUNik6HI3HgBTkISiqu github.com/sorairolake/lzip-go v0.3.5/go.mod h1:N0KYq5iWrMXI0ZEXKXaS9hCyOjZUQdBDEIbXfoUwbdk= github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e h1:H+jDTUeF+SVd4ApwnSFoew8ZwGNRfgb9EsZc7LcocAg= github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e/go.mod h1:VsUklG6OQo7Ctunu0gS3AtEOCEc2kMB6r5rKzxAes58= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -406,6 +410,8 @@ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavM github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.3.7/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.5.4 h1:2uY/xC0roWy8IBEGLgB1ywIoEJFGmRrX21YQcvGZzjU= diff --git a/pkg/cli/gptscript.go b/pkg/cli/gptscript.go index 4b0642d2..b5a823b2 100644 --- a/pkg/cli/gptscript.go +++ b/pkg/cli/gptscript.go @@ -215,7 +215,7 @@ func (r *GPTScript) listTools(ctx context.Context, gptScript *gptscript.GPTScrip // Don't print instructions tool.Instructions = "" - lines = append(lines, tool.String()) + lines = append(lines, tool.Print()) } fmt.Println(strings.Join(lines, "\n---\n")) return nil diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index abf45e8c..c7867512 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -41,6 +41,11 @@ type Engine struct { RuntimeManager RuntimeManager Env []string Progress chan<- types.CompletionStatus + MCPRunner MCPRunner +} + +type MCPRunner interface { + Run(ctx context.Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) } type State struct { @@ -307,6 +312,17 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too return nil } +func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) { + output, err := e.MCPRunner.Run(ctx.Ctx, e.Progress, tool, input) + if err != nil { + return nil, fmt.Errorf("failed to run MCP invoke: %w", err) + } + + return &Return{ + Result: &output, + }, nil +} + func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*Return, error) { if tool.IsHTTP() { return e.runHTTP(ctx, tool, input) @@ -342,6 +358,10 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, err error) { } }() + if tool.IsMCPInvoke() { + return e.runMCPInvoke(ctx, tool, input) + } + if tool.IsCommand() { return e.runCommandTools(ctx, tool, input) } @@ -378,6 +398,7 @@ func addUpdateSystem(ctx Context, tool types.Tool, msgs []types.CompletionMessag instructions = append(instructions, context.Content) } + tool.Instructions = strings.TrimPrefix(tool.Instructions, types.PromptPrefix) if tool.Instructions != "" { instructions = append(instructions, tool.Instructions) } diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index e70827c6..626cc87f 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -20,6 +20,7 @@ import ( "github.com/gptscript-ai/gptscript/pkg/builtin" "github.com/gptscript-ai/gptscript/pkg/cache" "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/openapi" "github.com/gptscript-ai/gptscript/pkg/parser" "github.com/gptscript-ai/gptscript/pkg/system" @@ -155,7 +156,23 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { return openAPIDocument } -func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { +func processMCP(ctx context.Context, tool []types.Tool, mcpLoader MCPLoader) (result []types.Tool, _ error) { + for _, t := range tool { + if t.IsMCP() { + mcpTools, err := mcpLoader.Load(ctx, t) + if err != nil { + return nil, fmt.Errorf("error loading MCP tools: %w", err) + } + result = append(result, mcpTools...) + } else { + result = append(result, t) + } + } + + return result, nil +} + +func readTool(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { data := base.Content var ( @@ -212,6 +229,11 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, fmt.Errorf("no tools found in %s", base) } + tools, err := processMCP(ctx, tools, mcp) + if err != nil { + return nil, err + } + var ( localTools = types.ToolSet{} targetTools []types.Tool @@ -279,17 +301,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base localTools[strings.ToLower(tool.Name)] = tool } - return linkAll(ctx, cache, prg, base, targetTools, localTools, defaultModel) + return linkAll(ctx, cache, mcp, prg, base, targetTools, localTools, defaultModel) } -func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { +func linkAll(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { localToolsMapping := make(map[string]string, len(tools)) for _, localTool := range localTools { localToolsMapping[strings.ToLower(localTool.Name)] = localTool.ID } for _, tool := range tools { - tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping, defaultModel) + tool, err := link(ctx, cache, mcp, prg, base, tool, localTools, localToolsMapping, defaultModel) if err != nil { return nil, err } @@ -298,7 +320,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base return } -func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { +func link(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { if existing, ok := prg.ToolSet[tool.ID]; ok { return existing, nil } @@ -323,7 +345,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so linkedTool = existing } else { var err error - linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping, defaultModel) + linkedTool, err = link(ctx, cache, mcp, prg, base, localTool, localTools, localToolsMapping, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err) } @@ -333,7 +355,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so toolNames[targetToolName] = struct{}{} } else { toolName, subTool := types.SplitToolRef(targetToolName) - resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool, defaultModel) + resolvedTools, err := resolve(ctx, cache, mcp, prg, base, toolName, subTool, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err) } @@ -373,7 +395,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. prg := types.Program{ ToolSet: types.ToolSet{}, } - tools, err := readTool(ctx, opt.Cache, &prg, &source{ + tools, err := readTool(ctx, opt.Cache, opt.MCPLoader, &prg, &source{ Content: []byte(content), Path: locationPath, Name: locationName, @@ -390,6 +412,12 @@ type Options struct { Cache *cache.Client Location string DefaultModel string + MCPLoader MCPLoader +} + +type MCPLoader interface { + Load(ctx context.Context, tool types.Tool) ([]types.Tool, error) + Close() error } func complete(opts ...Options) (result Options) { @@ -397,6 +425,7 @@ func complete(opts ...Options) (result Options) { result.Cache = types.FirstSet(opt.Cache, result.Cache) result.Location = types.FirstSet(opt.Location, result.Location) result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) + result.MCPLoader = types.FirstSet(opt.MCPLoader, result.MCPLoader) } if result.Location == "" { @@ -407,6 +436,10 @@ func complete(opts ...Options) (result Options) { result.DefaultModel = builtin.GetDefaultModel() } + if result.MCPLoader == nil { + result.MCPLoader = mcp.DefaultLoader + } + return } @@ -430,7 +463,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty Name: name, ToolSet: types.ToolSet{}, } - tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName, opt.DefaultModel) + tools, err := resolve(ctx, opt.Cache, opt.MCPLoader, &prg, &source{}, name, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -438,7 +471,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty return prg, nil } -func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, mcp MCPLoader, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { if subTool == "" { t, ok := builtin.DefaultModel(name, defaultModel) if ok { @@ -452,7 +485,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, err } - result, err := readTool(ctx, cache, prg, s, subTool, defaultModel) + result, err := readTool(ctx, cache, mcp, prg, s, subTool, defaultModel) if err != nil { return nil, err } diff --git a/pkg/loader/openapi_test.go b/pkg/loader/openapi_test.go index 423246d1..594d8cf7 100644 --- a/pkg/loader/openapi_test.go +++ b/pkg/loader/openapi_test.go @@ -26,7 +26,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err, "failed to read openapi v3") require.Equal(t, 3, numOpenAPITools(prgv3.ToolSet), "expected 3 openapi tools") @@ -35,7 +35,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.json") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2json, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2") require.Equal(t, 3, numOpenAPITools(prgv2json.ToolSet), "expected 3 openapi tools") @@ -44,7 +44,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err = os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2yaml, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2 (yaml)") require.Equal(t, 3, numOpenAPITools(prgv2yaml.ToolSet), "expected 3 openapi tools") @@ -57,7 +57,7 @@ func TestOpenAPIv3(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -69,7 +69,7 @@ func TestOpenAPIv3NoOperationIDs(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -81,7 +81,7 @@ func TestOpenAPIv2(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) @@ -94,7 +94,7 @@ func TestOpenAPIv3Revamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -107,7 +107,7 @@ func TestOpenAPIv3NoOperationIDsRevamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -120,8 +120,18 @@ func TestOpenAPIv2Revamp(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") + _, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) } + +type fakeMCPLoader struct{} + +func (fakeMCPLoader) Load(context.Context, types.Tool) ([]types.Tool, error) { + return nil, nil +} + +func (fakeMCPLoader) Close() error { + return nil +} diff --git a/pkg/mcp/loader.go b/pkg/mcp/loader.go new file mode 100644 index 00000000..0eb713e5 --- /dev/null +++ b/pkg/mcp/loader.go @@ -0,0 +1,312 @@ +package mcp + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "maps" + "slices" + "strings" + "sync" + + "github.com/getkin/kin-openapi/openapi3" + "github.com/gptscript-ai/gptscript/pkg/hash" + "github.com/gptscript-ai/gptscript/pkg/mvl" + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/gptscript-ai/gptscript/pkg/version" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) + +var ( + DefaultLoader = &Local{} + DefaultRunner = DefaultLoader + + logger = mvl.Package() +) + +type Local struct { + lock sync.Mutex + sessions map[string]*Session + sessionCtx context.Context + cancel context.CancelFunc +} + +type Session struct { + ID string + InitResult *mcp.InitializeResult + Client client.MCPClient + Config ServerConfig +} + +type Config struct { + MCPServers map[string]ServerConfig `json:"mcpServers"` +} + +// ServerConfig represents an MCP server configuration for tools calls. +// It is important that this type doesn't have any maps. +type ServerConfig struct { + DisableInstruction bool `json:"disableInstruction"` + Command string `json:"command"` + Args []string `json:"args"` + Env []string `json:"env"` + Server string `json:"server"` + URL string `json:"url"` + BaseURL string `json:"baseURL,omitempty"` + Headers []string `json:"headers"` + Scope string `json:"scope"` +} + +func (s *ServerConfig) GetBaseURL() string { + if s.BaseURL != "" { + return s.BaseURL + } + if s.Server != "" { + return s.Server + } + return s.URL +} + +func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, _ error) { + if !tool.IsMCP() { + return nil, nil + } + + _, configData, _ := strings.Cut(tool.Instructions, "\n") + + var servers Config + if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &servers); err != nil { + return nil, fmt.Errorf("failed to parse MCP configuration: %w\n%s", err, configData) + } + + if len(servers.MCPServers) == 0 { + // Try to load just one server + var server ServerConfig + if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &server); err != nil { + return nil, fmt.Errorf("failed to parse single MCP server configuration: %w\n%s", err, configData) + } + if server.Command == "" && server.URL == "" && server.Server == "" { + return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) + } + servers.MCPServers = map[string]ServerConfig{ + "default": server, + } + } + + if len(servers.MCPServers) > 1 { + return nil, fmt.Errorf("only a single MCP server definition is supported") + } + + for server := range maps.Keys(servers.MCPServers) { + session, err := l.loadSession(servers.MCPServers[server]) + if err != nil { + return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err) + } + + return l.sessionToTools(ctx, session, tool.Name) + } + + // This should never happen, but just in case + return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData) +} + +func (l *Local) Close() error { + if l == nil { + return nil + } + + l.lock.Lock() + defer l.lock.Unlock() + + if l.sessionCtx == nil { + return nil + } + + defer func() { + l.cancel() + l.sessionCtx = nil + }() + + var errs []error + for id, session := range l.sessions { + logger.Infof("closing MCP session %s", id) + if err := session.Client.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close MCP client %s: %w", id, err)) + } + } + + return errors.Join(errs...) +} + +func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName string) ([]types.Tool, error) { + tools, err := session.Client.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to list tools: %w", err) + } + + toolDefs := []types.Tool{{ /* this is a placeholder for main tool */ }} + var toolNames []string + + for _, tool := range tools.Tools { + var schema openapi3.Schema + + schemaData, err := json.Marshal(tool.InputSchema) + if err != nil { + panic(err) + } + + if tool.Name == "" { + // I dunno, bad tool? + continue + } + + if err := json.Unmarshal(schemaData, &schema); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool input schema: %w", err) + } + + annotations, err := json.Marshal(tool.Annotations) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool annotations: %w", err) + } + + toolDef := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: tool.Name, + Description: tool.Description, + Arguments: &schema, + }, + Instructions: types.MCPInvokePrefix + tool.Name + " " + session.ID, + }, + } + + if string(annotations) != "{}" { + toolDef.MetaData = map[string]string{ + "mcp-tool-annotations": string(annotations), + } + } + + if tool.Annotations.Title != "" && !slices.Contains(strings.Fields(tool.Annotations.Title), "as") { + toolDef.Name = tool.Annotations.Title + " as " + tool.Name + } + + toolDefs = append(toolDefs, toolDef) + toolNames = append(toolNames, tool.Name) + } + + main := types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: toolName, + Description: session.InitResult.ServerInfo.Name, + Export: toolNames, + }, + MetaData: map[string]string{ + "bundle": "true", + }, + }, + } + + if session.InitResult.Instructions != "" { + data, _ := json.Marshal(map[string]any{ + "tools": toolNames, + "instructions": session.InitResult.Instructions, + }) + toolDefs = append(toolDefs, types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: session.ID, + Type: "context", + }, + Instructions: types.EchoPrefix + "\n" + `# START MCP SERVER INFO: ` + session.InitResult.ServerInfo.Name + "\n" + + `You have available the following tools from an MCP Server that has provided the following additional instructions` + "\n" + + string(data) + "\n" + + `# END MCP SERVER INFO` + "\n", + }, + }) + + main.ExportContext = append(main.ExportContext, session.ID) + } + + toolDefs[0] = main + return toolDefs, nil +} + +func (l *Local) loadSession(server ServerConfig) (*Session, error) { + id := hash.Digest(server) + l.lock.Lock() + existing, ok := l.sessions[id] + if l.sessionCtx == nil { + l.sessionCtx, l.cancel = context.WithCancel(context.Background()) + } + ctx := l.sessionCtx + l.lock.Unlock() + + if ok { + return existing, nil + } + + var ( + c *client.Client + err error + ) + if server.Command != "" { + c, err = client.NewStdioMCPClient(server.Command, server.Env, server.Args...) + if err != nil { + return nil, fmt.Errorf("failed to create MCP stdio client: %w", err) + } + } else { + url := server.URL + if url == "" { + url = server.Server + } + + headers := make(map[string]string, len(server.Headers)) + for _, h := range server.Headers { + k, v, _ := strings.Cut(h, "=") + headers[k] = v + } + + c, err = client.NewSSEMCPClient(url, client.WithHeaders(headers)) + if err != nil { + return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err) + } + + // We expect the client to outlive this one request. + if err = c.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start MCP client: %w", err) + } + } + + var initRequest mcp.InitializeRequest + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: version.ProgramName, + Version: version.Get().String(), + } + + initResult, err := c.Initialize(ctx, initRequest) + if err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + result := &Session{ + ID: id, + InitResult: initResult, + Client: c, + Config: server, + } + + l.lock.Lock() + defer l.lock.Unlock() + + if existing, ok = l.sessions[id]; ok { + return existing, c.Close() + } + + if l.sessions == nil { + l.sessions = make(map[string]*Session) + } + l.sessions[id] = result + return result, nil +} diff --git a/pkg/mcp/runner.go b/pkg/mcp/runner.go new file mode 100644 index 00000000..448d58a7 --- /dev/null +++ b/pkg/mcp/runner.go @@ -0,0 +1,55 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/gptscript-ai/gptscript/pkg/types" + "github.com/mark3labs/mcp-go/mcp" +) + +func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) { + fields := strings.Fields(tool.Instructions) + if len(fields) < 2 { + return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions) + } + + id := fields[1] + toolName, ok := strings.CutPrefix(fields[0], types.MCPInvokePrefix) + if !ok { + return "", fmt.Errorf("invalid mcp call, invalid tool name in %s", tool.Instructions) + } + + arguments := map[string]any{} + + if input != "" { + if err := json.Unmarshal([]byte(input), &arguments); err != nil { + return "", fmt.Errorf("failed to unmarshal input: %w", err) + } + } + + l.lock.Lock() + session, ok := l.sessions[id] + l.lock.Unlock() + if !ok { + return "", fmt.Errorf("session not found for MCP server %s", id) + } + + request := mcp.CallToolRequest{} + request.Params.Name = toolName + request.Params.Arguments = arguments + + result, err := session.Client.CallTool(ctx, request) + if err != nil { + return "", fmt.Errorf("failed to call tool %s: %w", toolName, err) + } + + str, err := json.Marshal(result) + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + + return string(str), nil +} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 6d4e7598..200c453b 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -14,6 +14,7 @@ import ( context2 "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/types" "golang.org/x/exp/maps" ) @@ -37,6 +38,7 @@ type Options struct { CredentialOverrides []string `usage:"-"` Sequential bool `usage:"-"` Authorizer AuthorizerFunc `usage:"-"` + MCPRunner engine.MCPRunner `usage:"-"` } type RunOptions struct { @@ -69,6 +71,9 @@ func Complete(opts ...Options) (result Options) { if opt.CredentialOverrides != nil { result.CredentialOverrides = append(result.CredentialOverrides, opt.CredentialOverrides...) } + if opt.MCPRunner != nil { + result.MCPRunner = opt.MCPRunner + } } return } @@ -87,6 +92,9 @@ func complete(opts ...Options) Options { if result.Authorizer == nil { result.Authorizer = DefaultAuthorizer } + if result.MCPRunner == nil { + result.MCPRunner = mcp.DefaultRunner + } return result } @@ -99,6 +107,7 @@ type Runner struct { credOverrides []string credStore credentials.CredentialStore sequential bool + mcpRunner engine.MCPRunner } func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) { @@ -113,6 +122,7 @@ func New(client engine.Model, credStore credentials.CredentialStore, opts ...Opt credStore: credStore, sequential: opt.Sequential, auth: opt.Authorizer, + mcpRunner: opt.MCPRunner, } if opt.StartPort != 0 { @@ -326,6 +336,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en e := engine.Engine{ Model: r.c, + MCPRunner: r.mcpRunner, RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, @@ -524,6 +535,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s e := engine.Engine{ Model: r.c, + MCPRunner: r.mcpRunner, RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager), Progress: progress, Env: env, diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 1a4e28ea..52a06994 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -29,6 +29,7 @@ type server struct { datasetTool, workspaceTool string serverToolsEnv []string client *gptscript.GPTScript + mcpLoader loader.MCPLoader events *broadcaster.Broadcaster[event] runtimeManager engine.RuntimeManager @@ -283,11 +284,20 @@ func (s *server) load(w http.ResponseWriter, r *http.Request) { } if reqObject.Content != "" { - prg, err = loader.ProgramFromSource(ctx, reqObject.Content, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) + prg, err = loader.ProgramFromSource(ctx, reqObject.Content, reqObject.SubTool, loader.Options{ + Cache: s.client.Cache, + MCPLoader: s.mcpLoader, + }) } else if reqObject.File != "" { - prg, err = loader.Program(ctx, reqObject.File, reqObject.SubTool, loader.Options{Cache: s.client.Cache}) + prg, err = loader.Program(ctx, reqObject.File, reqObject.SubTool, loader.Options{ + Cache: s.client.Cache, + MCPLoader: s.mcpLoader, + }) } else { - prg, err = loader.ProgramFromSource(ctx, reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{Cache: s.client.Cache}) + prg, err = loader.ProgramFromSource(ctx, reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{ + Cache: s.client.Cache, + MCPLoader: s.mcpLoader, + }) } if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index fda4a215..a2c0d505 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -36,7 +36,11 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo if defaultModel == "" { defaultModel = s.gptscriptOpts.OpenAI.DefaultModel } - prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache, DefaultModel: defaultModel}) + prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{ + Cache: g.Cache, + DefaultModel: defaultModel, + MCPLoader: s.mcpLoader, + }) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) return diff --git a/pkg/sdkserver/server.go b/pkg/sdkserver/server.go index f15cc68f..52e9ec1c 100644 --- a/pkg/sdkserver/server.go +++ b/pkg/sdkserver/server.go @@ -16,6 +16,8 @@ import ( "github.com/google/uuid" "github.com/gptscript-ai/broadcaster" "github.com/gptscript-ai/gptscript/pkg/gptscript" + "github.com/gptscript-ai/gptscript/pkg/loader" + "github.com/gptscript-ai/gptscript/pkg/mcp" "github.com/gptscript-ai/gptscript/pkg/mvl" "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" "github.com/gptscript-ai/gptscript/pkg/runner" @@ -26,6 +28,7 @@ import ( type Options struct { gptscript.Options + MCPLoader loader.MCPLoader ListenAddress string DatasetTool, WorkspaceTool string ServerToolsEnv []string @@ -114,6 +117,7 @@ func run(ctx context.Context, listener net.Listener, opts Options) error { serverToolsEnv: opts.ServerToolsEnv, client: g, + mcpLoader: opts.MCPLoader, events: events, runtimeManager: runtimes.Default(opts.Cache.CacheDir, opts.SystemToolsDir), waitingToConfirm: make(map[string]chan runner.AuthorizerResponse), @@ -168,6 +172,7 @@ func complete(opts ...Options) Options { result.WorkspaceTool = types.FirstSet(opt.WorkspaceTool, result.WorkspaceTool) result.Debug = types.FirstSet(opt.Debug, result.Debug) result.DisableServerErrorLogging = types.FirstSet(opt.DisableServerErrorLogging, result.DisableServerErrorLogging) + result.MCPLoader = types.FirstSet(opt.MCPLoader, result.MCPLoader) } if result.ListenAddress == "" { @@ -183,6 +188,9 @@ func complete(opts ...Options) Options { if len(result.ServerToolsEnv) == 0 { result.ServerToolsEnv = os.Environ() } + if result.MCPLoader == nil { + result.MCPLoader = mcp.DefaultLoader + } return result } diff --git a/pkg/tests/runner2_test.go b/pkg/tests/runner2_test.go index f5de8e10..c531c661 100644 --- a/pkg/tests/runner2_test.go +++ b/pkg/tests/runner2_test.go @@ -3,11 +3,13 @@ package tests import ( "context" "encoding/json" + "runtime" "testing" "github.com/gptscript-ai/gptscript/pkg/loader" "github.com/gptscript-ai/gptscript/pkg/runner" "github.com/gptscript-ai/gptscript/pkg/tests/tester" + "github.com/gptscript-ai/gptscript/pkg/types" "github.com/hexops/autogold/v2" "github.com/stretchr/testify/require" ) @@ -203,3 +205,358 @@ echo "${GPTSCRIPT_INPUT}" require.NoError(t, err) autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data) } + +func TestMCPLoad(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + + r := tester.NewRunner(t) + prg, err := loader.ProgramFromSource(context.Background(), ` +name: mcp + +#!mcp + +{ + "mcpServers": { + "sqlite": { + "command": "docker", + "args": [ + "run", + "--rm", + "-i", + "-v", + "mcp-test:/mcp", + "mcp/sqlite@sha256:007ccae941a6f6db15b26ee41d92edda50ce157176d9273449e8b3f51d979c70", + "--db-path", + "/mcp/test.db" + ] + } + } +} +`, "") + require.NoError(t, err) + + autogold.Expect(types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Name: "mcp", + Description: "sqlite", + ModelName: "gpt-4o", + Export: []string{ + "read_query", + "write_query", + "create_table", + "list_tables", + "describe_table", + "append_insight", + }, + }, + MetaData: map[string]string{"bundle": "true"}, + }, + ID: "inline:mcp", + ToolMapping: map[string][]types.ToolReference{ + "append_insight": {{ + Reference: "append_insight", + ToolID: "inline:append_insight", + }}, + "create_table": {{ + Reference: "create_table", + ToolID: "inline:create_table", + }}, + "describe_table": {{ + Reference: "describe_table", + ToolID: "inline:describe_table", + }}, + "list_tables": {{ + Reference: "list_tables", + ToolID: "inline:list_tables", + }}, + "read_query": {{ + Reference: "read_query", + ToolID: "inline:read_query", + }}, + "write_query": {{ + Reference: "write_query", + ToolID: "inline:write_query", + }}, + }, + LocalTools: map[string]string{ + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query", + }, + Source: types.ToolSource{Location: "inline"}, + WorkingDir: ".", + }).Equal(t, prg.ToolSet[prg.EntryToolID]) + autogold.Expect(7).Equal(t, len(prg.ToolSet[prg.EntryToolID].LocalTools)) + data, _ := json.MarshalIndent(prg.ToolSet, "", " ") + autogold.Expect(`{ + "inline:append_insight": { + "name": "append_insight", + "description": "Add a business insight to the memo", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "insight": { + "description": "Business insight discovered from data analysis", + "type": "string" + } + }, + "required": [ + "insight" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke.append_insight 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", + "id": "inline:append_insight", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:create_table": { + "name": "create_table", + "description": "Create a new table in the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "CREATE TABLE SQL statement", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke.create_table 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", + "id": "inline:create_table", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:describe_table": { + "name": "describe_table", + "description": "Get the schema information for a specific table", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "table_name": { + "description": "Name of the table to describe", + "type": "string" + } + }, + "required": [ + "table_name" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke.describe_table 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", + "id": "inline:describe_table", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:list_tables": { + "name": "list_tables", + "description": "List all tables in the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "type": "object" + }, + "instructions": "#!sys.mcp.invoke.list_tables 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", + "id": "inline:list_tables", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:mcp": { + "name": "mcp", + "description": "sqlite", + "modelName": "gpt-4o", + "internalPrompt": null, + "export": [ + "read_query", + "write_query", + "create_table", + "list_tables", + "describe_table", + "append_insight" + ], + "metaData": { + "bundle": "true" + }, + "id": "inline:mcp", + "toolMapping": { + "append_insight": [ + { + "reference": "append_insight", + "toolID": "inline:append_insight" + } + ], + "create_table": [ + { + "reference": "create_table", + "toolID": "inline:create_table" + } + ], + "describe_table": [ + { + "reference": "describe_table", + "toolID": "inline:describe_table" + } + ], + "list_tables": [ + { + "reference": "list_tables", + "toolID": "inline:list_tables" + } + ], + "read_query": [ + { + "reference": "read_query", + "toolID": "inline:read_query" + } + ], + "write_query": [ + { + "reference": "write_query", + "toolID": "inline:write_query" + } + ] + }, + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:read_query": { + "name": "read_query", + "description": "Execute a SELECT query on the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "SELECT SQL query to execute", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke.read_query 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", + "id": "inline:read_query", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + }, + "inline:write_query": { + "name": "write_query", + "description": "Execute an INSERT, UPDATE, or DELETE query on the SQLite database", + "modelName": "gpt-4o", + "internalPrompt": null, + "arguments": { + "properties": { + "query": { + "description": "SQL query to execute", + "type": "string" + } + }, + "required": [ + "query" + ], + "type": "object" + }, + "instructions": "#!sys.mcp.invoke.write_query 607ca64476abf0288ef49061557243e43735fd4de4bc5fdcd51d93049ffa023e", + "id": "inline:write_query", + "localTools": { + "append_insight": "inline:append_insight", + "create_table": "inline:create_table", + "describe_table": "inline:describe_table", + "list_tables": "inline:list_tables", + "mcp": "inline:mcp", + "read_query": "inline:read_query", + "write_query": "inline:write_query" + }, + "source": { + "location": "inline" + }, + "workingDir": "." + } +}`).Equal(t, string(data)) + + prg.EntryToolID = prg.ToolSet[prg.EntryToolID].LocalTools["read_query"] + resp, err := r.Chat(context.Background(), nil, prg, nil, `{"query": "SELECT 1"}`, runner.RunOptions{}) + r.AssertStep(t, resp, err) +} diff --git a/pkg/tests/testdata/TestMCPLoad/call1-resp.golden b/pkg/tests/testdata/TestMCPLoad/call1-resp.golden new file mode 100644 index 00000000..2861a036 --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/call1-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "TEST RESULT CALL: 1" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestMCPLoad/call1.golden b/pkg/tests/testdata/TestMCPLoad/call1.golden new file mode 100644 index 00000000..31048a88 --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/call1.golden @@ -0,0 +1,3 @@ +`{ + "model": "gpt-4o" +}` diff --git a/pkg/tests/testdata/TestMCPLoad/step1.golden b/pkg/tests/testdata/TestMCPLoad/step1.golden new file mode 100644 index 00000000..ae20c8ed --- /dev/null +++ b/pkg/tests/testdata/TestMCPLoad/step1.golden @@ -0,0 +1,6 @@ +`{ + "done": true, + "content": "{\"content\":[{\"type\":\"text\",\"text\":\"[{'1': 1}]\"}]}", + "toolID": "", + "state": null +}` diff --git a/pkg/types/tool.go b/pkg/types/tool.go index 3d48c6e1..10b47c77 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -16,11 +16,14 @@ import ( ) const ( - DaemonPrefix = "#!sys.daemon" - OpenAPIPrefix = "#!sys.openapi" - EchoPrefix = "#!sys.echo" - CallPrefix = "#!sys.call" - CommandPrefix = "#!" + DaemonPrefix = "#!sys.daemon" + OpenAPIPrefix = "#!sys.openapi" + EchoPrefix = "#!sys.echo" + CallPrefix = "#!sys.call" + MCPPrefix = "#!mcp" + MCPInvokePrefix = "#!sys.mcp.invoke." + CommandPrefix = "#!" + PromptPrefix = "!!" ) var ( @@ -876,6 +879,14 @@ func (t Tool) IsDaemon() bool { return strings.HasPrefix(t.Instructions, DaemonPrefix) } +func (t Tool) IsMCP() bool { + return strings.HasPrefix(t.Instructions, MCPPrefix) +} + +func (t Tool) IsMCPInvoke() bool { + return strings.HasPrefix(t.Instructions, MCPInvokePrefix) +} + func (t Tool) IsOpenAPI() bool { return strings.HasPrefix(t.Instructions, OpenAPIPrefix) } diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go index b5e0d1d5..8d379f14 100644 --- a/pkg/types/toolstring.go +++ b/pkg/types/toolstring.go @@ -44,6 +44,10 @@ func ToDisplayText(tool Tool, input string) string { } func ToSysDisplayString(id string, args map[string]string) (string, error) { + if suffix, ok := strings.CutPrefix(id, MCPInvokePrefix); ok { + return fmt.Sprintf("Invoking MCP `%s`", suffix), nil + } + switch id { case "sys.append": return fmt.Sprintf("Appending to file `%s`", args["filename"]), nil