From fa27f8363050a839d34a20b4499ae8932a6b47d8 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 26 Nov 2023 02:46:38 +0000 Subject: [PATCH] Add tool calling support to streamed requests --- pkg/cli/openai.go | 56 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index d490e44..c03cf30 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -71,9 +71,8 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri // Append the assistant's reply with its request for tool calls toolCallJson, _ := json.Marshal(choice.Message.ToolCalls) messages = append(messages, Message{ - ConversationID: messages[0].ConversationID, - Role: "assistant", - ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, + Role: "assistant", + ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, }) toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls) @@ -103,7 +102,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, } defer stream.Close() - sb := strings.Builder{} + content := strings.Builder{} + toolCalls := []openai.ToolCall{} + + // Iterate stream segments for { response, e := stream.Recv() if errors.Is(e, io.EOF) { @@ -114,9 +116,47 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, err = e break } - chunk := response.Choices[0].Delta.Content - output <- chunk - sb.WriteString(chunk) + + delta := response.Choices[0].Delta + if len(delta.ToolCalls) > 0 { + // Construct streamed tool_call arguments + for _, tc := range delta.ToolCalls { + if tc.Index == nil { + return "", fmt.Errorf("Unexpected nil index for streamed tool call.") + } + if len(toolCalls) <= *tc.Index { + toolCalls = append(toolCalls, tc) + } else { + toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments + } + } + } else { + output <- delta.Content + content.WriteString(delta.Content) + } } - return sb.String(), err + + if len(toolCalls) > 0 { + if content.String() != "" { + return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.") + } + + // Append the assistant's reply with its request for tool calls + toolCallJson, _ := json.Marshal(toolCalls) + messages = append(messages, Message{ + Role: "assistant", + ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true}, + }) + + toolReplies, err := ExecuteToolCalls(toolCalls) + if err != nil { + return "", err + } + + // Recurse into CreateChatCompletionStream with the tool call replies + // added to the original messages + return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output) + } + + return content.String(), err }