Compare commits

..

2 Commits

Author SHA1 Message Date
Pavel Pivovarov
35733aa3e2 Switched to streaming 2025-12-15 16:00:58 +11:00
Pavel Pivovarov
1d659006ed Fixing verbosity 2025-12-15 15:49:49 +11:00
3 changed files with 92 additions and 19 deletions

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"strings" "strings"
"time" "time"
@@ -111,25 +112,90 @@ func (c *Client) GetTools() []openai.Tool {
return tools return tools
} }
// Chat sends a message and handles tool calls // Chat sends a message and streams the response, handling tool calls
func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage) (string, []openai.ChatCompletionMessage, error) { func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, streamFunc func(string)) (string, []openai.ChatCompletionMessage, error) {
req := openai.ChatCompletionRequest{ req := openai.ChatCompletionRequest{
Model: c.model, Model: c.model,
Messages: messages, Messages: messages,
Tools: c.GetTools(), Tools: c.GetTools(),
Stream: true,
} }
resp, err := c.client.CreateChatCompletion(ctx, req) stream, err := c.client.CreateChatCompletionStream(ctx, req)
if err != nil { if err != nil {
return "", messages, fmt.Errorf("chat completion failed: %w", err) return "", messages, fmt.Errorf("chat completion stream failed: %w", err)
}
defer stream.Close()
var fullContent strings.Builder
var toolCalls []openai.ToolCall
for {
response, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return "", messages, fmt.Errorf("stream error: %w", err)
} }
choice := resp.Choices[0] if len(response.Choices) == 0 {
messages = append(messages, choice.Message) continue
}
delta := response.Choices[0].Delta
// Handle content streaming
if delta.Content != "" {
fullContent.WriteString(delta.Content)
if streamFunc != nil {
streamFunc(delta.Content)
}
}
// Handle tool calls // Handle tool calls
if len(choice.Message.ToolCalls) > 0 { if len(delta.ToolCalls) > 0 {
for _, toolCall := range choice.Message.ToolCalls { for _, tc := range delta.ToolCalls {
if tc.Index != nil {
idx := *tc.Index
// Ensure we have enough space in the slice
for len(toolCalls) <= idx {
toolCalls = append(toolCalls, openai.ToolCall{})
}
if tc.ID != "" {
toolCalls[idx].ID = tc.ID
}
if tc.Type != "" {
toolCalls[idx].Type = tc.Type
}
if tc.Function.Name != "" {
toolCalls[idx].Function.Name = tc.Function.Name
}
if tc.Function.Arguments != "" {
toolCalls[idx].Function.Arguments += tc.Function.Arguments
}
}
}
}
}
// Create the assistant message
assistantMsg := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: fullContent.String(),
}
if len(toolCalls) > 0 {
assistantMsg.ToolCalls = toolCalls
}
messages = append(messages, assistantMsg)
// Handle tool calls if present
if len(toolCalls) > 0 {
for _, toolCall := range toolCalls {
result := c.handleToolCall(ctx, toolCall) result := c.handleToolCall(ctx, toolCall)
// Add tool response to messages // Add tool response to messages
@@ -140,11 +206,14 @@ func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessa
}) })
} }
// Make another call with tool results // Print blank line before streaming the final response
return c.Chat(ctx, messages) fmt.Println()
// Make another streaming call with tool results
return c.Chat(ctx, messages, streamFunc)
} }
return choice.Message.Content, messages, nil return fullContent.String(), messages, nil
} }
// handleToolCall routes tool calls to the appropriate handler // handleToolCall routes tool calls to the appropriate handler

16
main.go
View File

@@ -34,7 +34,6 @@ func main() {
// Connect to MCP servers if configured // Connect to MCP servers if configured
if len(cfg.MCPServers) > 0 { if len(cfg.MCPServers) > 0 {
fmt.Println("Connecting to MCP servers...")
if err := mcpManager.ConnectServers(cfg.MCPServers); err != nil { if err := mcpManager.ConnectServers(cfg.MCPServers); err != nil {
log.Printf("Warning: Failed to connect to some MCP servers: %v", err) log.Printf("Warning: Failed to connect to some MCP servers: %v", err)
} }
@@ -61,6 +60,10 @@ func main() {
// Check if arguments are provided (non-interactive mode) // Check if arguments are provided (non-interactive mode)
if len(os.Args) > 1 { if len(os.Args) > 1 {
query := strings.Join(os.Args[1:], " ") query := strings.Join(os.Args[1:], " ")
// Display MCP status in non-interactive mode if servers are configured
if len(cfg.MCPServers) > 0 {
displayMCPStatusInline(mcpManager)
}
processQuery(ctx, client, messages, query) processQuery(ctx, client, messages, query)
return return
} }
@@ -135,9 +138,13 @@ func processQuery(ctx context.Context, client *llm.Client, messages []openai.Cha
Content: userInput, Content: userInput,
}) })
// Get response from LLM // Print blank line before streaming starts
fmt.Println() fmt.Println()
response, updatedMessages, err := client.Chat(ctx, messages)
// Get response from LLM with streaming
_, updatedMessages, err := client.Chat(ctx, messages, func(chunk string) {
fmt.Print(chunk)
})
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "\nError: %v\n\n", err) fmt.Fprintf(os.Stderr, "\nError: %v\n\n", err)
// Remove the failed user message // Remove the failed user message
@@ -147,9 +154,8 @@ func processQuery(ctx context.Context, client *llm.Client, messages []openai.Cha
// Update messages with the full conversation history // Update messages with the full conversation history
messages = updatedMessages messages = updatedMessages
// Print response with empty line before it // Print newline after streaming completes
fmt.Println() fmt.Println()
fmt.Println(response)
fmt.Println() fmt.Println()
return messages return messages

View File

@@ -49,7 +49,6 @@ func (m *Manager) ConnectServers(servers map[string]config.MCPServer) error {
for name, serverCfg := range servers { for name, serverCfg := range servers {
if err := m.connectServer(name, serverCfg); err != nil { if err := m.connectServer(name, serverCfg); err != nil {
log.Printf("Warning: Failed to connect to MCP server %s: %v", name, err)
// Store the error in the connection // Store the error in the connection
m.servers[name] = &ServerConnection{ m.servers[name] = &ServerConnection{
Name: name, Name: name,
@@ -58,7 +57,6 @@ func (m *Manager) ConnectServers(servers map[string]config.MCPServer) error {
} }
continue continue
} }
log.Printf("Successfully connected to MCP server: %s", name)
} }
return nil return nil