Fix double reply callback on tool calls

This commit is contained in:
Matt Low 2024-03-17 01:07:52 +00:00
parent ec1f326c2a
commit 62f07dd240
2 changed files with 21 additions and 15 deletions

View File

@ -197,6 +197,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{}
isToolCall := false
for scanner.Scan() {
line := scanner.Text()
line = strings.TrimSpace(line)
@ -271,6 +273,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
}
isToolCall = true
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
@ -316,10 +320,12 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "message_stop":
// return the completed message
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
if !isToolCall {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
}
}
return sb.String(), nil
case "error":

View File

@ -204,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callbback provider.ReplyCallback,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
client := openai.NewClient(c.APIKey)
@ -256,22 +256,22 @@ func (c *OpenAIClient) CreateChatCompletionStream(
return content.String(), err
}
if callbback != nil {
if callback != nil {
for _, result := range results {
callbback(result)
callback(result)
}
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
}
if callbback != nil {
callbback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} else {
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
}
return content.String(), err