package mcp import ( "context" "encoding/json" "fmt" "log" "os/exec" "sync" "tell-me/config" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/sashabaranov/go-openai" ) // Manager manages multiple MCP server connections type Manager struct { servers map[string]*ServerConnection mu sync.RWMutex ctx context.Context cancel context.CancelFunc } // ServerConnection represents a connection to an MCP server type ServerConnection struct { Name string Config config.MCPServer Client *mcp.Client Session *mcp.ClientSession Tools []*mcp.Tool Error string // Connection error if any } // NewManager creates a new MCP manager func NewManager(ctx context.Context) *Manager { ctx, cancel := context.WithCancel(ctx) return &Manager{ servers: make(map[string]*ServerConnection), ctx: ctx, cancel: cancel, } } // ConnectServers connects to all configured MCP servers func (m *Manager) ConnectServers(servers map[string]config.MCPServer) error { m.mu.Lock() defer m.mu.Unlock() for name, serverCfg := range servers { if err := m.connectServer(name, serverCfg); err != nil { log.Printf("Warning: Failed to connect to MCP server %s: %v", name, err) // Store the error in the connection m.servers[name] = &ServerConnection{ Name: name, Config: serverCfg, Error: err.Error(), } continue } log.Printf("Successfully connected to MCP server: %s", name) } return nil } // connectServer connects to a single MCP server func (m *Manager) connectServer(name string, serverCfg config.MCPServer) error { // Create MCP client client := mcp.NewClient(&mcp.Implementation{ Name: "tell-me", Version: "1.0.0", }, nil) // Only stdio transport is supported for local servers if serverCfg.Command == "" { return fmt.Errorf("command is required for MCP server") } cmd := exec.CommandContext(m.ctx, serverCfg.Command, serverCfg.Args...) // Set environment variables if provided if len(serverCfg.Env) > 0 { cmd.Env = append(cmd.Env, m.envMapToSlice(serverCfg.Env)...) } transport := &mcp.CommandTransport{Command: cmd} // Connect to the server session, err := client.Connect(m.ctx, transport, nil) if err != nil { return fmt.Errorf("failed to connect: %w", err) } // List available tools toolsResult, err := session.ListTools(m.ctx, &mcp.ListToolsParams{}) if err != nil { session.Close() return fmt.Errorf("failed to list tools: %w", err) } // Store the connection m.servers[name] = &ServerConnection{ Name: name, Config: serverCfg, Client: client, Session: session, Tools: toolsResult.Tools, } return nil } // envMapToSlice converts environment map to slice format func (m *Manager) envMapToSlice(envMap map[string]string) []string { result := make([]string, 0, len(envMap)) for key, value := range envMap { result = append(result, fmt.Sprintf("%s=%s", key, value)) } return result } // GetAllTools returns all tools from all connected servers as OpenAI tool definitions func (m *Manager) GetAllTools() []openai.Tool { m.mu.RLock() defer m.mu.RUnlock() var tools []openai.Tool for serverName, conn := range m.servers { for _, mcpTool := range conn.Tools { // Convert MCP tool to OpenAI tool format tool := openai.Tool{ Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{ Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name), Description: mcpTool.Description, Parameters: mcpTool.InputSchema, }, } tools = append(tools, tool) } } return tools } // CallTool calls a tool on the appropriate MCP server func (m *Manager) CallTool(ctx context.Context, toolName string, arguments string) (string, error) { m.mu.RLock() defer m.mu.RUnlock() // Parse the tool name to extract server name and actual tool name // Format: serverName_toolName var serverName, actualToolName string for sName := range m.servers { prefix := sName + "_" if len(toolName) > len(prefix) && toolName[:len(prefix)] == prefix { serverName = sName actualToolName = toolName[len(prefix):] break } } if serverName == "" { return "", fmt.Errorf("unknown tool: %s", toolName) } conn, exists := m.servers[serverName] if !exists { return "", fmt.Errorf("server not found: %s", serverName) } // Parse arguments var args map[string]interface{} if arguments != "" { if err := json.Unmarshal([]byte(arguments), &args); err != nil { return "", fmt.Errorf("failed to parse arguments: %w", err) } } // Call the tool result, err := conn.Session.CallTool(ctx, &mcp.CallToolParams{ Name: actualToolName, Arguments: args, }) if err != nil { return "", fmt.Errorf("tool call failed: %w", err) } if result.IsError { return "", fmt.Errorf("tool returned error") } // Format the result var response string for _, content := range result.Content { switch c := content.(type) { case *mcp.TextContent: response += c.Text + "\n" case *mcp.ImageContent: response += fmt.Sprintf("[Image: %s]\n", c.MIMEType) case *mcp.EmbeddedResource: response += fmt.Sprintf("[Resource: %s]\n", c.Resource.URI) } } return response, nil } // GetServerInfo returns information about connected servers func (m *Manager) GetServerInfo() map[string][]string { m.mu.RLock() defer m.mu.RUnlock() info := make(map[string][]string) for name, conn := range m.servers { if conn.Error == "" { toolNames := make([]string, len(conn.Tools)) for i, tool := range conn.Tools { toolNames[i] = tool.Name } info[name] = toolNames } } return info } // GetDetailedStatus returns detailed status information for all servers func (m *Manager) GetDetailedStatus() []*ServerConnection { m.mu.RLock() defer m.mu.RUnlock() statuses := make([]*ServerConnection, 0, len(m.servers)) for _, conn := range m.servers { statuses = append(statuses, conn) } return statuses } // Close closes all MCP server connections func (m *Manager) Close() error { m.mu.Lock() defer m.mu.Unlock() var lastErr error for name, conn := range m.servers { if conn.Session != nil { if err := conn.Session.Close(); err != nil { log.Printf("Error closing connection to %s: %v", name, err) lastErr = err } } } // Cancel context after closing all sessions m.cancel() m.servers = make(map[string]*ServerConnection) return lastErr }