Switched to streaming
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -111,25 +112,90 @@ func (c *Client) GetTools() []openai.Tool {
|
|||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat sends a message and handles tool calls
|
// Chat sends a message and streams the response, handling tool calls
|
||||||
func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage) (string, []openai.ChatCompletionMessage, error) {
|
func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, streamFunc func(string)) (string, []openai.ChatCompletionMessage, error) {
|
||||||
req := openai.ChatCompletionRequest{
|
req := openai.ChatCompletionRequest{
|
||||||
Model: c.model,
|
Model: c.model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Tools: c.GetTools(),
|
Tools: c.GetTools(),
|
||||||
|
Stream: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.client.CreateChatCompletion(ctx, req)
|
stream, err := c.client.CreateChatCompletionStream(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", messages, fmt.Errorf("chat completion failed: %w", err)
|
return "", messages, fmt.Errorf("chat completion stream failed: %w", err)
|
||||||
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
var fullContent strings.Builder
|
||||||
|
var toolCalls []openai.ToolCall
|
||||||
|
|
||||||
|
for {
|
||||||
|
response, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", messages, fmt.Errorf("stream error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := resp.Choices[0]
|
if len(response.Choices) == 0 {
|
||||||
messages = append(messages, choice.Message)
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := response.Choices[0].Delta
|
||||||
|
|
||||||
|
// Handle content streaming
|
||||||
|
if delta.Content != "" {
|
||||||
|
fullContent.WriteString(delta.Content)
|
||||||
|
if streamFunc != nil {
|
||||||
|
streamFunc(delta.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle tool calls
|
// Handle tool calls
|
||||||
if len(choice.Message.ToolCalls) > 0 {
|
if len(delta.ToolCalls) > 0 {
|
||||||
for _, toolCall := range choice.Message.ToolCalls {
|
for _, tc := range delta.ToolCalls {
|
||||||
|
if tc.Index != nil {
|
||||||
|
idx := *tc.Index
|
||||||
|
|
||||||
|
// Ensure we have enough space in the slice
|
||||||
|
for len(toolCalls) <= idx {
|
||||||
|
toolCalls = append(toolCalls, openai.ToolCall{})
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.ID != "" {
|
||||||
|
toolCalls[idx].ID = tc.ID
|
||||||
|
}
|
||||||
|
if tc.Type != "" {
|
||||||
|
toolCalls[idx].Type = tc.Type
|
||||||
|
}
|
||||||
|
if tc.Function.Name != "" {
|
||||||
|
toolCalls[idx].Function.Name = tc.Function.Name
|
||||||
|
}
|
||||||
|
if tc.Function.Arguments != "" {
|
||||||
|
toolCalls[idx].Function.Arguments += tc.Function.Arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the assistant message
|
||||||
|
assistantMsg := openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
Content: fullContent.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
assistantMsg.ToolCalls = toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, assistantMsg)
|
||||||
|
|
||||||
|
// Handle tool calls if present
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
result := c.handleToolCall(ctx, toolCall)
|
result := c.handleToolCall(ctx, toolCall)
|
||||||
|
|
||||||
// Add tool response to messages
|
// Add tool response to messages
|
||||||
@@ -140,11 +206,14 @@ func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessa
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make another call with tool results
|
// Print blank line before streaming the final response
|
||||||
return c.Chat(ctx, messages)
|
fmt.Println()
|
||||||
|
|
||||||
|
// Make another streaming call with tool results
|
||||||
|
return c.Chat(ctx, messages, streamFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
return choice.Message.Content, messages, nil
|
return fullContent.String(), messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleToolCall routes tool calls to the appropriate handler
|
// handleToolCall routes tool calls to the appropriate handler
|
||||||
|
|||||||
11
main.go
11
main.go
@@ -138,9 +138,13 @@ func processQuery(ctx context.Context, client *llm.Client, messages []openai.Cha
|
|||||||
Content: userInput,
|
Content: userInput,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Get response from LLM
|
// Print blank line before streaming starts
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
response, updatedMessages, err := client.Chat(ctx, messages)
|
|
||||||
|
// Get response from LLM with streaming
|
||||||
|
_, updatedMessages, err := client.Chat(ctx, messages, func(chunk string) {
|
||||||
|
fmt.Print(chunk)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "\nError: %v\n\n", err)
|
fmt.Fprintf(os.Stderr, "\nError: %v\n\n", err)
|
||||||
// Remove the failed user message
|
// Remove the failed user message
|
||||||
@@ -150,9 +154,8 @@ func processQuery(ctx context.Context, client *llm.Client, messages []openai.Cha
|
|||||||
// Update messages with the full conversation history
|
// Update messages with the full conversation history
|
||||||
messages = updatedMessages
|
messages = updatedMessages
|
||||||
|
|
||||||
// Print response with empty line before it
|
// Print newline after streaming completes
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
fmt.Println(response)
|
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
Reference in New Issue
Block a user