Added MCP support
This commit is contained in:
261
mcp/manager.go
Normal file
261
mcp/manager.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"tell-me/config"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// ServerConnection represents a connection to an MCP server
|
||||
type ServerConnection struct {
|
||||
Name string
|
||||
Config config.MCPServer
|
||||
Client *mcp.Client
|
||||
Session *mcp.ClientSession
|
||||
Tools []*mcp.Tool
|
||||
Error string // Connection error if any
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager(ctx context.Context) *Manager {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectServers connects to all configured MCP servers
|
||||
func (m *Manager) ConnectServers(servers map[string]config.MCPServer) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for name, serverCfg := range servers {
|
||||
if err := m.connectServer(name, serverCfg); err != nil {
|
||||
log.Printf("Warning: Failed to connect to MCP server %s: %v", name, err)
|
||||
// Store the error in the connection
|
||||
m.servers[name] = &ServerConnection{
|
||||
Name: name,
|
||||
Config: serverCfg,
|
||||
Error: err.Error(),
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("Successfully connected to MCP server: %s", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectServer connects to a single MCP server
|
||||
func (m *Manager) connectServer(name string, serverCfg config.MCPServer) error {
|
||||
// Create MCP client
|
||||
client := mcp.NewClient(&mcp.Implementation{
|
||||
Name: "tell-me",
|
||||
Version: "1.0.0",
|
||||
}, nil)
|
||||
|
||||
// Only stdio transport is supported for local servers
|
||||
if serverCfg.Command == "" {
|
||||
return fmt.Errorf("command is required for MCP server")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(m.ctx, serverCfg.Command, serverCfg.Args...)
|
||||
|
||||
// Set environment variables if provided
|
||||
if len(serverCfg.Env) > 0 {
|
||||
cmd.Env = append(cmd.Env, m.envMapToSlice(serverCfg.Env)...)
|
||||
}
|
||||
|
||||
transport := &mcp.CommandTransport{Command: cmd}
|
||||
|
||||
// Connect to the server
|
||||
session, err := client.Connect(m.ctx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
// List available tools
|
||||
toolsResult, err := session.ListTools(m.ctx, &mcp.ListToolsParams{})
|
||||
if err != nil {
|
||||
session.Close()
|
||||
return fmt.Errorf("failed to list tools: %w", err)
|
||||
}
|
||||
|
||||
// Store the connection
|
||||
m.servers[name] = &ServerConnection{
|
||||
Name: name,
|
||||
Config: serverCfg,
|
||||
Client: client,
|
||||
Session: session,
|
||||
Tools: toolsResult.Tools,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// envMapToSlice converts environment map to slice format
|
||||
func (m *Manager) envMapToSlice(envMap map[string]string) []string {
|
||||
result := make([]string, 0, len(envMap))
|
||||
for key, value := range envMap {
|
||||
result = append(result, fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetAllTools returns all tools from all connected servers as OpenAI tool definitions
|
||||
func (m *Manager) GetAllTools() []openai.Tool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var tools []openai.Tool
|
||||
|
||||
for serverName, conn := range m.servers {
|
||||
for _, mcpTool := range conn.Tools {
|
||||
// Convert MCP tool to OpenAI tool format
|
||||
tool := openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: fmt.Sprintf("%s_%s", serverName, mcpTool.Name),
|
||||
Description: mcpTool.Description,
|
||||
Parameters: mcpTool.InputSchema,
|
||||
},
|
||||
}
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// CallTool calls a tool on the appropriate MCP server
|
||||
func (m *Manager) CallTool(ctx context.Context, toolName string, arguments string) (string, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Parse the tool name to extract server name and actual tool name
|
||||
// Format: serverName_toolName
|
||||
var serverName, actualToolName string
|
||||
for sName := range m.servers {
|
||||
prefix := sName + "_"
|
||||
if len(toolName) > len(prefix) && toolName[:len(prefix)] == prefix {
|
||||
serverName = sName
|
||||
actualToolName = toolName[len(prefix):]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if serverName == "" {
|
||||
return "", fmt.Errorf("unknown tool: %s", toolName)
|
||||
}
|
||||
|
||||
conn, exists := m.servers[serverName]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("server not found: %s", serverName)
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
var args map[string]interface{}
|
||||
if arguments != "" {
|
||||
if err := json.Unmarshal([]byte(arguments), &args); err != nil {
|
||||
return "", fmt.Errorf("failed to parse arguments: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Call the tool
|
||||
result, err := conn.Session.CallTool(ctx, &mcp.CallToolParams{
|
||||
Name: actualToolName,
|
||||
Arguments: args,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("tool call failed: %w", err)
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
return "", fmt.Errorf("tool returned error")
|
||||
}
|
||||
|
||||
// Format the result
|
||||
var response string
|
||||
for _, content := range result.Content {
|
||||
switch c := content.(type) {
|
||||
case *mcp.TextContent:
|
||||
response += c.Text + "\n"
|
||||
case *mcp.ImageContent:
|
||||
response += fmt.Sprintf("[Image: %s]\n", c.MIMEType)
|
||||
case *mcp.EmbeddedResource:
|
||||
response += fmt.Sprintf("[Resource: %s]\n", c.Resource.URI)
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetServerInfo returns information about connected servers
|
||||
func (m *Manager) GetServerInfo() map[string][]string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
info := make(map[string][]string)
|
||||
for name, conn := range m.servers {
|
||||
if conn.Error == "" {
|
||||
toolNames := make([]string, len(conn.Tools))
|
||||
for i, tool := range conn.Tools {
|
||||
toolNames[i] = tool.Name
|
||||
}
|
||||
info[name] = toolNames
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// GetDetailedStatus returns detailed status information for all servers
|
||||
func (m *Manager) GetDetailedStatus() []*ServerConnection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
statuses := make([]*ServerConnection, 0, len(m.servers))
|
||||
for _, conn := range m.servers {
|
||||
statuses = append(statuses, conn)
|
||||
}
|
||||
return statuses
|
||||
}
|
||||
|
||||
// Close closes all MCP server connections
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for name, conn := range m.servers {
|
||||
if conn.Session != nil {
|
||||
if err := conn.Session.Close(); err != nil {
|
||||
log.Printf("Error closing connection to %s: %v", name, err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel context after closing all sessions
|
||||
m.cancel()
|
||||
|
||||
m.servers = make(map[string]*ServerConnection)
|
||||
return lastErr
|
||||
}
|
||||
Reference in New Issue
Block a user