Add tool calling support to streamed requests

This commit is contained in:
Matt Low 2023-11-26 02:46:38 +00:00
parent bf1f23b1d6
commit 3e59702c80

View File

@ -71,7 +71,6 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
// Append the assistant's reply with its request for tool calls // Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{ messages = append(messages, Message{
ConversationID: messages[0].ConversationID,
Role: "assistant", Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
}) })
@ -103,7 +102,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
} }
defer stream.Close() defer stream.Close()
sb := strings.Builder{} content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for { for {
response, e := stream.Recv() response, e := stream.Recv()
if errors.Is(e, io.EOF) { if errors.Is(e, io.EOF) {
@ -114,9 +116,47 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
err = e err = e
break break
} }
chunk := response.Choices[0].Delta.Content
output <- chunk delta := response.Choices[0].Delta
sb.WriteString(chunk) if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
} }
return sb.String(), err if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
if content.String() != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
messages = append(messages, Message{
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
}
return content.String(), err
} }