From 62f07dd240567a8ae6055e4b956691c351220eff Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 17 Mar 2024 01:07:52 +0000 Subject: [PATCH] Fix double reply callback on tool calls --- pkg/lmcli/provider/anthropic/anthropic.go | 14 ++++++++++---- pkg/lmcli/provider/openai/openai.go | 22 +++++++++++----------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index 7ed1645..8a410de 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -197,6 +197,8 @@ func (c *AnthropicClient) CreateChatCompletionStream( scanner := bufio.NewScanner(resp.Body) sb := strings.Builder{} + isToolCall := false + for scanner.Scan() { line := scanner.Text() line = strings.TrimSpace(line) @@ -271,6 +273,8 @@ func (c *AnthropicClient) CreateChatCompletionStream( return content, fmt.Errorf("reached stop sequence but no opening tag found") } + isToolCall = true + funcCallXml := content[start:] funcCallXml += FUNCTION_STOP_SEQUENCE @@ -316,10 +320,12 @@ func (c *AnthropicClient) CreateChatCompletionStream( case "message_stop": // return the completed message if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: sb.String(), - }) + if !isToolCall { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: sb.String(), + }) + } } return sb.String(), nil case "error": diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index 8a01149..89d1309 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -204,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, - callbback provider.ReplyCallback, + callback provider.ReplyCallback, output chan<- string, ) (string, error) { client := openai.NewClient(c.APIKey) @@ -256,22 +256,22 @@ func (c *OpenAIClient) CreateChatCompletionStream( return content.String(), err } - if callbback != nil { + if callback != nil { for _, result := range results { - callbback(result) + callback(result) } } // Recurse into CreateChatCompletionStream with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) - } - - if callbback != nil { - callbback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content.String(), - }) + return c.CreateChatCompletionStream(ctx, params, messages, callback, output) + } else { + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: content.String(), + }) + } } return content.String(), err