260 lines
6.1 KiB
Go
260 lines
6.1 KiB
Go
package mcp
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os/exec"
|
|
"sync"
|
|
|
|
"tell-me/config"
|
|
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// Manager manages multiple MCP server connections
|
|
type Manager struct {
|
|
servers map[string]*ServerConnection
|
|
mu sync.RWMutex
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// ServerConnection represents a connection to an MCP server
|
|
type ServerConnection struct {
|
|
Name string
|
|
Config config.MCPServer
|
|
Client *mcp.Client
|
|
Session *mcp.ClientSession
|
|
Tools []*mcp.Tool
|
|
Error string // Connection error if any
|
|
}
|
|
|
|
// NewManager creates a new MCP manager
|
|
func NewManager(ctx context.Context) *Manager {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
return &Manager{
|
|
servers: make(map[string]*ServerConnection),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// ConnectServers connects to all configured MCP servers
|
|
func (m *Manager) ConnectServers(servers map[string]config.MCPServer) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
for name, serverCfg := range servers {
|
|
if err := m.connectServer(name, serverCfg); err != nil {
|
|
// Store the error in the connection
|
|
m.servers[name] = &ServerConnection{
|
|
Name: name,
|
|
Config: serverCfg,
|
|
Error: err.Error(),
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// connectServer connects to a single MCP server
|
|
func (m *Manager) connectServer(name string, serverCfg config.MCPServer) error {
|
|
// Create MCP client
|
|
client := mcp.NewClient(&mcp.Implementation{
|
|
Name: "tell-me",
|
|
Version: "1.0.0",
|
|
}, nil)
|
|
|
|
// Only stdio transport is supported for local servers
|
|
if serverCfg.Command == "" {
|
|
return fmt.Errorf("command is required for MCP server")
|
|
}
|
|
|
|
cmd := exec.CommandContext(m.ctx, serverCfg.Command, serverCfg.Args...)
|
|
|
|
// Set environment variables if provided
|
|
if len(serverCfg.Env) > 0 {
|
|
cmd.Env = append(cmd.Env, m.envMapToSlice(serverCfg.Env)...)
|
|
}
|
|
|
|
transport := &mcp.CommandTransport{Command: cmd}
|
|
|
|
// Connect to the server
|
|
session, err := client.Connect(m.ctx, transport, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect: %w", err)
|
|
}
|
|
|
|
// List available tools
|
|
toolsResult, err := session.ListTools(m.ctx, &mcp.ListToolsParams{})
|
|
if err != nil {
|
|
session.Close()
|
|
return fmt.Errorf("failed to list tools: %w", err)
|
|
}
|
|
|
|
// Store the connection
|
|
m.servers[name] = &ServerConnection{
|
|
Name: name,
|
|
Config: serverCfg,
|
|
Client: client,
|
|
Session: session,
|
|
Tools: toolsResult.Tools,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// envMapToSlice converts environment map to slice format
|
|
func (m *Manager) envMapToSlice(envMap map[string]string) []string {
|
|
result := make([]string, 0, len(envMap))
|
|
for key, value := range envMap {
|
|
result = append(result, fmt.Sprintf("%s=%s", key, value))
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetAllTools returns all tools from all connected servers as OpenAI tool definitions
|
|
func (m *Manager) GetAllTools() []openai.Tool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
var tools []openai.Tool
|
|
|
|
for serverName, conn := range m.servers {
|
|
for _, mcpTool := range conn.Tools {
|
|
// Convert MCP tool to OpenAI tool format
|
|
tool := openai.Tool{
|
|
Type: openai.ToolTypeFunction,
|
|
Function: &openai.FunctionDefinition{
|
|
Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
|
|
Description: mcpTool.Description,
|
|
Parameters: mcpTool.InputSchema,
|
|
},
|
|
}
|
|
tools = append(tools, tool)
|
|
}
|
|
}
|
|
|
|
return tools
|
|
}
|
|
|
|
// CallTool calls a tool on the appropriate MCP server
|
|
func (m *Manager) CallTool(ctx context.Context, toolName string, arguments string) (string, error) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
// Parse the tool name to extract server name and actual tool name
|
|
// Format: serverName_toolName
|
|
var serverName, actualToolName string
|
|
for sName := range m.servers {
|
|
prefix := sName + "_"
|
|
if len(toolName) > len(prefix) && toolName[:len(prefix)] == prefix {
|
|
serverName = sName
|
|
actualToolName = toolName[len(prefix):]
|
|
break
|
|
}
|
|
}
|
|
|
|
if serverName == "" {
|
|
return "", fmt.Errorf("unknown tool: %s", toolName)
|
|
}
|
|
|
|
conn, exists := m.servers[serverName]
|
|
if !exists {
|
|
return "", fmt.Errorf("server not found: %s", serverName)
|
|
}
|
|
|
|
// Parse arguments
|
|
var args map[string]interface{}
|
|
if arguments != "" {
|
|
if err := json.Unmarshal([]byte(arguments), &args); err != nil {
|
|
return "", fmt.Errorf("failed to parse arguments: %w", err)
|
|
}
|
|
}
|
|
|
|
// Call the tool
|
|
result, err := conn.Session.CallTool(ctx, &mcp.CallToolParams{
|
|
Name: actualToolName,
|
|
Arguments: args,
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("tool call failed: %w", err)
|
|
}
|
|
|
|
if result.IsError {
|
|
return "", fmt.Errorf("tool returned error")
|
|
}
|
|
|
|
// Format the result
|
|
var response string
|
|
for _, content := range result.Content {
|
|
switch c := content.(type) {
|
|
case *mcp.TextContent:
|
|
response += c.Text + "\n"
|
|
case *mcp.ImageContent:
|
|
response += fmt.Sprintf("[Image: %s]\n", c.MIMEType)
|
|
case *mcp.EmbeddedResource:
|
|
response += fmt.Sprintf("[Resource: %s]\n", c.Resource.URI)
|
|
}
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// GetServerInfo returns information about connected servers
|
|
func (m *Manager) GetServerInfo() map[string][]string {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
info := make(map[string][]string)
|
|
for name, conn := range m.servers {
|
|
if conn.Error == "" {
|
|
toolNames := make([]string, len(conn.Tools))
|
|
for i, tool := range conn.Tools {
|
|
toolNames[i] = tool.Name
|
|
}
|
|
info[name] = toolNames
|
|
}
|
|
}
|
|
return info
|
|
}
|
|
|
|
// GetDetailedStatus returns detailed status information for all servers
|
|
func (m *Manager) GetDetailedStatus() []*ServerConnection {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
statuses := make([]*ServerConnection, 0, len(m.servers))
|
|
for _, conn := range m.servers {
|
|
statuses = append(statuses, conn)
|
|
}
|
|
return statuses
|
|
}
|
|
|
|
// Close closes all MCP server connections
|
|
func (m *Manager) Close() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var lastErr error
|
|
for name, conn := range m.servers {
|
|
if conn.Session != nil {
|
|
if err := conn.Session.Close(); err != nil {
|
|
log.Printf("Error closing connection to %s: %v", name, err)
|
|
lastErr = err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Cancel context after closing all sessions
|
|
m.cancel()
|
|
|
|
m.servers = make(map[string]*ServerConnection)
|
|
return lastErr
|
|
}
|