189 lines
4.8 KiB
Go
189 lines
4.8 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"syscall"
|
||
|
||
"tell-me/config"
|
||
"tell-me/llm"
|
||
"tell-me/mcp"
|
||
|
||
"github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
func main() {
|
||
// Load configuration
|
||
cfg, err := config.Load()
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "Error loading configuration: %v\n", err)
|
||
fmt.Fprintf(os.Stderr, "Please create ~/.config/tell-me.yaml from tell-me.yaml.example\n")
|
||
os.Exit(1)
|
||
}
|
||
|
||
ctx := context.Background()
|
||
|
||
// Initialize MCP manager
|
||
mcpManager := mcp.NewManager(ctx)
|
||
defer mcpManager.Close()
|
||
|
||
// Connect to MCP servers if configured
|
||
if len(cfg.MCPServers) > 0 {
|
||
if err := mcpManager.ConnectServers(cfg.MCPServers); err != nil {
|
||
log.Printf("Warning: Failed to connect to some MCP servers: %v", err)
|
||
}
|
||
}
|
||
|
||
// Create LLM client with MCP manager
|
||
client := llm.NewClient(
|
||
cfg.APIURL,
|
||
cfg.APIKey,
|
||
cfg.Model,
|
||
cfg.ContextSize,
|
||
cfg.SearXNGURL,
|
||
mcpManager,
|
||
)
|
||
|
||
// Initialize conversation with system prompt from config
|
||
messages := []openai.ChatCompletionMessage{
|
||
{
|
||
Role: openai.ChatMessageRoleSystem,
|
||
Content: llm.GetSystemPrompt(cfg.Prompt),
|
||
},
|
||
}
|
||
|
||
// Check if arguments are provided (non-interactive mode)
|
||
if len(os.Args) > 1 {
|
||
query := strings.Join(os.Args[1:], " ")
|
||
// Display MCP status in non-interactive mode if servers are configured
|
||
if len(cfg.MCPServers) > 0 {
|
||
displayMCPStatusInline(mcpManager)
|
||
}
|
||
processQuery(ctx, client, messages, query)
|
||
return
|
||
}
|
||
|
||
// Setup signal handling for Ctrl-C
|
||
sigChan := make(chan os.Signal, 1)
|
||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||
go func() {
|
||
<-sigChan
|
||
fmt.Println("\n\nGoodbye!")
|
||
os.Exit(0)
|
||
}()
|
||
|
||
// Print welcome message with MCP status
|
||
fmt.Println("╔════════════════════════════════════════════════════════════════╗")
|
||
fmt.Println("║ Tell-Me CLI ║")
|
||
fmt.Println("║ AI-powered search with local LLM support ║")
|
||
fmt.Println("╚════════════════════════════════════════════════════════════════╝")
|
||
fmt.Println()
|
||
fmt.Printf("Using model: %s\n", cfg.Model)
|
||
fmt.Printf("SearXNG: %s\n", cfg.SearXNGURL)
|
||
|
||
// Display MCP server status
|
||
if len(cfg.MCPServers) > 0 {
|
||
fmt.Println()
|
||
displayMCPStatusInline(mcpManager)
|
||
}
|
||
|
||
fmt.Println()
|
||
fmt.Println("Type your questions below. Type 'exit' or 'quit' to exit, or press Ctrl-C.")
|
||
fmt.Println("────────────────────────────────────────────────────────────────")
|
||
fmt.Println()
|
||
|
||
// Create scanner for user input
|
||
scanner := bufio.NewScanner(os.Stdin)
|
||
|
||
for {
|
||
// Prompt for user input
|
||
fmt.Print("❯ ")
|
||
if !scanner.Scan() {
|
||
break
|
||
}
|
||
|
||
userInput := strings.TrimSpace(scanner.Text())
|
||
|
||
// Check for exit commands
|
||
if userInput == "exit" || userInput == "quit" {
|
||
fmt.Println("\nGoodbye!")
|
||
break
|
||
}
|
||
|
||
// Skip empty input
|
||
if userInput == "" {
|
||
continue
|
||
}
|
||
|
||
// Process the query
|
||
messages = processQuery(ctx, client, messages, userInput)
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
fmt.Fprintf(os.Stderr, "Error reading input: %v\n", err)
|
||
os.Exit(1)
|
||
}
|
||
}
|
||
|
||
// processQuery handles a single query and returns updated messages
|
||
func processQuery(ctx context.Context, client *llm.Client, messages []openai.ChatCompletionMessage, userInput string) []openai.ChatCompletionMessage {
|
||
// Add user message to conversation
|
||
messages = append(messages, openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleUser,
|
||
Content: userInput,
|
||
})
|
||
|
||
// Print blank line before streaming starts
|
||
fmt.Println()
|
||
|
||
// Get response from LLM with streaming
|
||
_, updatedMessages, err := client.Chat(ctx, messages, func(chunk string) {
|
||
fmt.Print(chunk)
|
||
})
|
||
if err != nil {
|
||
fmt.Fprintf(os.Stderr, "\nError: %v\n\n", err)
|
||
// Remove the failed user message
|
||
return messages[:len(messages)-1]
|
||
}
|
||
|
||
// Update messages with the full conversation history
|
||
messages = updatedMessages
|
||
|
||
// Print newline after streaming completes
|
||
fmt.Println()
|
||
fmt.Println()
|
||
|
||
return messages
|
||
}
|
||
|
||
// displayMCPStatusInline shows MCP server status in the header
|
||
func displayMCPStatusInline(manager *mcp.Manager) {
|
||
statuses := manager.GetDetailedStatus()
|
||
|
||
if len(statuses) == 0 {
|
||
return
|
||
}
|
||
|
||
fmt.Print("MCP Servers: ")
|
||
|
||
for i, status := range statuses {
|
||
if i > 0 {
|
||
fmt.Print(", ")
|
||
}
|
||
|
||
if status.Error != "" {
|
||
// Red X for error
|
||
fmt.Printf("\033[31m✗\033[0m %s", status.Name)
|
||
} else {
|
||
// Green checkmark for OK
|
||
fmt.Printf("\033[32m✓\033[0m %s (%d tools)", status.Name, len(status.Tools))
|
||
}
|
||
}
|
||
fmt.Println()
|
||
}
|