273 lines
7.0 KiB
Go
273 lines
7.0 KiB
Go
package llm
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"tell-me/mcp"
|
|
"tell-me/tools"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// Client wraps the OpenAI client for LLM interactions
|
|
type Client struct {
|
|
client *openai.Client
|
|
model string
|
|
contextSize int
|
|
mcpManager *mcp.Manager
|
|
}
|
|
|
|
// NewClient creates a new LLM client
|
|
func NewClient(apiURL, apiKey, model string, contextSize int, mcpManager *mcp.Manager) *Client {
|
|
config := openai.DefaultConfig(apiKey)
|
|
config.BaseURL = apiURL
|
|
|
|
client := openai.NewClientWithConfig(config)
|
|
|
|
return &Client{
|
|
client: client,
|
|
model: model,
|
|
contextSize: contextSize,
|
|
mcpManager: mcpManager,
|
|
}
|
|
}
|
|
|
|
// GetSystemPrompt returns the system prompt with current date appended
|
|
func GetSystemPrompt(prompt string) string {
|
|
currentDate := time.Now().Format("2006-01-02")
|
|
return fmt.Sprintf("%s\n\nCURRENT DATE: %s", prompt, currentDate)
|
|
}
|
|
|
|
// GetTools returns the tool definitions for the LLM (built-in tools only)
|
|
func GetBuiltInTools() []openai.Tool {
|
|
return []openai.Tool{
|
|
{
|
|
Type: openai.ToolTypeFunction,
|
|
Function: &openai.FunctionDefinition{
|
|
Name: "web_search",
|
|
Description: "Search the internet for information using SearXNG. Use this tool to find current information, facts, news, or any knowledge you need to answer the user's question.",
|
|
Parameters: json.RawMessage(`{
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The search query to find relevant information"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}`),
|
|
},
|
|
},
|
|
{
|
|
Type: openai.ToolTypeFunction,
|
|
Function: &openai.FunctionDefinition{
|
|
Name: "fetch_articles",
|
|
Description: "Fetch and read content from 1-5 articles at once. Provide both titles and URLs from search results. The HTML will be converted to clean text format and combined. Use this after searching to read the most relevant pages together.",
|
|
Parameters: json.RawMessage(`{
|
|
"type": "object",
|
|
"properties": {
|
|
"articles": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"title": {
|
|
"type": "string",
|
|
"description": "The title of the article from search results"
|
|
},
|
|
"url": {
|
|
"type": "string",
|
|
"description": "The URL to fetch (must start with http:// or https://)"
|
|
}
|
|
},
|
|
"required": ["title", "url"]
|
|
},
|
|
"description": "Array of articles with titles and URLs (1-5 recommended, max 5)"
|
|
}
|
|
},
|
|
"required": ["articles"]
|
|
}`),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// GetTools returns all available tools (built-in + MCP tools)
|
|
func (c *Client) GetTools() []openai.Tool {
|
|
tools := GetBuiltInTools()
|
|
|
|
// Add MCP tools if manager is available
|
|
if c.mcpManager != nil {
|
|
mcpTools := c.mcpManager.GetAllTools()
|
|
tools = append(tools, mcpTools...)
|
|
}
|
|
|
|
return tools
|
|
}
|
|
|
|
// Chat sends a message and streams the response, handling tool calls
|
|
func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, streamFunc func(string)) (string, []openai.ChatCompletionMessage, error) {
|
|
req := openai.ChatCompletionRequest{
|
|
Model: c.model,
|
|
Messages: messages,
|
|
Tools: c.GetTools(),
|
|
Stream: true,
|
|
}
|
|
|
|
stream, err := c.client.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
return "", messages, fmt.Errorf("chat completion stream failed: %w", err)
|
|
}
|
|
defer stream.Close()
|
|
|
|
var fullContent strings.Builder
|
|
var toolCalls []openai.ToolCall
|
|
|
|
for {
|
|
response, err := stream.Recv()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return "", messages, fmt.Errorf("stream error: %w", err)
|
|
}
|
|
|
|
if len(response.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
delta := response.Choices[0].Delta
|
|
|
|
// Handle content streaming
|
|
if delta.Content != "" {
|
|
fullContent.WriteString(delta.Content)
|
|
if streamFunc != nil {
|
|
streamFunc(delta.Content)
|
|
}
|
|
}
|
|
|
|
// Handle tool calls
|
|
if len(delta.ToolCalls) > 0 {
|
|
for _, tc := range delta.ToolCalls {
|
|
if tc.Index != nil {
|
|
idx := *tc.Index
|
|
|
|
// Ensure we have enough space in the slice
|
|
for len(toolCalls) <= idx {
|
|
toolCalls = append(toolCalls, openai.ToolCall{})
|
|
}
|
|
|
|
if tc.ID != "" {
|
|
toolCalls[idx].ID = tc.ID
|
|
}
|
|
if tc.Type != "" {
|
|
toolCalls[idx].Type = tc.Type
|
|
}
|
|
if tc.Function.Name != "" {
|
|
toolCalls[idx].Function.Name = tc.Function.Name
|
|
}
|
|
if tc.Function.Arguments != "" {
|
|
toolCalls[idx].Function.Arguments += tc.Function.Arguments
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create the assistant message
|
|
assistantMsg := openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
Content: fullContent.String(),
|
|
}
|
|
|
|
if len(toolCalls) > 0 {
|
|
assistantMsg.ToolCalls = toolCalls
|
|
}
|
|
|
|
messages = append(messages, assistantMsg)
|
|
|
|
// Handle tool calls if present
|
|
if len(toolCalls) > 0 {
|
|
for _, toolCall := range toolCalls {
|
|
result := c.handleToolCall(ctx, toolCall)
|
|
|
|
// Add tool response to messages
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
Role: openai.ChatMessageRoleTool,
|
|
Content: result,
|
|
ToolCallID: toolCall.ID,
|
|
})
|
|
}
|
|
|
|
// Print blank line before streaming the final response
|
|
fmt.Println()
|
|
|
|
// Make another streaming call with tool results
|
|
return c.Chat(ctx, messages, streamFunc)
|
|
}
|
|
|
|
return fullContent.String(), messages, nil
|
|
}
|
|
|
|
// handleToolCall routes tool calls to the appropriate handler
|
|
func (c *Client) handleToolCall(ctx context.Context, toolCall openai.ToolCall) string {
|
|
toolName := toolCall.Function.Name
|
|
|
|
// Check if it's a built-in tool
|
|
switch toolName {
|
|
case "web_search":
|
|
var args struct {
|
|
Query string `json:"query"`
|
|
}
|
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
|
return fmt.Sprintf("Error parsing arguments: %v", err)
|
|
}
|
|
fmt.Printf("Searching: %s\n", args.Query)
|
|
result, err := tools.WebSearch(args.Query)
|
|
if err != nil {
|
|
return fmt.Sprintf("Search error: %v", err)
|
|
}
|
|
return result
|
|
|
|
case "fetch_articles":
|
|
var args struct {
|
|
Articles []struct {
|
|
Title string `json:"title"`
|
|
URL string `json:"url"`
|
|
} `json:"articles"`
|
|
}
|
|
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
|
|
return fmt.Sprintf("Error parsing arguments: %v", err)
|
|
}
|
|
fmt.Printf("Reading %d articles:\n", len(args.Articles))
|
|
urls := make([]string, len(args.Articles))
|
|
for i, article := range args.Articles {
|
|
fmt.Printf(" - %s\n", article.Title)
|
|
urls[i] = article.URL
|
|
}
|
|
result, err := tools.FetchArticles(urls)
|
|
if err != nil {
|
|
return fmt.Sprintf("Fetch error: %v", err)
|
|
}
|
|
return result
|
|
|
|
default:
|
|
// Check if it's an MCP tool (format: servername_toolname)
|
|
if c.mcpManager != nil && strings.Contains(toolName, "_") {
|
|
fmt.Printf("Calling MCP tool: %s\n", toolName)
|
|
result, err := c.mcpManager.CallTool(ctx, toolName, toolCall.Function.Arguments)
|
|
if err != nil {
|
|
return fmt.Sprintf("MCP tool error: %v", err)
|
|
}
|
|
return result
|
|
}
|
|
|
|
return fmt.Sprintf("Unknown tool: %s", toolName)
|
|
}
|
|
}
|