Add tool calling support to streamed requests

This commit is contained in:
Matt Low 2023-11-26 02:46:38 +00:00
parent b229c42811
commit fa27f83630

View File

@ -71,9 +71,8 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
messages = append(messages, Message{
ConversationID: messages[0].ConversationID,
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
@ -103,7 +102,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
}
defer stream.Close()
sb := strings.Builder{}
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
@ -114,9 +116,47 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
err = e
break
}
chunk := response.Choices[0].Delta.Content
output <- chunk
sb.WriteString(chunk)
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
return sb.String(), err
if len(toolCalls) > 0 {
if content.String() != "" {
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
}
// Append the assistant's reply with its request for tool calls
toolCallJson, _ := json.Marshal(toolCalls)
messages = append(messages, Message{
Role: "assistant",
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
})
toolReplies, err := ExecuteToolCalls(toolCalls)
if err != nil {
return "", err
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
}
return content.String(), err
}