Fix double reply callback on tool calls

This commit is contained in:
Matt Low 2024-03-17 01:07:52 +00:00
parent ec1f326c2a
commit 62f07dd240
2 changed files with 21 additions and 15 deletions

View File

@ -197,6 +197,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{} sb := strings.Builder{}
isToolCall := false
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@ -271,6 +273,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found") return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
} }
isToolCall = true
funcCallXml := content[start:] funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE funcCallXml += FUNCTION_STOP_SEQUENCE
@ -316,11 +320,13 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "message_stop": case "message_stop":
// return the completed message // return the completed message
if callback != nil { if callback != nil {
if !isToolCall {
callback(model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: sb.String(), Content: sb.String(),
}) })
} }
}
return sb.String(), nil return sb.String(), nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -204,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callbback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
@ -256,23 +256,23 @@ func (c *OpenAIClient) CreateChatCompletionStream(
return content.String(), err return content.String(), err
} }
if callbback != nil { if callback != nil {
for _, result := range results { for _, result := range results {
callbback(result) callback(result)
} }
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} } else {
if callback != nil {
if callbback != nil { callback(model.Message{
callbback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) })
} }
}
return content.String(), err return content.String(), err
} }