Files
tell-me/mcp/manager.go
Pavel Pivovarov 1d659006ed Fixing verbosity
2025-12-15 15:49:49 +11:00

260 lines
6.1 KiB
Go

package mcp
import (
"context"
"encoding/json"
"fmt"
"log"
"os/exec"
"sync"
"tell-me/config"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sashabaranov/go-openai"
)
// Manager manages multiple MCP server connections
type Manager struct {
servers map[string]*ServerConnection
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// ServerConnection represents a connection to an MCP server
type ServerConnection struct {
Name string
Config config.MCPServer
Client *mcp.Client
Session *mcp.ClientSession
Tools []*mcp.Tool
Error string // Connection error if any
}
// NewManager creates a new MCP manager
func NewManager(ctx context.Context) *Manager {
ctx, cancel := context.WithCancel(ctx)
return &Manager{
servers: make(map[string]*ServerConnection),
ctx: ctx,
cancel: cancel,
}
}
// ConnectServers connects to all configured MCP servers
func (m *Manager) ConnectServers(servers map[string]config.MCPServer) error {
m.mu.Lock()
defer m.mu.Unlock()
for name, serverCfg := range servers {
if err := m.connectServer(name, serverCfg); err != nil {
// Store the error in the connection
m.servers[name] = &ServerConnection{
Name: name,
Config: serverCfg,
Error: err.Error(),
}
continue
}
}
return nil
}
// connectServer connects to a single MCP server
func (m *Manager) connectServer(name string, serverCfg config.MCPServer) error {
// Create MCP client
client := mcp.NewClient(&mcp.Implementation{
Name: "tell-me",
Version: "1.0.0",
}, nil)
// Only stdio transport is supported for local servers
if serverCfg.Command == "" {
return fmt.Errorf("command is required for MCP server")
}
cmd := exec.CommandContext(m.ctx, serverCfg.Command, serverCfg.Args...)
// Set environment variables if provided
if len(serverCfg.Env) > 0 {
cmd.Env = append(cmd.Env, m.envMapToSlice(serverCfg.Env)...)
}
transport := &mcp.CommandTransport{Command: cmd}
// Connect to the server
session, err := client.Connect(m.ctx, transport, nil)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
// List available tools
toolsResult, err := session.ListTools(m.ctx, &mcp.ListToolsParams{})
if err != nil {
session.Close()
return fmt.Errorf("failed to list tools: %w", err)
}
// Store the connection
m.servers[name] = &ServerConnection{
Name: name,
Config: serverCfg,
Client: client,
Session: session,
Tools: toolsResult.Tools,
}
return nil
}
// envMapToSlice converts environment map to slice format
func (m *Manager) envMapToSlice(envMap map[string]string) []string {
result := make([]string, 0, len(envMap))
for key, value := range envMap {
result = append(result, fmt.Sprintf("%s=%s", key, value))
}
return result
}
// GetAllTools returns all tools from all connected servers as OpenAI tool definitions
func (m *Manager) GetAllTools() []openai.Tool {
m.mu.RLock()
defer m.mu.RUnlock()
var tools []openai.Tool
for serverName, conn := range m.servers {
for _, mcpTool := range conn.Tools {
// Convert MCP tool to OpenAI tool format
tool := openai.Tool{
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{
Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
Description: mcpTool.Description,
Parameters: mcpTool.InputSchema,
},
}
tools = append(tools, tool)
}
}
return tools
}
// CallTool calls a tool on the appropriate MCP server
func (m *Manager) CallTool(ctx context.Context, toolName string, arguments string) (string, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Parse the tool name to extract server name and actual tool name
// Format: serverName_toolName
var serverName, actualToolName string
for sName := range m.servers {
prefix := sName + "_"
if len(toolName) > len(prefix) && toolName[:len(prefix)] == prefix {
serverName = sName
actualToolName = toolName[len(prefix):]
break
}
}
if serverName == "" {
return "", fmt.Errorf("unknown tool: %s", toolName)
}
conn, exists := m.servers[serverName]
if !exists {
return "", fmt.Errorf("server not found: %s", serverName)
}
// Parse arguments
var args map[string]interface{}
if arguments != "" {
if err := json.Unmarshal([]byte(arguments), &args); err != nil {
return "", fmt.Errorf("failed to parse arguments: %w", err)
}
}
// Call the tool
result, err := conn.Session.CallTool(ctx, &mcp.CallToolParams{
Name: actualToolName,
Arguments: args,
})
if err != nil {
return "", fmt.Errorf("tool call failed: %w", err)
}
if result.IsError {
return "", fmt.Errorf("tool returned error")
}
// Format the result
var response string
for _, content := range result.Content {
switch c := content.(type) {
case *mcp.TextContent:
response += c.Text + "\n"
case *mcp.ImageContent:
response += fmt.Sprintf("[Image: %s]\n", c.MIMEType)
case *mcp.EmbeddedResource:
response += fmt.Sprintf("[Resource: %s]\n", c.Resource.URI)
}
}
return response, nil
}
// GetServerInfo returns information about connected servers
func (m *Manager) GetServerInfo() map[string][]string {
m.mu.RLock()
defer m.mu.RUnlock()
info := make(map[string][]string)
for name, conn := range m.servers {
if conn.Error == "" {
toolNames := make([]string, len(conn.Tools))
for i, tool := range conn.Tools {
toolNames[i] = tool.Name
}
info[name] = toolNames
}
}
return info
}
// GetDetailedStatus returns detailed status information for all servers
func (m *Manager) GetDetailedStatus() []*ServerConnection {
m.mu.RLock()
defer m.mu.RUnlock()
statuses := make([]*ServerConnection, 0, len(m.servers))
for _, conn := range m.servers {
statuses = append(statuses, conn)
}
return statuses
}
// Close closes all MCP server connections
func (m *Manager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
var lastErr error
for name, conn := range m.servers {
if conn.Session != nil {
if err := conn.Session.Close(); err != nil {
log.Printf("Error closing connection to %s: %v", name, err)
lastErr = err
}
}
}
// Cancel context after closing all sessions
m.cancel()
m.servers = make(map[string]*ServerConnection)
return lastErr
}