diff --git a/llm/client.go b/llm/client.go index 8d79188..e5d3725 100644 --- a/llm/client.go +++ b/llm/client.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "strings" "time" @@ -111,25 +112,90 @@ func (c *Client) GetTools() []openai.Tool { return tools } -// Chat sends a message and handles tool calls -func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage) (string, []openai.ChatCompletionMessage, error) { +// Chat sends a message and streams the response, handling tool calls +func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessage, streamFunc func(string)) (string, []openai.ChatCompletionMessage, error) { req := openai.ChatCompletionRequest{ Model: c.model, Messages: messages, Tools: c.GetTools(), + Stream: true, } - resp, err := c.client.CreateChatCompletion(ctx, req) + stream, err := c.client.CreateChatCompletionStream(ctx, req) if err != nil { - return "", messages, fmt.Errorf("chat completion failed: %w", err) + return "", messages, fmt.Errorf("chat completion stream failed: %w", err) + } + defer stream.Close() + + var fullContent strings.Builder + var toolCalls []openai.ToolCall + + for { + response, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return "", messages, fmt.Errorf("stream error: %w", err) + } + + if len(response.Choices) == 0 { + continue + } + + delta := response.Choices[0].Delta + + // Handle content streaming + if delta.Content != "" { + fullContent.WriteString(delta.Content) + if streamFunc != nil { + streamFunc(delta.Content) + } + } + + // Handle tool calls + if len(delta.ToolCalls) > 0 { + for _, tc := range delta.ToolCalls { + if tc.Index != nil { + idx := *tc.Index + + // Ensure we have enough space in the slice + for len(toolCalls) <= idx { + toolCalls = append(toolCalls, openai.ToolCall{}) + } + + if tc.ID != "" { + toolCalls[idx].ID = tc.ID + } + if tc.Type != "" { + toolCalls[idx].Type = tc.Type + } + if tc.Function.Name != "" { + toolCalls[idx].Function.Name = tc.Function.Name + } + if tc.Function.Arguments != "" { + toolCalls[idx].Function.Arguments += tc.Function.Arguments + } + } + } + } } - choice := resp.Choices[0] - messages = append(messages, choice.Message) + // Create the assistant message + assistantMsg := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: fullContent.String(), + } - // Handle tool calls - if len(choice.Message.ToolCalls) > 0 { - for _, toolCall := range choice.Message.ToolCalls { + if len(toolCalls) > 0 { + assistantMsg.ToolCalls = toolCalls + } + + messages = append(messages, assistantMsg) + + // Handle tool calls if present + if len(toolCalls) > 0 { + for _, toolCall := range toolCalls { result := c.handleToolCall(ctx, toolCall) // Add tool response to messages @@ -140,11 +206,14 @@ func (c *Client) Chat(ctx context.Context, messages []openai.ChatCompletionMessa }) } - // Make another call with tool results - return c.Chat(ctx, messages) + // Print blank line before streaming the final response + fmt.Println() + + // Make another streaming call with tool results + return c.Chat(ctx, messages, streamFunc) } - return choice.Message.Content, messages, nil + return fullContent.String(), messages, nil } // handleToolCall routes tool calls to the appropriate handler diff --git a/main.go b/main.go index bfebe89..4d4cde0 100644 --- a/main.go +++ b/main.go @@ -138,9 +138,13 @@ func processQuery(ctx context.Context, client *llm.Client, messages []openai.Cha Content: userInput, }) - // Get response from LLM + // Print blank line before streaming starts fmt.Println() - response, updatedMessages, err := client.Chat(ctx, messages) + + // Get response from LLM with streaming + _, updatedMessages, err := client.Chat(ctx, messages, func(chunk string) { + fmt.Print(chunk) + }) if err != nil { fmt.Fprintf(os.Stderr, "\nError: %v\n\n", err) // Remove the failed user message @@ -150,9 +154,8 @@ func processQuery(ctx context.Context, client *llm.Client, messages []openai.Cha // Update messages with the full conversation history messages = updatedMessages - // Print response with empty line before it + // Print newline after streaming completes fmt.Println() - fmt.Println(response) fmt.Println() return messages