diff --git a/go/go.mod b/go/go.mod index 0659353cd1..92868fbe32 100644 --- a/go/go.mod +++ b/go/go.mod @@ -27,7 +27,7 @@ require ( github.com/jackc/pgx/v5 v5.7.5 github.com/jba/slog v0.2.0 github.com/lib/pq v1.10.9 - github.com/mark3labs/mcp-go v0.29.0 + github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/pgvector/pgvector-go v0.3.0 github.com/weaviate/weaviate v1.30.0 github.com/weaviate/weaviate-go-client/v5 v5.1.0 @@ -50,7 +50,6 @@ require ( github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect - github.com/spf13/cast v1.7.1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/mod v0.25.0 // indirect @@ -59,6 +58,7 @@ require ( require ( cloud.google.com/go/alloydb v1.16.1 // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/go/go.sum b/go/go.sum index c832ade7c6..1d5322054c 100644 --- a/go/go.sum +++ b/go/go.sum @@ -95,8 +95,6 @@ github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfU github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-jose/go-jose/v4 v4.1.0 h1:cYSYxd3pw5zd2FSXk2vGdn9igQU2PS8MuxrCOCl0FdY= github.com/go-jose/go-jose/v4 v4.1.0/go.mod h1:GG/vqmYm3Von2nYiB2vGTXzdoNKE5tix5tuc6iAd+sw= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -179,6 +177,8 @@ github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= @@ -215,6 +215,8 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= @@ -289,8 +291,6 @@ github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= -github.com/mark3labs/mcp-go v0.29.0 h1:sH1NBcumKskhxqYzhXfGc201D7P76TVXiT0fGVhabeI= -github.com/mark3labs/mcp-go v0.29.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/mbleigh/raymond v0.0.0-20250414171441-6b3a58ab9e0a h1:v2cBA3xWKv2cIOVhnzX/gNgkNXqiHfUgJtA3r61Hf7A= @@ -302,6 +302,8 @@ github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= +github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -333,8 +335,6 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= -github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE= diff --git a/go/plugins/mcp/README.md b/go/plugins/mcp/README.md index 238509621a..977814ee00 100644 --- a/go/plugins/mcp/README.md +++ b/go/plugins/mcp/README.md @@ -24,7 +24,8 @@ func main() { g := genkit.Init(ctx) // Connect to the MCP everything server - client, err := mcp.NewGenkitMCPClient(mcp.MCPClientOptions{ + // NewClient uses the context to manage the connection lifecycle. + client, err := mcp.NewClient(ctx, mcp.MCPClientOptions{ Name: "everything-server", Stdio: &mcp.StdioConfig{ Command: "npx", @@ -49,7 +50,9 @@ func main() { } ``` -## GenkitMCPManager - Multiple Server Management +> **Note:** `NewGenkitMCPClient` is deprecated in favor of `NewClient`, which supports context propagation. + +## MCPHost - Multiple Server Management Manage connections to multiple MCP servers: @@ -66,24 +69,28 @@ import ( func main() { ctx := context.Background() - g, _ := genkit.Init(ctx) + g := genkit.Init(ctx) - // Create manager with multiple servers - manager, err := mcp.NewMCPManager(mcp.MCPManagerOptions{ + // Create host with multiple servers + host, err := mcp.NewMCPHost(g, mcp.MCPHostOptions{ Name: "my-app", - MCPServers: map[string]mcp.MCPClientOptions{ - "everything": { - Name: "everything-server", - Stdio: &mcp.StdioConfig{ - Command: "npx", - Args: []string{"-y", "@modelcontextprotocol/server-everything"}, + MCPServers: []mcp.MCPServerConfig{ + { + Name: "everything", + Config: mcp.MCPClientOptions{ + Stdio: &mcp.StdioConfig{ + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-everything"}, + }, }, }, - "filesystem": { - Name: "fs-server", - Stdio: &mcp.StdioConfig{ - Command: "npx", - Args: []string{"@modelcontextprotocol/server-filesystem", "/tmp"}, + { + Name: "filesystem", + Config: mcp.MCPClientOptions{ + Stdio: &mcp.StdioConfig{ + Command: "npx", + Args: []string{"@modelcontextprotocol/server-filesystem", "/tmp"}, + }, }, }, }, @@ -93,7 +100,7 @@ func main() { } // Connect to new server at runtime - err = manager.ConnectServer(ctx, "weather", mcp.MCPClientOptions{ + err = host.Connect(ctx, g, "weather", mcp.MCPClientOptions{ Name: "weather-server", Stdio: &mcp.StdioConfig{ Command: "python", @@ -105,14 +112,10 @@ func main() { } // Temporarily disable/enable servers - manager.DisableServer("filesystem") - manager.EnableServer("filesystem") - - // Disconnect server - manager.DisconnectServer("weather") + host.Disconnect(ctx, "weather") // Get tools from all active servers - tools, err := manager.GetActiveTools(ctx, g) + tools, err := host.GetActiveTools(ctx, g) if err != nil { log.Fatal(err) } @@ -121,7 +124,7 @@ func main() { ## GenkitMCPServer - Expose Genkit Tools -Turn your Genkit app into an MCP server: +Turn your Genkit app into an MCP server that others can connect to: ```go package main @@ -130,69 +133,67 @@ import ( "context" "log" + "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/mcp" ) func main() { - ctx := context.Background() - g := genkit.Init(ctx) + g := genkit.Init(context.Background()) - // Create a host with multiple servers - host, err := mcp.NewMCPHost(g, mcp.MCPHostOptions{ - Name: "my-app", - MCPServers: []mcp.MCPServerConfig{ - { - Name: "everything-server", - Config: mcp.MCPClientOptions{ - Name: "everything-server", - Stdio: &mcp.StdioConfig{ - Command: "npx", - Args: []string{"-y", "@modelcontextprotocol/server-everything"}, - }, - }, - }, - { - Name: "fs-server", - Config: mcp.MCPClientOptions{ - Name: "fs-server", - Stdio: &mcp.StdioConfig{ - Command: "npx", - Args: []string{"@modelcontextprotocol/server-filesystem", "/tmp"}, - }, - }, - }, - }, + // Define tools and resources you want to expose + genkit.DefineTool(g, "hello", "says hello", func(ctx *ai.ToolContext, input any) (string, error) { + return "Hello from Genkit!", nil }) - if err != nil { - log.Fatal(err) - } - // Connect to new server at runtime - err = host.Connect(ctx, g, "weather", mcp.MCPClientOptions{ - Name: "weather-server", - Stdio: &mcp.StdioConfig{ - Command: "python", - Args: []string{"weather_server.py"}, - }, + // Create the MCP server + server := mcp.NewMCPServer(g, mcp.MCPServerOptions{ + Name: "my-genkit-server", + Version: "1.0.0", }) - if err != nil { + + // Start serving over Stdio + // Use ServeStdioWithContext(ctx) for graceful shutdown support. + if err := server.ServeStdio(); err != nil { log.Fatal(err) } +} +``` - // Reconnect server - host.Reconnect(ctx, "fs-server") +### Exposing as an HTTP Server (SSE) - // Disconnect server - host.Disconnect(ctx, "weather") +You can also expose your Genkit tools over HTTP using Server-Sent Events (SSE): - // Get tools from all active servers - tools, err := host.GetActiveTools(ctx, g) - if err != nil { - log.Fatal(err) - } -} +```go +package main + +import ( + "context" + "log" + "net/http" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/mcp" +) + +func main() { + g := genkit.Init(context.Background()) + + // Define tools... + + server := mcp.NewMCPServer(g, mcp.MCPServerOptions{ + Name: "my-genkit-http-server", + }) + + handler, err := server.HTTPHandler() + if err != nil { + log.Fatal(err) + } + + http.Handle("/mcp", handler) + log.Printf("MCP server listening on http://localhost:8080/mcp") + log.Fatal(http.ListenAndServe(":8080", nil)) +} ``` ## Testing Your Server @@ -222,5 +223,6 @@ Stdio: &mcp.StdioConfig{ ```go SSE: &mcp.SSEConfig{ BaseURL: "http://localhost:3000/sse", + Headers: map[string]string{"Authorization": "Bearer token"}, } ``` diff --git a/go/plugins/mcp/client.go b/go/plugins/mcp/client.go index 732e6b9b67..0fa39a5bbc 100644 --- a/go/plugins/mcp/client.go +++ b/go/plugins/mcp/client.go @@ -1,16 +1,18 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 // Package mcp provides a client for integration with the Model Context Protocol. package mcp @@ -19,15 +21,16 @@ import ( "context" "fmt" "net/http" + "os/exec" "time" "github.com/firebase/genkit/go/core/logger" - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/client/transport" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -// StdioConfig holds configuration for a stdio-based MCP server process. +const DefaultHTTPClientTimeout = 30 + +// StdioConfig holds configuration for a stdio-based MCP server process type StdioConfig struct { Command string Env []string @@ -49,17 +52,17 @@ type StreamableHTTPConfig struct { Timeout time.Duration // HTTP request timeout } -// MCPClientOptions holds configuration for the MCPClient. +// MCPClientOptions contains options for the Streamable HTTP transport type MCPClientOptions struct { - // Name for this client instance - ideally a nickname for the server + // Name for this client instance Name string - // Version number for this client (defaults to "1.0.0" if empty) + // Version number for this client (defaults to "1.0.0") Version string // Disabled flag to temporarily disable this client Disabled bool - // Transport options - only one should be provided + // Transport options -- only one should be provided // Stdio contains config for starting a local server process using stdio transport Stdio *StdioConfig @@ -73,147 +76,162 @@ type MCPClientOptions struct { // ServerRef represents an active connection to an MCP server type ServerRef struct { - Client *client.Client - Transport transport.Interface - Error string + Session *mcp.ClientSession + Error error } -// GenkitMCPClient represents a client for interacting with MCP servers. +// GenkitMCPClient represents a client for interacting with MCP servers type GenkitMCPClient struct { options MCPClientOptions + client *mcp.Client server *ServerRef } -// NewGenkitMCPClient creates a new GenkitMCPClient with the given options. -// Returns an error if the initial connection fails. -func NewGenkitMCPClient(options MCPClientOptions) (*GenkitMCPClient, error) { - // Set default values - if options.Name == "" { - options.Name = "unnamed" +// NewClient creates a new GenkitMCPClient with the given options. +func NewClient(ctx context.Context, opts MCPClientOptions) (*GenkitMCPClient, error) { + if opts.Name == "" { + opts.Name = "unnamed" } - if options.Version == "" { - options.Version = "1.0.0" + if opts.Version == "" { + opts.Version = "1.0.0" } - client := &GenkitMCPClient{ - options: options, + c := &GenkitMCPClient{ + options: opts, + client: mcp.NewClient(&mcp.Implementation{ + Name: opts.Name, + Version: opts.Version, + }, nil), } - if err := client.connect(options); err != nil { + if err := c.connect(ctx); err != nil { return nil, fmt.Errorf("failed to initialize MCP client: %w", err) } + if c.server.Error != nil { + return nil, c.server.Error + } + + return c, nil +} - return client, nil +// NewGenkitMCPClient creates a new [GenkitMCPClient] with the given options. +// Deprecated: Use NewClient(ctx, opts) instead. +func NewGenkitMCPClient(opts MCPClientOptions) (*GenkitMCPClient, error) { + return NewClient(context.Background(), opts) } // connect establishes a connection to an MCP server -func (c *GenkitMCPClient) connect(options MCPClientOptions) error { - // Close existing connection if any - if c.server != nil { - if err := c.server.Client.Close(); err != nil { - ctx := context.Background() - logger.FromContext(ctx).Warn("Error closing previous MCP transport", "client", c.options.Name, "error", err) +func (c *GenkitMCPClient) connect(ctx context.Context) error { + if c.server != nil && c.server.Session != nil { + if err := c.server.Session.Close(); err != nil { + logger.FromContext(ctx).Warn("Error closing previous MCP session", "client", c.options.Name, "error", err) } } - // Create and configure transport - transport, err := c.createTransport(options) - if err != nil { - return err + // if disabled, return without establishing a session + if c.options.Disabled { + c.server = nil + return nil } - // Start the transport - ctx := context.Background() - if err := transport.Start(ctx); err != nil { - return fmt.Errorf("failed to start transport: %w", err) + transport, err := c.createTransport() + if err != nil { + // no transport means no ability to create a server + c.server = &ServerRef{Error: err} + return err } - // Create MCP client - mcpClient := client.NewClient(transport) - - // Initialize the client if not disabled - var serverError string - if !options.Disabled { - serverError = c.initializeClient(ctx, mcpClient, options.Version) + session, err := c.client.Connect(ctx, transport, nil) + if err != nil { + c.server = &ServerRef{ + Error: err, + } + return fmt.Errorf("failed to connect to MCP server: %w", err) } c.server = &ServerRef{ - Client: mcpClient, - Transport: transport, - Error: serverError, + Session: session, } return nil } -// createTransport creates the appropriate transport based on client options -func (c *GenkitMCPClient) createTransport(options MCPClientOptions) (transport.Interface, error) { - if options.Stdio != nil { - return transport.NewStdio(options.Stdio.Command, options.Stdio.Env, options.Stdio.Args...), nil +// headerTransport is a [http.RoundTripper] that adds custom headers to every request +type headerTransport struct { + rt http.RoundTripper + headers map[string]string +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for k, v := range t.headers { + req.Header.Set(k, v) } + return t.rt.RoundTrip(req) +} - if options.SSE != nil { - var sseOptions []transport.ClientOption - if options.SSE.Headers != nil { - sseOptions = append(sseOptions, transport.WithHeaders(options.SSE.Headers)) +// wrapHTTPClient wraps an existing client with custom headers +func wrapHTTPClient(client *http.Client, headers map[string]string) *http.Client { + if len(headers) == 0 { + if client == nil { + return http.DefaultClient } - if options.SSE.HTTPClient != nil { - sseOptions = append(sseOptions, transport.WithHTTPClient(options.SSE.HTTPClient)) - } - - return transport.NewSSE(options.SSE.BaseURL, sseOptions...) + return client } - if options.StreamableHTTP != nil { - var streamableHTTPOptions []transport.StreamableHTTPCOption - if options.StreamableHTTP.Headers != nil { - streamableHTTPOptions = append(streamableHTTPOptions, transport.WithHTTPHeaders(options.StreamableHTTP.Headers)) - } - if options.StreamableHTTP.Timeout > 0 { - streamableHTTPOptions = append(streamableHTTPOptions, transport.WithHTTPTimeout(options.StreamableHTTP.Timeout)) - } + newClient := &http.Client{} + if client != nil { + *newClient = *client + } else { + newClient.Timeout = DefaultHTTPClientTimeout * time.Second + } - transportImpl, err := transport.NewStreamableHTTP(options.StreamableHTTP.BaseURL, streamableHTTPOptions...) - if err != nil { - return nil, fmt.Errorf("failed to create streamable HTTP transport: %w", err) - } - return transportImpl, nil + transport := newClient.Transport + if transport == nil { + transport = http.DefaultTransport } - return nil, fmt.Errorf("no valid transport configuration provided: must specify Stdio, SSE, or StreamableHTTP") + newClient.Transport = &headerTransport{ + rt: transport, + headers: headers, + } + return newClient } -// initializeClient initializes the MCP client connection -func (c *GenkitMCPClient) initializeClient(ctx context.Context, mcpClient *client.Client, version string) string { - initReq := mcp.InitializeRequest{ - Params: struct { - ProtocolVersion string `json:"protocolVersion"` - Capabilities mcp.ClientCapabilities `json:"capabilities"` - ClientInfo mcp.Implementation `json:"clientInfo"` - }{ - ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, - ClientInfo: mcp.Implementation{ - Name: "genkit-mcp-client", - Version: version, - }, - Capabilities: mcp.ClientCapabilities{}, - }, +// createTransport creates the appropriate transport based on client options +func (c *GenkitMCPClient) createTransport() (mcp.Transport, error) { + if c.options.Stdio != nil { + cmd := exec.Command(c.options.Stdio.Command, c.options.Stdio.Args...) + cmd.Env = c.options.Stdio.Env + return &mcp.CommandTransport{ + Command: cmd, + }, nil } - _, err := mcpClient.Initialize(ctx, initReq) - if err != nil { - return err.Error() + if c.options.SSE != nil { + httpClient := wrapHTTPClient(c.options.SSE.HTTPClient, c.options.SSE.Headers) + return &mcp.SSEClientTransport{ + Endpoint: c.options.SSE.BaseURL, + HTTPClient: httpClient, + }, nil } - return "" + if c.options.StreamableHTTP != nil { + httpClient := wrapHTTPClient(c.options.StreamableHTTP.HTTPClient, c.options.StreamableHTTP.Headers) + return &mcp.StreamableClientTransport{ + Endpoint: c.options.StreamableHTTP.BaseURL, + HTTPClient: httpClient, + }, nil + } + + return nil, fmt.Errorf("no valid transport configuration provided: must specify Stdio, SSE or StreamableHTTP") } -// Name returns the client name +// Name returns the name of the client func (c *GenkitMCPClient) Name() string { return c.options.Name } -// IsEnabled returns whether the client is enabled +// IsEnabled return whether the client is enabled func (c *GenkitMCPClient) IsEnabled() bool { return !c.options.Disabled } @@ -226,26 +244,35 @@ func (c *GenkitMCPClient) Disable() { } } -// Reenable re-enables a previously disabled client by reconnecting -func (c *GenkitMCPClient) Reenable() { +// ReenableWithContext re-enables a previously disabled client by reconnecting it. +func (c *GenkitMCPClient) ReenableWithContext(ctx context.Context) error { if c.options.Disabled { c.options.Disabled = false - c.connect(c.options) + return c.connect(ctx) + } + return nil +} + +// Reenable re-enables a previously disabled client by reconnecting it. +// Deprecated: Use ReenableWithContext instead. +func (c *GenkitMCPClient) Reenable() { + if err := c.ReenableWithContext(context.Background()); err != nil { + logger.FromContext(context.Background()).Warn("failed to re-enable MCP client", "client", c.options.Name, "error", err) } } // Restart restarts the transport connection func (c *GenkitMCPClient) Restart(ctx context.Context) error { if err := c.Disconnect(); err != nil { - logger.FromContext(ctx).Warn("Error closing MCP transport during restart", "client", c.options.Name, "error", err) + logger.FromContext(ctx).Warn("Error closing MCP session during restart", "client", c.options.Name, "error", err) } - return c.connect(c.options) + return c.connect(ctx) } -// Disconnect closes the connection to the MCP server +// Disconnect closes the session with the MCP server func (c *GenkitMCPClient) Disconnect() error { - if c.server != nil { - err := c.server.Client.Close() + if c.server != nil && c.server.Session != nil { + err := c.server.Session.Close() c.server = nil return err } diff --git a/go/plugins/mcp/client_test.go b/go/plugins/mcp/client_test.go new file mode 100644 index 0000000000..47f1cb4f41 --- /dev/null +++ b/go/plugins/mcp/client_test.go @@ -0,0 +1,58 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package mcp + +import ( + "net/http" + "testing" +) + +func TestWrapHTTPClient(t *testing.T) { + t.Run("nil headers returns client as is", func(t *testing.T) { + original := &http.Client{} + got := wrapHTTPClient(original, nil) + if got != original { + t.Errorf("wrapHTTPClient(nil headers) got different pointer, want same") + } + }) + + t.Run("nil client returns default with timeout", func(t *testing.T) { + got := wrapHTTPClient(nil, map[string]string{"X-Test": "Value"}) + if got == nil { + t.Fatal("wrapHTTPClient(nil client) returned nil") + } + if got.Transport == nil { + t.Fatal("transport is nil, want headerTransport") + } + + _, ok := got.Transport.(*headerTransport) + if !ok { + t.Errorf("transport type got = %T, want *headerTransport", got.Transport) + } + }) +} + +func TestMCPClientDefaults(t *testing.T) { + opts := MCPClientOptions{ + Name: "my-mcp", + } + + c := &GenkitMCPClient{options: opts} + if got := c.Name(); got != "my-mcp" { + t.Errorf("c.Name() got = %q, want %q", got, "my-mcp") + } +} diff --git a/go/plugins/mcp/common.go b/go/plugins/mcp/common.go index c942f6c5bc..4f1734912a 100644 --- a/go/plugins/mcp/common.go +++ b/go/plugins/mcp/common.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,39 +11,34 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 -// Package mcp provides a client for integration with the Model Context Protocol. package mcp import ( - "fmt" - - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -// GetPromptNameWithNamespace returns a prompt name prefixed with the client's namespace -func (c *GenkitMCPClient) GetPromptNameWithNamespace(promptName string) string { - return fmt.Sprintf("%s_%s", c.options.Name, promptName) -} - -// GetToolNameWithNamespace returns a tool name prefixed with the client's namespace -func (c *GenkitMCPClient) GetToolNameWithNamespace(toolName string) string { - return fmt.Sprintf("%s_%s", c.options.Name, toolName) -} - -// GetResourceNameWithNamespace returns a resource name prefixed with the client's namespace -func (c *GenkitMCPClient) GetResourceNameWithNamespace(resourceName string) string { - return fmt.Sprintf("%s_%s", c.options.Name, resourceName) -} +const ( + RoleUser mcp.Role = "user" + RoleAssistant mcp.Role = "assistant" +) // ExtractTextFromContent extracts text content from MCP Content func ExtractTextFromContent(content mcp.Content) string { - if textContent, ok := content.(mcp.TextContent); ok && textContent.Type == "text" { - return textContent.Text - } else if resourceContent, ok := content.(mcp.EmbeddedResource); ok { - if textResource, ok := resourceContent.Resource.(mcp.TextResourceContents); ok { - return textResource.Text + if content == nil { + return "" + } + + switch c := content.(type) { + case *mcp.TextContent: + return c.Text + case *mcp.EmbeddedResource: + if c.Resource != nil { + return c.Resource.Text } } + return "" } diff --git a/go/plugins/mcp/host.go b/go/plugins/mcp/host.go index d5b73cc6bb..d081c4fb7a 100644 --- a/go/plugins/mcp/host.go +++ b/go/plugins/mcp/host.go @@ -49,8 +49,9 @@ type MCPHost struct { clients map[string]*GenkitMCPClient // Internal map for efficient lookups } -// NewMCPHost creates a new MCPHost with the given options -func NewMCPHost(g *genkit.Genkit, options MCPHostOptions) (*MCPHost, error) { +// NewMCPHostWithContext creates a new MCPHost with the given options and context. +// It connects to the configured servers using the provided context. +func NewMCPHostWithContext(ctx context.Context, g *genkit.Genkit, options MCPHostOptions) (*MCPHost, error) { // Set default values if options.Name == "" { options.Name = "genkit-mcp" @@ -66,7 +67,6 @@ func NewMCPHost(g *genkit.Genkit, options MCPHostOptions) (*MCPHost, error) { } // Connect to all servers synchronously during initialization - ctx := context.Background() for _, serverConfig := range options.MCPServers { if err := host.Connect(ctx, g, serverConfig.Name, serverConfig.Config); err != nil { logger.FromContext(ctx).Error("Failed to connect to MCP server", "server", serverConfig.Name, "host", host.name, "error", err) @@ -77,6 +77,12 @@ func NewMCPHost(g *genkit.Genkit, options MCPHostOptions) (*MCPHost, error) { return host, nil } +// NewMCPHost creates a new MCPHost with the given options. +// Deprecated: Use NewMCPHostWithContext instead. +func NewMCPHost(g *genkit.Genkit, options MCPHostOptions) (*MCPHost, error) { + return NewMCPHostWithContext(context.Background(), g, options) +} + // Connect connects to a single MCP server with the provided configuration // and automatically registers tools, prompts, and resources from the server func (h *MCPHost) Connect(ctx context.Context, g *genkit.Genkit, serverName string, config MCPClientOptions) error { @@ -95,7 +101,7 @@ func (h *MCPHost) Connect(ctx context.Context, g *genkit.Genkit, serverName stri } // Create and connect the client - client, err := NewGenkitMCPClient(config) + client, err := NewClient(ctx, config) if err != nil { return fmt.Errorf("error connecting to server %s: %w", serverName, err) } diff --git a/go/plugins/mcp/http_test.go b/go/plugins/mcp/http_test.go new file mode 100644 index 0000000000..bbf0d2a1b5 --- /dev/null +++ b/go/plugins/mcp/http_test.go @@ -0,0 +1,106 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "context" + "encoding/json" + "math" + "net/http" + "net/http/httptest" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestHTTPServerIntegration(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + + genkit.DefineTool(g, "gablorken", "calculates a gablorken", + func(ctx *ai.ToolContext, input struct { + Value int + Over float64 + }, + ) (float64, error) { + return math.Pow(float64(input.Value), input.Over), nil + }, + ) + + server := NewMCPServer(g, MCPServerOptions{Name: "http-server"}) + handler, err := server.HTTPHandler() + if err != nil { + t.Fatalf("HTTPHandler failed: %v", err) + } + + mux := http.NewServeMux() + mux.Handle("/mcp", handler) + ts := httptest.NewServer(mux) + defer ts.Close() + + client, err := NewClient(ctx, MCPClientOptions{ + Name: "remote-mcp", + SSE: &SSEConfig{ + BaseURL: ts.URL + "/mcp", + }, + }) + if err != nil { + t.Fatalf("NewGenkitMCPClient failed: %v", err) + } + defer client.Disconnect() + + tools, err := client.GetActiveTools(ctx, g) + if err != nil { + t.Fatalf("GetActiveTools failed: %v", err) + } + + var gablorken ai.Tool + for _, tool := range tools { + if tool.Name() == "remote-mcp_gablorken" { + gablorken = tool + break + } + } + + if gablorken == nil { + t.Fatal("gablorken tool not found via SSE") + } + + args := map[string]any{"Value": 2, "Over": 3.0} + rawRes, err := gablorken.RunRaw(ctx, args) + if err != nil { + t.Fatalf("gablorken.RunRaw failed: %v", err) + } + bytes, err := json.Marshal(rawRes) + if err != nil { + t.Fatalf("failed to marshal result: %v", err) + } + + var res mcp.CallToolResult + if err := json.Unmarshal(bytes, &res); err != nil { + t.Fatalf("failed to unmarshal into CallToolResult: %v", err) + } + if len(res.Content) == 0 { + t.Fatal("expected result content, got none") + } + + gotText := ExtractTextFromContent(res.Content[0]) + wantText := "8" + if gotText != wantText { + t.Errorf("result text got = %q, want %q", gotText, wantText) + } +} diff --git a/go/plugins/mcp/prompts.go b/go/plugins/mcp/prompts.go index c0782aa429..841954f241 100644 --- a/go/plugins/mcp/prompts.go +++ b/go/plugins/mcp/prompts.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,8 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 -// Package mcp provides a client for integration with the Model Context Protocol. package mcp import ( @@ -21,140 +22,74 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -// GetPrompt retrieves a prompt from the MCP server -func (c *GenkitMCPClient) GetPrompt(ctx context.Context, g *genkit.Genkit, promptName string, args map[string]string) (ai.Prompt, error) { - if !c.IsEnabled() || c.server == nil { - return nil, fmt.Errorf("MCP client is disabled or not connected") +func (c *GenkitMCPClient) GetActivePrompts(ctx context.Context) ([]mcp.Prompt, error) { + if !c.IsEnabled() || c.server == nil || c.server.Session == nil { + return nil, nil } - - // Check if prompt already exists - namespacedPromptName := c.GetPromptNameWithNamespace(promptName) - if existingPrompt := genkit.LookupPrompt(g, namespacedPromptName); existingPrompt != nil { - return existingPrompt, nil + if c.server.Error != nil { + return nil, c.server.Error } - // Fetch prompt from MCP server - mcpPrompt, err := c.fetchMCPPrompt(ctx, promptName, args) - if err != nil { - return nil, err + var prompts []mcp.Prompt + for p, err := range c.server.Session.Prompts(ctx, nil) { + if err != nil { + return nil, fmt.Errorf("failed to list prompts: %w", err) + } + prompts = append(prompts, *p) } - // Convert and register the prompt - return c.createGenkitPrompt(g, namespacedPromptName, mcpPrompt) + return prompts, nil } -// fetchMCPPrompt retrieves a prompt from the MCP server -func (c *GenkitMCPClient) fetchMCPPrompt(ctx context.Context, promptName string, args map[string]string) (*mcp.GetPromptResult, error) { - req := mcp.GetPromptRequest{ - Params: struct { - Name string `json:"name"` - Arguments map[string]string `json:"arguments,omitempty"` - }{ - Name: promptName, - Arguments: args, - }, - } - - result, err := c.server.Client.GetPrompt(ctx, req) - if err != nil { - return nil, fmt.Errorf("failed to get prompt %s: %w", promptName, err) +func (c *GenkitMCPClient) GetPrompt(ctx context.Context, g *genkit.Genkit, name string, args map[string]string) (ai.Prompt, error) { + if !c.IsEnabled() || c.server == nil || c.server.Session == nil { + return nil, fmt.Errorf("MCP client is disabled or not connected") } - return result, nil -} - -// createGenkitPrompt converts MCP prompt to Genkit prompt and registers it -func (c *GenkitMCPClient) createGenkitPrompt(g *genkit.Genkit, promptName string, mcpPrompt *mcp.GetPromptResult) (ai.Prompt, error) { - messages := c.convertMCPMessages(mcpPrompt.Messages) - - promptOpts := []ai.PromptOption{ - ai.WithDescription(mcpPrompt.Description), + promptName := fmt.Sprintf("%s_%s", c.options.Name, name) + if prompt := genkit.LookupPrompt(g, promptName); prompt != nil { + return prompt, nil } - if len(messages) > 0 { - promptOpts = append(promptOpts, ai.WithMessages(messages...)) + res, err := c.server.Session.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: name, + Arguments: args, + }) + if err != nil { + return nil, fmt.Errorf("failed to get prompt %s: %w", name, err) } - prompt := genkit.DefinePrompt(g, promptName, promptOpts...) + msgs := c.toGenkitMessages(res.Messages) + prompt := genkit.DefinePrompt(g, promptName, + ai.WithDescription(res.Description), + ai.WithMessages(msgs...), + ) return prompt, nil } -// convertMCPMessages converts MCP messages to Genkit messages -func (c *GenkitMCPClient) convertMCPMessages(mcpMessages []mcp.PromptMessage) []*ai.Message { +func (c *GenkitMCPClient) toGenkitMessages(mcpMessages []*mcp.PromptMessage) []*ai.Message { var messages []*ai.Message for _, msg := range mcpMessages { + role := ai.RoleUser + // "assistant" as per the MCP spec + if msg.Role == "assistant" { + role = ai.RoleModel + } + text := ExtractTextFromContent(msg.Content) if text == "" { continue } - - switch msg.Role { - case mcp.RoleUser: - messages = append(messages, ai.NewUserTextMessage(text)) - case mcp.RoleAssistant: - messages = append(messages, ai.NewModelTextMessage(text)) - } + messages = append(messages, &ai.Message{ + Role: role, + Content: []*ai.Part{ai.NewTextPart(text)}, + }) } return messages } - -// GetActivePrompts retrieves all prompts available from the MCP server -func (c *GenkitMCPClient) GetActivePrompts(ctx context.Context) ([]mcp.Prompt, error) { - if !c.IsEnabled() || c.server == nil { - return nil, nil - } - - // Get all MCP prompts - return c.getPrompts(ctx) -} - -// getPrompts retrieves all prompts from the MCP server by paginating through results -func (c *GenkitMCPClient) getPrompts(ctx context.Context) ([]mcp.Prompt, error) { - var allMcpPrompts []mcp.Prompt - var cursor mcp.Cursor - - // Paginate through all available prompts from the MCP server - for { - // Fetch a page of prompts - mcpPrompts, nextCursor, err := c.fetchPromptsPage(ctx, cursor) - if err != nil { - return nil, err - } - - allMcpPrompts = append(allMcpPrompts, mcpPrompts...) - - // Check if we've reached the last page - cursor = nextCursor - if cursor == "" { - break - } - } - - return allMcpPrompts, nil -} - -// fetchPromptsPage retrieves a single page of prompts from the MCP server -func (c *GenkitMCPClient) fetchPromptsPage(ctx context.Context, cursor mcp.Cursor) ([]mcp.Prompt, mcp.Cursor, error) { - listReq := mcp.ListPromptsRequest{ - PaginatedRequest: mcp.PaginatedRequest{ - Params: struct { - Cursor mcp.Cursor `json:"cursor,omitempty"` - }{ - Cursor: cursor, - }, - }, - } - - result, err := c.server.Client.ListPrompts(ctx, listReq) - if err != nil { - return nil, "", fmt.Errorf("failed to list prompts: %w", err) - } - - return result.Prompts, result.NextCursor, nil -} diff --git a/go/plugins/mcp/prompts_test.go b/go/plugins/mcp/prompts_test.go new file mode 100644 index 0000000000..1e05d3e321 --- /dev/null +++ b/go/plugins/mcp/prompts_test.go @@ -0,0 +1,61 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package mcp + +import ( + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestToGenkitMessages(t *testing.T) { + client := &GenkitMCPClient{} + + mcpMessages := []*mcp.PromptMessage{ + { + Role: "user", + Content: &mcp.TextContent{Text: "how are you?"}, + }, + { + Role: "assistant", + Content: &mcp.TextContent{Text: "I am fine"}, + }, + } + + got := client.toGenkitMessages(mcpMessages) + + if len(got) != 2 { + t.Fatalf("len(messages) got = %d, want 2", len(got)) + } + + // Test User Message + if got[0].Role != ai.RoleUser { + t.Errorf("msg[0].Role got = %v, want %v", got[0].Role, ai.RoleUser) + } + if got[0].Content[0].Text != "how are you?" { + t.Errorf("msg[0].Text got = %q, want %q", got[0].Content[0].Text, "how are you?") + } + + // Test Assistant -> Model Role Mapping + if got[1].Role != ai.RoleModel { + t.Errorf("msg[1].Role got = %v, want %v (RoleModel)", got[1].Role, ai.RoleModel) + } + if got[1].Content[0].Text != "I am fine" { + t.Errorf("msg[1].Text got = %q, want %q", got[1].Content[0].Text, "I am fine") + } +} diff --git a/go/plugins/mcp/resources.go b/go/plugins/mcp/resources.go index 5715721b1e..7931d46a4d 100644 --- a/go/plugins/mcp/resources.go +++ b/go/plugins/mcp/resources.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,281 +11,127 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package mcp import ( "context" + "encoding/base64" "fmt" "github.com/firebase/genkit/go/ai" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -// GetActiveResources fetches resources from the MCP server func (c *GenkitMCPClient) GetActiveResources(ctx context.Context) ([]ai.Resource, error) { - if !c.IsEnabled() || c.server == nil { - return nil, fmt.Errorf("MCP client is disabled or not connected") - } - - var resources []ai.Resource - - // Fetch static resources - staticResources, err := c.getStaticResources(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get resources from %s: %w", c.options.Name, err) - } - resources = append(resources, staticResources...) - - // Fetch template resources (optional - not all servers support templates) - templateResources, err := c.getTemplateResources(ctx) - if err != nil { - // Templates not supported by all servers, continue without them - return resources, nil + if !c.IsEnabled() || c.server == nil || c.server.Session == nil { + return nil, nil } - resources = append(resources, templateResources...) - - return resources, nil -} - -// getStaticResources retrieves and converts static MCP resources to Genkit resources -func (c *GenkitMCPClient) getStaticResources(ctx context.Context) ([]ai.Resource, error) { - mcpResources, err := c.getResources(ctx) - if err != nil { - return nil, err + if c.server.Error != nil { + return nil, c.server.Error } var resources []ai.Resource - for _, mcpResource := range mcpResources { - resource, err := c.toGenkitResource(mcpResource) - if err != nil { - return nil, fmt.Errorf("failed to create resource %s: %w", mcpResource.Name, err) - } - resources = append(resources, resource) - } - return resources, nil -} - -// getTemplateResources retrieves and converts MCP resource templates to Genkit resources -func (c *GenkitMCPClient) getTemplateResources(ctx context.Context) ([]ai.Resource, error) { - mcpTemplates, err := c.getResourceTemplates(ctx) - if err != nil { - return nil, err - } - var resources []ai.Resource - for _, mcpTemplate := range mcpTemplates { - resource, err := c.toGenkitResourceTemplate(mcpTemplate) + // fetch static resources (URIs like "file:///logs/today.txt") + for res, err := range c.server.Session.Resources(ctx, nil) { if err != nil { - return nil, fmt.Errorf("failed to create resource template %s: %w", mcpTemplate.Name, err) + return nil, fmt.Errorf("failed to list resources: %w", err) } - resources = append(resources, resource) + resources = append(resources, c.toGenkitResource(res)) } - return resources, nil -} - -// toGenkitResource creates a Genkit resource from an MCP static resource -func (c *GenkitMCPClient) toGenkitResource(mcpResource mcp.Resource) (ai.Resource, error) { - // Create namespaced resource name - resourceName := c.GetResourceNameWithNamespace(mcpResource.Name) - // Create Genkit resource that bridges to MCP - return ai.NewResource(resourceName, &ai.ResourceOptions{ - URI: mcpResource.URI, - Description: mcpResource.Description, - Metadata: map[string]any{ - "mcp_server": c.options.Name, - "mcp_name": mcpResource.Name, - "source": "mcp", - "mime_type": mcpResource.MIMEType, - }, - }, func(ctx context.Context, input *ai.ResourceInput) (*ai.ResourceOutput, error) { - output, err := c.readMCPResource(ctx, input.URI) + // fetch resource templates (dynamic URIS like "db://{table}/{id}") + for res, err := range c.server.Session.ResourceTemplates(ctx, nil) { if err != nil { - return nil, err + return nil, fmt.Errorf("failed to list resource templates: %w", err) } - return &ai.ResourceOutput{Content: output.Content}, nil - }), nil -} - -// toGenkitResourceTemplate creates a Genkit template resource from MCP template -func (c *GenkitMCPClient) toGenkitResourceTemplate(mcpTemplate mcp.ResourceTemplate) (ai.Resource, error) { - resourceName := c.GetResourceNameWithNamespace(mcpTemplate.Name) - - // Convert URITemplate to string - extract the raw template string - var templateStr string - if mcpTemplate.URITemplate != nil && mcpTemplate.URITemplate.Template != nil { - templateStr = mcpTemplate.URITemplate.Template.Raw() - } - - // Validate template - return error instead of panicking - if templateStr == "" { - return nil, fmt.Errorf("MCP resource template %s has empty URI template", mcpTemplate.Name) + resources = append(resources, c.toGenkitResourceTemplate(res)) } - return ai.NewResource(resourceName, &ai.ResourceOptions{ - Template: templateStr, - Description: mcpTemplate.Description, - Metadata: map[string]any{ - "mcp_server": c.options.Name, - "mcp_name": mcpTemplate.Name, - "mcp_template": templateStr, - "source": "mcp", - "mime_type": mcpTemplate.MIMEType, - }, - }, func(ctx context.Context, input *ai.ResourceInput) (*ai.ResourceOutput, error) { - output, err := c.readMCPResource(ctx, input.URI) - if err != nil { - return nil, err - } - return &ai.ResourceOutput{Content: output.Content}, nil - }), nil + return resources, nil } -// readMCPResource fetches content from MCP server for a given URI -func (c *GenkitMCPClient) readMCPResource(ctx context.Context, uri string) (ai.ResourceOutput, error) { - if !c.IsEnabled() || c.server == nil { - return ai.ResourceOutput{}, fmt.Errorf("MCP client is disabled or not connected") - } - - // Create ReadResource request - readReq := mcp.ReadResourceRequest{ - Params: struct { - URI string `json:"uri"` - Arguments map[string]interface{} `json:"arguments,omitempty"` - }{ - URI: uri, - Arguments: nil, - }, - } - - // Call the MCP server to read the resource - readResp, err := c.server.Client.ReadResource(ctx, readReq) - if err != nil { - return ai.ResourceOutput{}, fmt.Errorf("failed to read resource from MCP server %s: %w", c.options.Name, err) - } +func (c *GenkitMCPClient) toGenkitResource(r *mcp.Resource) ai.Resource { + name := fmt.Sprintf("%s_%s", c.options.Name, r.Name) - // Convert MCP ResourceContents to Genkit Parts - parts, err := convertMCPResourceContentsToGenkitParts(readResp.Contents) - if err != nil { - return ai.ResourceOutput{}, fmt.Errorf("failed to convert MCP resource contents to Genkit parts: %w", err) + metadata := map[string]any{ + "mcp_server": c.options.Name, + "mcp_uri": r.URI, } - return ai.ResourceOutput{Content: parts}, nil -} - -// getResources retrieves all resources from the MCP server by paginating through results -func (c *GenkitMCPClient) getResources(ctx context.Context) ([]mcp.Resource, error) { - var allResources []mcp.Resource - var cursor mcp.Cursor - - // Paginate through all available resources from the MCP server - for { - // Fetch a page of resources - resources, nextCursor, err := c.fetchResourcesPage(ctx, cursor) - if err != nil { - return nil, err + if r.Annotations != nil { + if r.Annotations.Audience != nil { + metadata["audience"] = r.Annotations.Audience } - - allResources = append(allResources, resources...) - - // Check if we've reached the last page - cursor = nextCursor - if cursor == "" { - break + if r.Annotations.Priority != 0 { + metadata["priority"] = r.Annotations.Priority + } + if r.Annotations.LastModified != "" { + metadata["last_modified"] = r.Annotations.LastModified } } - return allResources, nil + return ai.NewResource(name, &ai.ResourceOptions{ + URI: r.URI, + Description: r.Description, + Metadata: metadata, + }, c.readResourceHandler) } -// fetchResourcesPage retrieves a single page of resources from the MCP server -func (c *GenkitMCPClient) fetchResourcesPage(ctx context.Context, cursor mcp.Cursor) ([]mcp.Resource, mcp.Cursor, error) { - // Build the list request - include cursor if we have one for pagination - listReq := mcp.ListResourcesRequest{} - listReq.PaginatedRequest = mcp.PaginatedRequest{ - Params: struct { - Cursor mcp.Cursor `json:"cursor,omitempty"` - }{ - Cursor: cursor, - }, - } - - // Ask the MCP server for resources - result, err := c.server.Client.ListResources(ctx, listReq) - if err != nil { - return nil, "", fmt.Errorf("failed to list resources from MCP server %s: %w", c.options.Name, err) - } +func (c *GenkitMCPClient) toGenkitResourceTemplate(rt *mcp.ResourceTemplate) ai.Resource { + name := fmt.Sprintf("%s_%s", c.options.Name, rt.Name) - return result.Resources, result.NextCursor, nil + return ai.NewResource(name, &ai.ResourceOptions{ + Template: rt.URITemplate, + Description: rt.Description, + }, c.readResourceHandler) } -// getResourceTemplates retrieves all resource templates from the MCP server by paginating through results -func (c *GenkitMCPClient) getResourceTemplates(ctx context.Context) ([]mcp.ResourceTemplate, error) { - var allTemplates []mcp.ResourceTemplate - var cursor mcp.Cursor - - // Paginate through all available resource templates from the MCP server - for { - // Fetch a page of resource templates - templates, nextCursor, err := c.fetchResourceTemplatesPage(ctx, cursor) - if err != nil { - return nil, err - } - - allTemplates = append(allTemplates, templates...) - - // Check if we've reached the last page - cursor = nextCursor - if cursor == "" { - break - } +func (c *GenkitMCPClient) readResourceHandler(ctx context.Context, input *ai.ResourceInput) (*ai.ResourceOutput, error) { + if c.server == nil || c.server.Session == nil { + return nil, fmt.Errorf("MCP session is closed") + } + if c.server.Error != nil { + return nil, c.server.Error } - return allTemplates, nil -} - -// fetchResourceTemplatesPage retrieves a single page of resource templates from the MCP server -func (c *GenkitMCPClient) fetchResourceTemplatesPage(ctx context.Context, cursor mcp.Cursor) ([]mcp.ResourceTemplate, mcp.Cursor, error) { - listReq := mcp.ListResourceTemplatesRequest{ - PaginatedRequest: mcp.PaginatedRequest{ - Params: struct { - Cursor mcp.Cursor `json:"cursor,omitempty"` - }{ - Cursor: cursor, - }, - }, + params := &mcp.ReadResourceParams{ + URI: input.URI, } - result, err := c.server.Client.ListResourceTemplates(ctx, listReq) + res, err := c.server.Session.ReadResource(ctx, params) if err != nil { - return nil, "", fmt.Errorf("failed to list resource templates from MCP server %s: %w", c.options.Name, err) + return nil, fmt.Errorf("failed to read MCP resource %s: %w", input.URI, err) } - return result.ResourceTemplates, result.NextCursor, nil + parts := c.toGenkitParts(res.Contents) + + return &ai.ResourceOutput{ + Content: parts, + }, nil } -// convertMCPResourceContentsToGenkitParts converts MCP ResourceContents to Genkit Parts -func convertMCPResourceContentsToGenkitParts(mcpContents []mcp.ResourceContents) ([]*ai.Part, error) { +func (c *GenkitMCPClient) toGenkitParts(contents []*mcp.ResourceContents) []*ai.Part { var parts []*ai.Part - for _, content := range mcpContents { - // Handle TextResourceContents - if textContent, ok := content.(mcp.TextResourceContents); ok { - parts = append(parts, ai.NewTextPart(textContent.Text)) + for _, cont := range contents { + if cont.Text != "" { + if cont.MIMEType == "application/json" { + parts = append(parts, ai.NewDataPart(cont.Text)) + continue + } + parts = append(parts, ai.NewTextPart(cont.Text)) continue } - // Handle BlobResourceContents - if blobContent, ok := content.(mcp.BlobResourceContents); ok { - // Create media part using ai.NewMediaPart for binary data - parts = append(parts, ai.NewMediaPart(blobContent.MIMEType, blobContent.Blob)) - continue + if len(cont.Blob) > 0 { + encodedString := base64.StdEncoding.EncodeToString(cont.Blob) + parts = append(parts, ai.NewMediaPart(cont.MIMEType, encodedString)) } - - // Handle unknown resource content types as text - parts = append(parts, ai.NewTextPart(fmt.Sprintf("[Unknown MCP resource content type: %T]", content))) } - return parts, nil + return parts } diff --git a/go/plugins/mcp/resources_test.go b/go/plugins/mcp/resources_test.go index 7446e52f3d..7074190f29 100644 --- a/go/plugins/mcp/resources_test.go +++ b/go/plugins/mcp/resources_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,306 +11,115 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package mcp import ( - "fmt" + "encoding/base64" "testing" "github.com/firebase/genkit/go/ai" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -// TestMCPTemplateTranslation tests the translation of MCP ResourceTemplate -// objects to Genkit ai.Resource objects. -// -// This test validates: -// 1. Template string extraction from MCP ResourceTemplate objects -// 2. Working Genkit ai.Resource objects -// 3. URI pattern matching with extracted templates -// 4. Variable extraction from matched URIs -// -// This translation step happens inside GetActiveResources() -// when users fetch resources from MCP servers. If template extraction fails, -// the resulting resources won't match any URIs and will be unusable. -func TestMCPTemplateTranslation(t *testing.T) { - testCases := []struct { - name string - templateURI string - testURI string - shouldMatch bool - expectedVars map[string]string - }{ - { - name: "user profile template", - templateURI: "user://profile/{id}", - testURI: "user://profile/alice", - shouldMatch: true, - expectedVars: map[string]string{"id": "alice"}, - }, - { - name: "user profile no match", - templateURI: "user://profile/{id}", - testURI: "api://different/path", - shouldMatch: false, - }, - { - name: "api service template", - templateURI: "api://{service}/{version}", - testURI: "api://users/v1", - shouldMatch: true, - expectedVars: map[string]string{"service": "users", "version": "v1"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Simulates what GetActiveResources() receives from MCP server - mcpTemplate := mcp.NewResourceTemplate(tc.templateURI, "test-resource") - - if mcpTemplate.URITemplate != nil && mcpTemplate.URITemplate.Template != nil { - rawString := mcpTemplate.URITemplate.Template.Raw() - if rawString != tc.templateURI { - t.Errorf("Raw() extraction failed: expected %q, got %q", tc.templateURI, rawString) - t.Errorf("This indicates the MCP SDK Raw() method is broken!") - } - } else { - t.Fatal("URITemplate structure is nil - MCP SDK structure changed!") - } - - // Create client for testing translation - client := &GenkitMCPClient{ - options: MCPClientOptions{Name: "test-client"}, - } - - // Test the MCP → Genkit translation step - detachedResource, err := client.toGenkitResourceTemplate(mcpTemplate) - if err != nil { - t.Fatalf("MCP → Genkit translation failed: %v", err) - } +func TestToGenkitParts(t *testing.T) { + client := &GenkitMCPClient{} - // Verify the translated resource can match URIs correctly - actualMatch := detachedResource.Matches(tc.testURI) - if actualMatch != tc.shouldMatch { - t.Errorf("Template matching failed: template %s vs URI %s: expected match=%v, got %v", - tc.templateURI, tc.testURI, tc.shouldMatch, actualMatch) - t.Errorf("This indicates template extraction or URI matching is broken!") - } - - if tc.shouldMatch && tc.expectedVars != nil { - variables, err := detachedResource.ExtractVariables(tc.testURI) - if err != nil { - t.Errorf("Variable extraction failed after translation: %v", err) - } + t.Run("text part", func(t *testing.T) { + contents := []*mcp.ResourceContents{ + { + Text: "hello world", + MIMEType: "text/plain", + }, + } + + parts := client.toGenkitParts(contents) + if got := len(parts); got != 1 { + t.Fatalf("len(parts) got = %d, want 1", got) + } + + if got := parts[0].Text; got != "hello world" { + t.Errorf("parts[0].Text got = %q, want %q", got, "hello world") + } + + if got := parts[0].Kind; got != ai.PartText { + t.Errorf("parts[0].Kind got = %v, want %v (PartText)", got, ai.PartText) + } + }) + + t.Run("json data part", func(t *testing.T) { + jsonData := `{"id": 123, "status": "ok"}` + contents := []*mcp.ResourceContents{ + { + Text: jsonData, + MIMEType: "application/json", + }, + } + + parts := client.toGenkitParts(contents) + if got := len(parts); got != 1 { + t.Fatalf("len(parts) got = %d, want 1", got) + } + + // In resources_mcp.go, application/json becomes a DataPart + if got := parts[0].Kind; got != ai.PartData { + t.Errorf("parts[0].Kind got = %v, want %v (PartData)", got, ai.PartData) + } + + if got := parts[0].Text; got != jsonData { + t.Errorf("parts[0].Text got = %q, want %q", got, jsonData) + } + }) + + t.Run("binary blob part", func(t *testing.T) { + blobData := []byte{0x00, 0x01, 0x02, 0x03} + contents := []*mcp.ResourceContents{ + { + Blob: blobData, + MIMEType: "image/png", + }, + } + + parts := client.toGenkitParts(contents) + if got := len(parts); got != 1 { + t.Fatalf("len(parts) got = %d, want 1", got) + } + + if got := parts[0].Kind; got != ai.PartMedia { + t.Errorf("parts[0].Kind got = %v, want %v (PartMedia)", got, ai.PartMedia) + } + + if got := parts[0].ContentType; got != "image/png" { + t.Errorf("parts[0].ContentType got = %q, want %q", got, "image/png") + } + + wantBase64 := base64.StdEncoding.EncodeToString(blobData) + if got := parts[0].Text; got != wantBase64 { + t.Errorf("parts[0].Text got = %q, want %q (base64 encoded)", got, wantBase64) + } + }) +} - for key, expectedValue := range tc.expectedVars { - if variables[key] != expectedValue { - t.Errorf("Variable %s: expected %s, got %s", key, expectedValue, variables[key]) - } - } - } - }) +func TestToGenkitResource(t *testing.T) { + client := &GenkitMCPClient{ + options: MCPClientOptions{Name: "srv"}, } -} -// TestMCPTemplateEdgeCases tests malformed inputs -func TestMCPTemplateEdgeCases(t *testing.T) { - testCases := []struct { - name string - templateURI string - testURI string - expectError bool - expectMatch bool - expectedVars map[string]string - description string - }{ - { - name: "empty template", - templateURI: "", - testURI: "user://profile/alice", - expectError: true, - description: "Should fail with empty template", - }, - { - name: "malformed template - missing closing brace", - templateURI: "user://profile/{id", - testURI: "user://profile/alice", - expectError: true, - description: "Should fail with malformed template syntax", - }, - { - name: "malformed template - missing opening brace", - templateURI: "user://profile/id}", - testURI: "user://profile/alice", - expectError: true, - description: "Should fail with malformed template syntax", - }, - { - name: "template with special characters", - templateURI: "api://v1/{resource-name}/data", - testURI: "api://v1/user-profiles/data", - expectError: true, // MCP SDK rejects this template - description: "Should handle SDK template rejections gracefully", - }, - { - name: "template with encoded characters", - templateURI: "file://docs/{filename}", - testURI: "file://docs/hello%20world.pdf", - expectMatch: true, - expectedVars: map[string]string{"filename": "hello world.pdf"}, - description: "URL decoding occurs during variable extraction", - }, - { - name: "URI with query parameters", - templateURI: "api://search/{query}", - testURI: "api://search/hello?limit=10&offset=0", - expectMatch: true, // Query parameters are stripped before matching - expectedVars: map[string]string{"query": "hello"}, - description: "Query parameters are stripped, template matches path portion", - }, - { - name: "case sensitivity", - templateURI: "user://profile/{id}", - testURI: "USER://PROFILE/ALICE", - expectMatch: false, // URI schemes are case-sensitive - description: "Should be case-sensitive for scheme", - }, - { - name: "multiple variables same pattern", - templateURI: "api://{service}/{service}", - testURI: "api://users/users", - expectMatch: true, - expectedVars: map[string]string{"service": ""}, // BUG: Returns empty instead of "users" - description: "Duplicate variable names have buggy behavior (should return 'users', not '')", - }, - { - name: "empty variable value", - templateURI: "api://{service}/data", - testURI: "api:///data", // Empty service name - expectMatch: true, // RFC 6570 allows empty variables - expectedVars: map[string]string{"service": ""}, - description: "Empty variable values are valid per RFC 6570", - }, - { - name: "nested path variables", - templateURI: "file:///{folder}/{subfolder}/{filename}", - testURI: "file:///docs/api/readme.md", - expectMatch: true, - expectedVars: map[string]string{ - "folder": "docs", - "subfolder": "api", - "filename": "readme.md", - }, - description: "Should handle multiple path segments", - }, - { - name: "trailing slash in URI (common user issue)", - templateURI: "api://users/{id}", - testURI: "api://users/123/", // User adds trailing slash - expectMatch: true, // Fixed! Trailing slashes are now stripped - expectedVars: map[string]string{"id": "123"}, - description: "Trailing slashes are stripped for better UX", - }, - { - name: "URI with fragment (common in docs/web)", - templateURI: "docs://page/{id}", - testURI: "docs://page/intro#section1", // Common in documentation - expectMatch: true, // Fixed! Fragments are now stripped - expectedVars: map[string]string{"id": "intro"}, - description: "URI fragments are stripped like query parameters", - }, - { - name: "file extension in template", - templateURI: "file://docs/{filename}", - testURI: "file://docs/README.md", - expectMatch: true, - expectedVars: map[string]string{"filename": "README.md"}, - description: "File extensions should be captured in variables", + mcpRes := &mcp.Resource{ + Name: "logs", + URI: "file:///var/log/app.log", + Description: "Application logs", + Annotations: &mcp.Annotations{ + Priority: 0.8, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Handle empty template as special case - if tc.templateURI == "" { - client := &GenkitMCPClient{ - options: MCPClientOptions{Name: "test-client"}, - } - - mcpTemplate := mcp.NewResourceTemplate("", "test-resource") - _, err := client.toGenkitResourceTemplate(mcpTemplate) - - if tc.expectError && err == nil { - t.Error("Expected error for empty template, but got none") - } - if !tc.expectError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - return - } - - // Test template creation (may panic for malformed templates) - var mcpTemplate mcp.ResourceTemplate - var templateErr error - - func() { - defer func() { - if r := recover(); r != nil { - templateErr = fmt.Errorf("template creation panicked: %v", r) - } - }() - mcpTemplate = mcp.NewResourceTemplate(tc.templateURI, "test-resource") - }() - - // Create client for testing translation - client := &GenkitMCPClient{ - options: MCPClientOptions{Name: "test-client"}, - } - - // Test the MCP → Genkit translation step - var resource ai.Resource - var err error - - if templateErr != nil { - err = templateErr - } else { - resource, err = client.toGenkitResourceTemplate(mcpTemplate) - } - - if tc.expectError { - if err == nil { - t.Errorf("Expected error for %s, but got none", tc.description) - } - return - } - - if err != nil { - t.Errorf("Unexpected error for %s: %v", tc.description, err) - return - } - - // Test URI matching - actualMatch := resource.Matches(tc.testURI) - if actualMatch != tc.expectMatch { - t.Errorf("URI matching failed for %s: template %s vs URI %s: expected match=%v, got %v", - tc.description, tc.templateURI, tc.testURI, tc.expectMatch, actualMatch) - } - - // Test variable extraction if match is expected - if tc.expectMatch && tc.expectedVars != nil { - variables, err := resource.ExtractVariables(tc.testURI) - if err != nil { - t.Errorf("Variable extraction failed for %s: %v", tc.description, err) - return - } + res := client.toGenkitResource(mcpRes) - for key, expectedValue := range tc.expectedVars { - if variables[key] != expectedValue { - t.Errorf("Variable %s: expected %q, got %q", key, expectedValue, variables[key]) - } - } - } - }) + wantName := "srv_logs" + if got := res.Name(); got != wantName { + t.Errorf("res.Name() got = %q, want %q", got, wantName) } } diff --git a/go/plugins/mcp/server.go b/go/plugins/mcp/server.go index 2dc5289cc6..06aec855e5 100644 --- a/go/plugins/mcp/server.go +++ b/go/plugins/mcp/server.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,299 +11,306 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package mcp import ( "context" + "encoding/json" "fmt" - "log/slog" - "strings" + "net/http" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/genkit" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/firebase/genkit/go/plugins/internal/uri" + "github.com/modelcontextprotocol/go-sdk/mcp" ) // MCPServerOptions holds configuration for GenkitMCPServer type MCPServerOptions struct { - // Name for this server instance - used for MCP identification - Name string - // Version number for this server (defaults to "1.0.0" if empty) + Name string + Title string Version string } -// GenkitMCPServer represents an MCP server that exposes Genkit tools, prompts, and resources +// GenkitMCPServer represents an MCP server that exposes Genkit tools and resources type GenkitMCPServer struct { - genkit *genkit.Genkit - options MCPServerOptions - mcpServer *server.MCPServer - - // Discovered actions from Genkit registry - toolActions []ai.Tool - resourceActions []api.Action - actionsResolved bool + genkit *genkit.Genkit + options MCPServerOptions + server *mcp.Server } -// NewMCPServer creates a new GenkitMCPServer with the provided options -func NewMCPServer(g *genkit.Genkit, options MCPServerOptions) *GenkitMCPServer { - // Set default values - if options.Version == "" { - options.Version = "1.0.0" +func NewMCPServer(g *genkit.Genkit, opts MCPServerOptions) *GenkitMCPServer { + if opts.Version == "" { + opts.Version = "1.0.0" } - server := &GenkitMCPServer{ + return &GenkitMCPServer{ genkit: g, - options: options, + options: opts, } - - return server } -// setup initializes the MCP server and discovers actions func (s *GenkitMCPServer) setup() error { - if s.actionsResolved { + if s.server != nil { return nil } - // Create MCP server with all capabilities - s.mcpServer = server.NewMCPServer( - s.options.Name, - s.options.Version, - server.WithToolCapabilities(true), - server.WithResourceCapabilities(true, true), // subscribe and listChanged capabilities - ) - - // Discover and categorize actions from Genkit registry - toolActions, resourceActions, err := s.discoverAndCategorizeActions() - if err != nil { - return fmt.Errorf("failed to discover actions: %w", err) + s.server = mcp.NewServer(&mcp.Implementation{ + Name: s.options.Name, + Title: s.options.Title, + Version: s.options.Version, + }, &mcp.ServerOptions{ + Capabilities: &mcp.ServerCapabilities{ + Resources: &mcp.ResourceCapabilities{ + ListChanged: true, + }, + Tools: &mcp.ToolCapabilities{ + ListChanged: true, + }, + }, + }) + + tools := genkit.ListTools(s.genkit) + for _, t := range tools { + mcpTool := s.toMCPTool(t) + s.server.AddTool(mcpTool, s.createToolHandler(t)) } - // Store discovered actions - s.toolActions = toolActions - s.resourceActions = resourceActions - - // Register tools with the MCP server - for _, tool := range toolActions { - mcpTool := s.convertGenkitToolToMCP(tool) - s.mcpServer.AddTool(mcpTool, s.createToolHandler(tool)) - } - - // Register resources with the MCP server - for _, resourceAction := range resourceActions { - if err := s.registerResourceWithMCP(resourceAction); err != nil { - slog.Warn("Failed to register resource", "resource", resourceAction.Desc().Name, "error", err) + resources := genkit.ListResources(s.genkit) + for _, r := range resources { + if err := s.registerResource(r); err != nil { + return err } } - s.actionsResolved = true - slog.Info("MCP Server setup complete", - "name", s.options.Name, - "tools", len(s.toolActions), - "resources", len(s.resourceActions)) + // TODO: add prompt + return nil } -// discoverAndCategorizeActions discovers all actions from Genkit registry and categorizes them -func (s *GenkitMCPServer) discoverAndCategorizeActions() ([]ai.Tool, []api.Action, error) { - // Use the existing List functions which properly handle the registry access - toolActions := genkit.ListTools(s.genkit) - resources := genkit.ListResources(s.genkit) - - // Convert ai.Resource to api.Action - resourceActions := make([]api.Action, len(resources)) - for i, resource := range resources { - if resourceAction, ok := resource.(api.Action); ok { - resourceActions[i] = resourceAction - } else { - return nil, nil, fmt.Errorf("resource %s does not implement api.Action", resource.Name()) - } +func (s *GenkitMCPServer) toMCPTool(t ai.Tool) *mcp.Tool { + def := t.Definition() + return &mcp.Tool{ + Name: def.Name, + Description: def.Description, + InputSchema: def.InputSchema, } - - return toolActions, resourceActions, nil } -// convertGenkitToolToMCP converts a Genkit tool to MCP format -func (s *GenkitMCPServer) convertGenkitToolToMCP(tool ai.Tool) mcp.Tool { - def := tool.Definition() - - // Start with basic options - options := []mcp.ToolOption{mcp.WithDescription(def.Description)} - - // Convert input schema if available - if def.InputSchema != nil { - // Parse the JSON schema and convert to MCP tool options - if properties, ok := def.InputSchema["properties"].(map[string]interface{}); ok { - // Convert each property to appropriate MCP option - for propName, propDef := range properties { - if propMap, ok := propDef.(map[string]interface{}); ok { - propType, _ := propMap["type"].(string) - - switch propType { - case "string": - options = append(options, mcp.WithString(propName)) - case "integer", "number": - options = append(options, mcp.WithNumber(propName)) - case "boolean": - options = append(options, mcp.WithBoolean(propName)) - } +func (s *GenkitMCPServer) createToolHandler(t ai.Tool) mcp.ToolHandler { + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return tracing.RunInNewSpan(ctx, &tracing.SpanMetadata{ + Name: "mcp.server.call_tool", + Type: "action", + Subtype: "tool", + IsRoot: true, + Metadata: map[string]string{ + "tool": t.Name(), + }, + }, req, func(ctx context.Context, _ *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var args map[string]any + + if len(req.Params.Arguments) > 0 { + if err := json.Unmarshal(req.Params.Arguments, &args); err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) } } - } - } - - return mcp.NewTool(def.Name, options...) -} -// createToolHandler creates an MCP tool handler for a Genkit tool -func (s *GenkitMCPServer) createToolHandler(tool ai.Tool) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Execute the Genkit tool - result, err := tool.RunRaw(ctx, request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + res, err := t.RunRaw(ctx, args) + if err != nil { + return &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{&mcp.TextContent{Text: err.Error()}}, + }, nil + } - // Convert result to MCP format - switch v := result.(type) { - case string: - return mcp.NewToolResultText(v), nil - case nil: - return mcp.NewToolResultText(""), nil - default: - // Convert complex types to string - return mcp.NewToolResultText(fmt.Sprintf("%v", v)), nil - } + var text string + if s, ok := res.(string); ok { + text = s + } else { + bytes, err := json.Marshal(res) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool result: %w", err) + } + text = string(bytes) + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: text}}, + }, nil + }) } } -// registerResourceWithMCP registers a Genkit resource with the MCP server -func (s *GenkitMCPServer) registerResourceWithMCP(resourceAction api.Action) error { - desc := resourceAction.Desc() - resourceName := strings.TrimPrefix(desc.Key, "/resource/") +func (s *GenkitMCPServer) registerResource(resource ai.Resource) error { + action, ok := resource.(api.Action) + if !ok { + return nil + } + desc := action.Desc() - // Extract original URI/template from metadata - var originalURI string + var uri string var isTemplate bool - if resourceMeta, ok := desc.Metadata["resource"].(map[string]any); ok { - if uri, ok := resourceMeta["uri"].(string); ok && uri != "" { - originalURI = uri - isTemplate = false - } else if template, ok := resourceMeta["template"].(string); ok && template != "" { - originalURI = template - isTemplate = true + // Check metadata for URI/Template. + if resourceMeta := desc.Metadata["resource"]; resourceMeta != nil { + if meta, ok := resourceMeta.(map[string]any); ok { + if u, ok := meta["uri"].(string); ok && u != "" { + uri = u + } else if t, ok := meta["template"].(string); ok && t != "" { + uri = t + isTemplate = true + } + } else { + bytes, err := json.Marshal(resourceMeta) + if err == nil { + var meta struct { + URI string `json:"uri"` + Template string `json:"template"` + } + if err := json.Unmarshal(bytes, &meta); err == nil { + if meta.URI != "" { + uri = meta.URI + } else if meta.Template != "" { + uri = meta.Template + isTemplate = true + } + } + } } } - // Fallback to synthetic URI if no original URI found (shouldn't happen normally) - if originalURI == "" { - originalURI = fmt.Sprintf("genkit://%s", resourceName) - isTemplate = false + if uri == "" { + return nil } - // Create resource handler - handler := func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + h := s.createResourceHandler(resource) + if isTemplate { + s.server.AddResourceTemplate(&mcp.ResourceTemplate{ + URITemplate: uri, + Name: desc.Name, + Description: desc.Description, + }, h) + } else { + s.server.AddResource(&mcp.Resource{ + URI: uri, + Name: desc.Name, + Description: desc.Description, + }, h) + } + return nil +} - // Find matching resource for the URI and execute it - resourceAction, input, err := genkit.FindMatchingResource(s.genkit, request.Params.URI) - if err != nil { - return nil, fmt.Errorf("no resource found for URI %s: %w", request.Params.URI, err) - } +func (s *GenkitMCPServer) createResourceHandler(resource ai.Resource) mcp.ResourceHandler { + return func(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return tracing.RunInNewSpan(ctx, &tracing.SpanMetadata{ + Name: "mcp.server.read_resource", + Type: "action", + Subtype: "resource", + IsRoot: true, + Metadata: map[string]string{ + "uri": req.Params.URI, + }, + }, req, func(ctx context.Context, _ *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + _, input, err := genkit.FindMatchingResource(s.genkit, req.Params.URI) + if err != nil { + return nil, mcp.ResourceNotFoundError(req.Params.URI) + } - // Execute the resource - result, err := resourceAction.Execute(ctx, input) - if err != nil { - return nil, fmt.Errorf("resource execution failed: %w", err) - } + out, err := resource.Execute(ctx, input) + if err != nil { + return nil, err + } - // Convert result to MCP content format - var contents []mcp.ResourceContents - for _, part := range result.Content { - if part.Text != "" { - contents = append(contents, mcp.TextResourceContents{ - URI: request.Params.URI, - MIMEType: "text/plain", - Text: part.Text, - }) + contents, err := s.toMCPResourceContents(req.Params.URI, out.Content) + if err != nil { + return nil, err } - // Handle other part types (media, data, etc.) if needed - } - return contents, nil + return &mcp.ReadResourceResult{Contents: contents}, nil + }) } +} - // Register as template resource or static resource based on type - if isTemplate { - // Create MCP template resource - mcpTemplate := mcp.NewResourceTemplate( - originalURI, // Template URI like "user://profile/{id}" - resourceName, // Name - mcp.WithTemplateDescription(desc.Description), - ) - s.mcpServer.AddResourceTemplate(mcpTemplate, handler) - } else { - // Create MCP static resource - mcpResource := mcp.NewResource( - originalURI, // Static URI - resourceName, // Name - mcp.WithResourceDescription(desc.Description), - ) - s.mcpServer.AddResource(mcpResource, handler) +// toMCPResourceContents translates a slice of [ai.Part] into a slice of [mcp.ResourceContents] +func (s *GenkitMCPServer) toMCPResourceContents(requestURI string, parts []*ai.Part) ([]*mcp.ResourceContents, error) { + var contents []*mcp.ResourceContents + for _, p := range parts { + switch { + case p.IsText(): + contents = append(contents, &mcp.ResourceContents{ + URI: requestURI, + MIMEType: "text/plain", + Text: p.Text, + }) + case p.IsMedia(): + contentType, blob, err := uri.Data(p) + if err != nil { + return nil, err + } + contents = append(contents, &mcp.ResourceContents{ + URI: requestURI, + MIMEType: contentType, + Blob: blob, + }) + case p.IsData(): + contents = append(contents, &mcp.ResourceContents{ + URI: requestURI, + MIMEType: "application/json", + Text: p.Text, + }) + } } - - return nil + return contents, nil } -// ServeStdio starts the MCP server using stdio transport +// ServeStdio runs the server over Stdio func (s *GenkitMCPServer) ServeStdio() error { + return s.ServeStdioWithContext(context.Background()) +} + +// ServeStdioWithContext runs the server over Stdio with the given context. +// Canceling the context will stop the server. +func (s *GenkitMCPServer) ServeStdioWithContext(ctx context.Context) error { if err := s.setup(); err != nil { - return fmt.Errorf("setup failed: %w", err) + return err } - - return server.ServeStdio(s.mcpServer) + return s.server.Run(ctx, &mcp.StdioTransport{}) } -// Serve starts the MCP server with a custom transport -func (s *GenkitMCPServer) Serve(transport interface{}) error { +// HTTPHandler creates an HTTP handler that serves MCP using SSE. +func (s *GenkitMCPServer) HTTPHandler() (http.Handler, error) { if err := s.setup(); err != nil { - return fmt.Errorf("setup failed: %w", err) + return nil, err } - // For now, only stdio is supported through the server.ServeStdio function - return server.ServeStdio(s.mcpServer) -} + sseHandler := mcp.NewSSEHandler(func(r *http.Request) *mcp.Server { + return s.server + }, nil) -// Close shuts down the MCP server -func (s *GenkitMCPServer) Close() error { - // The mcp-go server handles cleanup internally - return nil + return sseHandler, nil } -// GetServer returns the underlying MCP server instance -func (s *GenkitMCPServer) GetServer() *server.MCPServer { - return s.mcpServer -} - -// ListRegisteredTools returns the names of all discovered tools +// ListRegisteredTools returns the names of all tools registered in the server's Genkit instance. func (s *GenkitMCPServer) ListRegisteredTools() []string { - var toolNames []string - for _, tool := range s.toolActions { - toolNames = append(toolNames, tool.Name()) + tools := genkit.ListTools(s.genkit) + var names []string + for _, t := range tools { + names = append(names, t.Name()) } - return toolNames + return names } -// ListRegisteredResources returns the names of all discovered resources +// ListRegisteredResources returns the names of all resources registered in the server's Genkit instance. func (s *GenkitMCPServer) ListRegisteredResources() []string { - var resourceNames []string - for _, resourceAction := range s.resourceActions { - desc := resourceAction.Desc() - resourceName := strings.TrimPrefix(desc.Key, "/resource/") - resourceNames = append(resourceNames, resourceName) + resources := genkit.ListResources(s.genkit) + var names []string + for _, r := range resources { + names = append(names, r.Name()) } - return resourceNames + return names } diff --git a/go/plugins/mcp/server_test.go b/go/plugins/mcp/server_test.go new file mode 100644 index 0000000000..6b01f85009 --- /dev/null +++ b/go/plugins/mcp/server_test.go @@ -0,0 +1,59 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mcp + +import ( + "context" + "testing" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" +) + +func TestToMCPTool(t *testing.T) { + ctx := context.Background() + g := genkit.Init(ctx) + server := &GenkitMCPServer{genkit: g} + + genkitTool := genkit.DefineTool(g, "gablorken", "calculates a gablorken", + func(ctx *ai.ToolContext, input struct { + Value int + Over float64 + }, + ) (float64, error) { + return 0, nil + }, + ) + + got := server.toMCPTool(genkitTool) + + if got.Name != "gablorken" { + t.Errorf("mcpTool.Name got = %q, want %q", got.Name, "gablorken") + } + if got.Description != "calculates a gablorken" { + t.Errorf("mcpTool.Description got = %q, want %q", got.Description, "calculates a gablorken") + } + if got.InputSchema == nil { + t.Fatal("mcpTool.InputSchema is nil") + } +} + +func TestNewMCPServer(t *testing.T) { + s := NewMCPServer(nil, MCPServerOptions{Name: "test-server"}) + + if got := s.options.Version; got != "1.0.0" { + t.Errorf("default version got = %q, want %q", got, "1.0.0") + } +} diff --git a/go/plugins/mcp/tools.go b/go/plugins/mcp/tools.go index 931f619b0e..236bbd9e6d 100644 --- a/go/plugins/mcp/tools.go +++ b/go/plugins/mcp/tools.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,8 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 -// Package mcp provides a client for integration with the Model Context Protocol. package mcp import ( @@ -22,212 +23,86 @@ import ( "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/genkit" - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -// GetActiveTools retrieves all tools available from the MCP server +// GetActiveTools retrieves all available tools from the MCP server func (c *GenkitMCPClient) GetActiveTools(ctx context.Context, g *genkit.Genkit) ([]ai.Tool, error) { - if !c.IsEnabled() || c.server == nil { + if !c.IsEnabled() || c.server == nil || c.server.Session == nil { return nil, nil } - - // Get all MCP tools - mcpTools, err := c.getTools(ctx) - if err != nil { - return nil, err + if c.server.Error != nil { + return nil, c.server.Error } - // Create tools from MCP server - return c.createTools(mcpTools) -} - -// createTools creates Genkit tools from MCP tools -func (c *GenkitMCPClient) createTools(mcpTools []mcp.Tool) ([]ai.Tool, error) { var tools []ai.Tool - for _, mcpTool := range mcpTools { - tool, err := c.createTool(mcpTool) + for mt, err := range c.server.Session.Tools(ctx, nil) { if err != nil { - return nil, err + return nil, fmt.Errorf("failed to list tools: %w", err) } - if tool != nil { - tools = append(tools, tool) + tool, err := c.createTool(mt) + if err != nil { + return nil, err } + tools = append(tools, tool) } return tools, nil } -// getInputSchema returns the MCP input schema as a generic map for Genkit -func (c *GenkitMCPClient) getInputSchema(mcpTool mcp.Tool) (map[string]any, error) { - var out map[string]any - schemaBytes, err := json.Marshal(mcpTool.InputSchema) - if err != nil { - return nil, fmt.Errorf("failed to marshal MCP input schema for tool %s: %w", mcpTool.Name, err) - } - if err := json.Unmarshal(schemaBytes, &out); err != nil { - // Fall back to empty map if unmarshalling fails - out = map[string]any{} - } - if out == nil { - out = map[string]any{} +// parseInputSchema converts a given schema into a map[string]any +func (c *GenkitMCPClient) parseInputSchema(schema any) (map[string]any, error) { + if schema == nil { + return map[string]any{"type": "object"}, nil } - return out, nil -} - -// createTool converts a single MCP tool to a Genkit tool -func (c *GenkitMCPClient) createTool(mcpTool mcp.Tool) (ai.Tool, error) { - // Use namespaced tool name - namespacedToolName := c.GetToolNameWithNamespace(mcpTool.Name) - toolFunc := c.createToolFunction(mcpTool) - inputSchema, err := c.getInputSchema(mcpTool) + bytes, err := json.Marshal(schema) if err != nil { - return nil, fmt.Errorf("failed to get input schema for tool %s: %w", mcpTool.Name, err) - } - var tool ai.Tool - if len(inputSchema) > 0 { - tool = ai.NewTool( - namespacedToolName, - mcpTool.Description, - toolFunc, - ai.WithInputSchema(inputSchema), - ) - } else { - tool = ai.NewTool( - namespacedToolName, - mcpTool.Description, - toolFunc, - ) + return nil, err } - return tool, nil -} - -// getTools retrieves all tools from the MCP server by paginating through results -func (c *GenkitMCPClient) getTools(ctx context.Context) ([]mcp.Tool, error) { - var allMcpTools []mcp.Tool - var cursor mcp.Cursor - - // Paginate through all available tools from the MCP server - for { - // Fetch a page of tools - mcpTools, nextCursor, err := c.fetchToolsPage(ctx, cursor) - if err != nil { - return nil, err - } - allMcpTools = append(allMcpTools, mcpTools...) - - // Check if we've reached the last page - cursor = nextCursor - if cursor == "" { - break - } + var res map[string]any + if err := json.Unmarshal(bytes, &res); err != nil { + return nil, err } - return allMcpTools, nil + return res, nil } -// fetchToolsPage retrieves a single page of tools from the MCP server -func (c *GenkitMCPClient) fetchToolsPage(ctx context.Context, cursor mcp.Cursor) ([]mcp.Tool, mcp.Cursor, error) { - listReq := mcp.ListToolsRequest{ - PaginatedRequest: mcp.PaginatedRequest{ - Params: struct { - Cursor mcp.Cursor `json:"cursor,omitempty"` - }{ - Cursor: cursor, - }, - }, - } +func (c *GenkitMCPClient) createTool(mt *mcp.Tool) (ai.Tool, error) { + namespaceName := fmt.Sprintf("%s_%s", c.options.Name, mt.Name) - result, err := c.server.Client.ListTools(ctx, listReq) + toolFunc := c.createToolFunction(mt.Name) + + inputSchema, err := c.parseInputSchema(mt.InputSchema) if err != nil { - return nil, "", fmt.Errorf("failed to list tools: %w", err) + return nil, fmt.Errorf("failed to parse schema for tool %s: %w", mt.Name, err) } - return result.Tools, result.NextCursor, nil + return ai.NewTool(namespaceName, mt.Description, toolFunc, ai.WithInputSchema(inputSchema)), nil } -// createToolFunction creates a Genkit tool function that will execute the MCP tool -func (c *GenkitMCPClient) createToolFunction(mcpTool mcp.Tool) func(*ai.ToolContext, interface{}) (interface{}, error) { - // Capture mcpTool by value for the closure - currentMCPTool := mcpTool - client := c.server.Client - - return func(toolCtx *ai.ToolContext, args interface{}) (interface{}, error) { - ctx := toolCtx.Context // Get context from tool context - - // Convert the arguments to the format expected by MCP - callToolArgs, err := prepareToolArguments(currentMCPTool, args) - if err != nil { - return nil, err +func (c *GenkitMCPClient) createToolFunction(toolName string) func(*ai.ToolContext, any) (any, error) { + return func(toolCtx *ai.ToolContext, args any) (any, error) { + if c.server == nil || c.server.Session == nil { + return nil, fmt.Errorf("MCP session is closed") } - - // Create and execute the MCP tool call request - mcpResult, err := executeToolCall(ctx, client, currentMCPTool.Name, callToolArgs) - if err != nil { - return nil, fmt.Errorf("failed to call tool %s: %w", currentMCPTool.Name, err) + if c.server.Error != nil { + return nil, c.server.Error } - return mcpResult, nil - } -} - -// prepareToolArguments converts Genkit tool arguments to MCP format -// and validates required fields based on the tool's schema -func prepareToolArguments(mcpTool mcp.Tool, args interface{}) (map[string]interface{}, error) { - var callToolArgs map[string]interface{} - if args != nil { - jsonBytes, err := json.Marshal(args) + params := &mcp.CallToolParams{ + Name: toolName, + Arguments: args, + } + result, err := c.server.Session.CallTool(toolCtx.Context, params) if err != nil { - return nil, fmt.Errorf("tool arguments must be marshallable to map[string]interface{}, got %T: %w", args, err) + return nil, fmt.Errorf("MCP tool call failed: %w", err) } - if err := json.Unmarshal(jsonBytes, &callToolArgs); err != nil { - return nil, fmt.Errorf("tool arguments could not be converted to map[string]interface{} for tool %s: %w", mcpTool.Name, err) + if result.IsError { + // in mcp, errors are often returned as text + return nil, fmt.Errorf("tool execution error: %v", result.Content) } - } else { - callToolArgs = make(map[string]interface{}) + return result, nil } - - // Validate required fields - if err := validateRequiredArguments(mcpTool, callToolArgs); err != nil { - return nil, err - } - - return callToolArgs, nil -} - -// validateRequiredArguments checks if all required arguments are present -func validateRequiredArguments(mcpTool mcp.Tool, args map[string]interface{}) error { - if mcpTool.InputSchema.Required != nil { - for _, required := range mcpTool.InputSchema.Required { - if _, exists := args[required]; !exists { - return fmt.Errorf("required field %q missing for tool %q", required, mcpTool.Name) - } - } - } - return nil -} - -// executeToolCall makes the actual MCP tool call -func executeToolCall(ctx context.Context, client *client.Client, toolName string, args map[string]interface{}) (*mcp.CallToolResult, error) { - callReq := mcp.CallToolRequest{ - Params: struct { - Name string `json:"name"` - Arguments any `json:"arguments,omitempty"` - Meta *mcp.Meta `json:"_meta,omitempty"` - }{ - Name: toolName, - Arguments: args, - Meta: nil, - }, - } - - result, err := client.CallTool(ctx, callReq) - - if err != nil { - return nil, err - } - - return result, nil } diff --git a/go/plugins/mcp/tools_test.go b/go/plugins/mcp/tools_test.go index 8d27470e06..cf064cb04d 100644 --- a/go/plugins/mcp/tools_test.go +++ b/go/plugins/mcp/tools_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,153 +11,93 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package mcp import ( - "encoding/json" + "reflect" "testing" - "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" ) -func asMap(t *testing.T, v any, label string) map[string]any { - t.Helper() - m, ok := v.(map[string]any) - if !ok { - t.Fatalf("%s: want map[string]any, got %T", label, v) - } - return m -} +func TestParseInputSchema(t *testing.T) { + client := &GenkitMCPClient{} -func eqStr(t *testing.T, got any, want string, label string) { - t.Helper() - s, ok := got.(string) - if !ok || s != want { - t.Fatalf("%s: got %v", label, got) - } -} + t.Run("valid schema", func(t *testing.T) { + input := map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + } -func eqNum(t *testing.T, got any, want int, label string) { - t.Helper() - f, ok := got.(float64) - if !ok || int(f) != want { - t.Fatalf("%s: got %v", label, got) - } -} + got, err := client.parseInputSchema(input) + if err != nil { + t.Fatalf("parseInputSchema() error = %v, want nil", err) + } -func reqContains(t *testing.T, v any, want string) { - t.Helper() - switch arr := v.(type) { - case []any: - for _, x := range arr { - if s, ok := x.(string); ok && s == want { - return - } + if !reflect.DeepEqual(got, input) { + t.Errorf("parseInputSchema() got = %v, want %v", got, input) } - case []string: - for _, s := range arr { - if s == want { - return - } + }) + + t.Run("nil schema returns empty object", func(t *testing.T) { + got, err := client.parseInputSchema(nil) + if err != nil { + t.Fatalf("parseInputSchema() error = %v, want nil", err) } - default: - t.Fatalf("required: unexpected type %T", v) - } - t.Fatalf("required does not contain %q: %v", want, v) + + want := map[string]any{"type": "object"} + if !reflect.DeepEqual(got, want) { + t.Errorf("parseInputSchema() got = %v, want %v", got, want) + } + }) } -// TestCreateTool tests the createTool function. func TestCreateTool(t *testing.T) { - client := &GenkitMCPClient{options: MCPClientOptions{Name: "srv"}} - client.server = &ServerRef{} // avoid nil deref in createToolFunction - - var m mcp.Tool - toolJSON := `{ - "name": "echo", - "description": "Echo", - "inputSchema": { - "type": "object", - "required": ["q"], - "properties": { - "q": {"type": "string", "description": "query"}, - "n": {"type": "number", "minimum": 1, "maximum": 10}, - "arr": {"type": "array", "minItems": 2, "maxItems": 4} - } - } - }` - if err := json.Unmarshal([]byte(toolJSON), &m); err != nil { - t.Fatalf("failed to unmarshal tool JSON: %v", err) + client := &GenkitMCPClient{ + options: MCPClientOptions{Name: "test-srv"}, } - tool, err := client.createTool(m) - if err != nil { - t.Fatalf("createTool error: %v", err) + mcpTool := &mcp.Tool{ + Name: "get_weather", + Description: "Fetches weather data", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "city": map[string]any{"type": "string"}, + }, + }, } - // Validate that the tool is namespaced - def := tool.Definition() - if def.Name != "srv_echo" { - t.Fatalf("namespacing failed: got %q", def.Name) - } - if def.Description != "Echo" { - t.Fatalf("description mismatch: %q", def.Description) - } - if def.InputSchema == nil { - t.Fatalf("input schema missing") + tool, err := client.createTool(mcpTool) + if err != nil { + t.Fatalf("createTool() error = %v, want nil", err) } - // Validate that the schema is propagated correctly. - eqStr(t, def.InputSchema["type"], "object", "schema.type") - props := asMap(t, def.InputSchema["properties"], "schema.properties") - - q := asMap(t, props["q"], "properties.q") - eqStr(t, q["type"], "string", "q.type") - eqStr(t, q["description"], "query", "q.description") - - n := asMap(t, props["n"], "properties.n") - eqStr(t, n["type"], "number", "n.type") - eqNum(t, n["minimum"], 1, "n.minimum") - eqNum(t, n["maximum"], 10, "n.maximum") - - arr := asMap(t, props["arr"], "properties.arr") - eqStr(t, arr["type"], "array", "arr.type") - eqNum(t, arr["minItems"], 2, "arr.minItems") - eqNum(t, arr["maxItems"], 4, "arr.maxItems") - - reqContains(t, def.InputSchema["required"], "q") -} - -// TestPrepareToolArguments tests the prepareToolArguments function. -// Ensures that required fields are validated. -func TestPrepareToolArguments(t *testing.T) { - var tool mcp.Tool - toolJSON := `{ - "name": "echo", - "inputSchema": { - "type": "object", - "required": ["q"] - } - }` - if err := json.Unmarshal([]byte(toolJSON), &tool); err != nil { - t.Fatalf("failed to unmarshal tool JSON: %v", err) + // Test Namespacing + wantName := "test-srv_get_weather" + if got := tool.Name(); got != wantName { + t.Errorf("tool.Name() got = %q, want %q", got, wantName) } - okArgs := map[string]any{"q": "hi"} - got, err := prepareToolArguments(tool, okArgs) - if err != nil { - t.Fatalf("unexpected error for valid args: %v", err) - } - if got["q"] != "hi" { - t.Fatalf("args not preserved: %v", got) + // Test Description + def := tool.Definition() + if got := def.Description; got != mcpTool.Description { + t.Errorf("tool.Description got = %q, want %q", got, mcpTool.Description) } - _, err = prepareToolArguments(tool, map[string]any{}) - if err == nil { - t.Fatalf("expected error for missing required field") + // Test Schema presence + if def.InputSchema == nil { + t.Error("tool.InputSchema is nil, want valid schema") } - _, err = prepareToolArguments(tool, nil) - if err == nil { - t.Fatalf("expected error for nil args with required field") + + // Test Schema content + gotCityType := def.InputSchema["properties"].(map[string]any)["city"].(map[string]any)["type"] + if gotCityType != "string" { + t.Errorf("schema city type got = %v, want %q", gotCityType, "string") } }