Added MCP support
This commit is contained in:
176
llm/client.go
176
llm/client.go
@@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.netra.pivpav.com/public/tell-me/tools"
|
||||
"tell-me/mcp"
|
||||
"tell-me/tools"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
@@ -16,10 +19,11 @@ type Client struct {
|
||||
model string
|
||||
contextSize int
|
||||
searxngURL string
|
||||
mcpManager *mcp.Manager
|
||||
}
|
||||
|
||||
// NewClient creates a new LLM client
|
||||
func NewClient(apiURL, apiKey, model string, contextSize int, searxngURL string) *Client {
|
||||
func NewClient(apiURL, apiKey, model string, contextSize int, searxngURL string, mcpManager *mcp.Manager) *Client {
|
||||
config := openai.DefaultConfig(apiKey)
|
||||
config.BaseURL = apiURL
|
||||
|
||||
@@ -30,59 +34,18 @@ func NewClient(apiURL, apiKey, model string, contextSize int, searxngURL string)
|
||||
model: model,
|
||||
contextSize: contextSize,
|
||||
searxngURL: searxngURL,
|
||||
mcpManager: mcpManager,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSystemPrompt returns the system prompt that enforces search-first behavior
|
||||
func GetSystemPrompt() string {
|
||||
// GetSystemPrompt returns the system prompt with current date appended
|
||||
func GetSystemPrompt(prompt string) string {
|
||||
currentDate := time.Now().Format("2006-01-02")
|
||||
|
||||
return fmt.Sprintf(`You are a helpful AI research assistant with access to web search and article fetching capabilities.
|
||||
|
||||
RESEARCH WORKFLOW - MANDATORY STEPS:
|
||||
1. For questions requiring current information, facts, or knowledge beyond your training data:
|
||||
- Perform MULTIPLE searches (typically 2-3) with DIFFERENT query angles to gather comprehensive information
|
||||
- Vary your search terms to capture different perspectives and sources
|
||||
|
||||
2. After completing ALL searches, analyze the combined results:
|
||||
- Review ALL search results from your multiple searches together
|
||||
- Identify the 3-5 MOST relevant and authoritative URLs across ALL searches
|
||||
- Prioritize: official sources, reputable news sites, technical documentation, expert reviews
|
||||
- Look for sources that complement each other (e.g., official specs + expert analysis + user reviews)
|
||||
|
||||
3. Fetch the selected articles:
|
||||
- Use fetch_articles with the 3-5 best URLs you identified from ALL your searches
|
||||
- Read all fetched content thoroughly before formulating your answer
|
||||
- Synthesize information from multiple sources for a comprehensive response
|
||||
|
||||
HANDLING USER CORRECTIONS - CRITICAL:
|
||||
When a user indicates your answer is incorrect, incomplete, or needs clarification:
|
||||
1. NEVER argue or defend your previous answer
|
||||
2. IMMEDIATELY acknowledge the correction: "Let me search for more accurate information"
|
||||
3. Perform NEW searches with DIFFERENT queries based on the user's feedback
|
||||
4. Fetch NEW sources that address the specific correction or clarification needed
|
||||
5. Provide an updated answer based on the new research
|
||||
6. If the user provides specific information, incorporate it and verify with additional searches
|
||||
|
||||
Remember: The user may have more current or specific knowledge. Your role is to research and verify, not to argue.
|
||||
|
||||
OUTPUT FORMATTING RULES:
|
||||
- NEVER include source URLs or citations in your response
|
||||
- DO NOT use Markdown formatting (no **, ##, -, *, [], etc.)
|
||||
- Write in plain text only - use natural language without any special formatting
|
||||
- For emphasis, use CAPITAL LETTERS instead of bold or italics
|
||||
- For lists, use simple numbered lines (1., 2., 3.) or write as flowing paragraphs
|
||||
- Keep output clean and readable for terminal display
|
||||
|
||||
Available tools:
|
||||
- web_search: Search the internet (can be used multiple times with different queries)
|
||||
- fetch_articles: Fetch and read content from 1-5 URLs at once
|
||||
|
||||
CURRENT DATE: %s`, currentDate)
|
||||
return fmt.Sprintf("%s\n\nCURRENT DATE: %s", prompt, currentDate)
|
||||
}
|
||||
|
||||
// GetTools returns the tool definitions for the LLM
|
||||
func GetTools() []openai.Tool {
|
||||
// GetTools returns the tool definitions for the LLM (built-in tools only)
|
||||
func GetBuiltInTools() []openai.Tool {
|
||||
return []openai.Tool{
|
||||
{
|
||||
Type: openai.ToolTypeFunction,
|
||||
@@ -135,12 +98,25 @@ func GetTools() []openai.Tool {
|
||||
}
|
||||
}
|
||||
|
||||
// GetTools returns all available tools (built-in + MCP tools)
|
||||
func (c *Client) GetTools() []openai.Tool {
|
||||
tools := GetBuiltInTools()
|
||||
|
||||
// Add MCP tools if manager is available
|
||||
if c.mcpManager != nil {
|
||||
mcpTools := c.mcpManager.GetAllTools()
|
||||
tools = append(tools, mcpTools...)
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// Chat sends a message and handles tool calls
|
||||
func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage) (string, []openai.ChatCompletionMessage, error) {
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: c.model,
|
||||
Messages: messages,
|
||||
Tools: GetTools(),
|
||||
Tools: c.GetTools(),
|
||||
}
|
||||
|
||||
resp, err := c.client.CreateChatCompletion(ctx, req)
|
||||
@@ -154,48 +130,7 @@ func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessa
|
||||
// Handle tool calls
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
for _, toolCall := range choice.Message.ToolCalls {
|
||||
var result string
|
||||
|
||||
switch toolCall.Function.Name {
|
||||
case "web_search":
|
||||
var args struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
result = fmt.Sprintf("Error parsing arguments: %v", err)
|
||||
} else {
|
||||
fmt.Printf("Searching: %s\n", args.Query)
|
||||
result, err = tools.WebSearch(c.searxngURL, args.Query)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Search error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
case "fetch_articles":
|
||||
var args struct {
|
||||
Articles []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
} `json:"articles"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
result = fmt.Sprintf("Error parsing arguments: %v", err)
|
||||
} else {
|
||||
fmt.Printf("Reading %d articles:\n", len(args.Articles))
|
||||
urls := make([]string, len(args.Articles))
|
||||
for i, article := range args.Articles {
|
||||
fmt.Printf(" - %s\n", article.Title)
|
||||
urls[i] = article.URL
|
||||
}
|
||||
result, err = tools.FetchArticles(urls)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Fetch error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
result = fmt.Sprintf("Unknown tool: %s", toolCall.Function.Name)
|
||||
}
|
||||
result := c.handleToolCall(ctx, toolCall)
|
||||
|
||||
// Add tool response to messages
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
@@ -211,3 +146,60 @@ func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessa
|
||||
|
||||
return choice.Message.Content, messages, nil
|
||||
}
|
||||
|
||||
// handleToolCall routes tool calls to the appropriate handler
|
||||
func (c *Client) handleToolCall(ctx context.Context, toolCall openai.ToolCall) string {
|
||||
toolName := toolCall.Function.Name
|
||||
|
||||
// Check if it's a built-in tool
|
||||
switch toolName {
|
||||
case "web_search":
|
||||
var args struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
return fmt.Sprintf("Error parsing arguments: %v", err)
|
||||
}
|
||||
fmt.Printf("Searching: %s\n", args.Query)
|
||||
result, err := tools.WebSearch(c.searxngURL, args.Query)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Search error: %v", err)
|
||||
}
|
||||
return result
|
||||
|
||||
case "fetch_articles":
|
||||
var args struct {
|
||||
Articles []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
} `json:"articles"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
||||
return fmt.Sprintf("Error parsing arguments: %v", err)
|
||||
}
|
||||
fmt.Printf("Reading %d articles:\n", len(args.Articles))
|
||||
urls := make([]string, len(args.Articles))
|
||||
for i, article := range args.Articles {
|
||||
fmt.Printf(" - %s\n", article.Title)
|
||||
urls[i] = article.URL
|
||||
}
|
||||
result, err := tools.FetchArticles(urls)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Fetch error: %v", err)
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
// Check if it's an MCP tool (format: servername_toolname)
|
||||
if c.mcpManager != nil && strings.Contains(toolName, "_") {
|
||||
fmt.Printf("Calling MCP tool: %s\n", toolName)
|
||||
result, err := c.mcpManager.CallTool(ctx, toolName, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("MCP tool error: %v", err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Unknown tool: %s", toolName)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user