Add tool calling support to streamed requests
This commit is contained in:
parent
b229c42811
commit
fa27f83630
@ -71,7 +71,6 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
|||||||
// Append the assistant's reply with its request for tool calls
|
// Append the assistant's reply with its request for tool calls
|
||||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||||
messages = append(messages, Message{
|
messages = append(messages, Message{
|
||||||
ConversationID: messages[0].ConversationID,
|
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||||
})
|
})
|
||||||
@ -103,7 +102,10 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
}
|
}
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
sb := strings.Builder{}
|
content := strings.Builder{}
|
||||||
|
toolCalls := []openai.ToolCall{}
|
||||||
|
|
||||||
|
// Iterate stream segments
|
||||||
for {
|
for {
|
||||||
response, e := stream.Recv()
|
response, e := stream.Recv()
|
||||||
if errors.Is(e, io.EOF) {
|
if errors.Is(e, io.EOF) {
|
||||||
@ -114,9 +116,47 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
err = e
|
err = e
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
chunk := response.Choices[0].Delta.Content
|
|
||||||
output <- chunk
|
delta := response.Choices[0].Delta
|
||||||
sb.WriteString(chunk)
|
if len(delta.ToolCalls) > 0 {
|
||||||
|
// Construct streamed tool_call arguments
|
||||||
|
for _, tc := range delta.ToolCalls {
|
||||||
|
if tc.Index == nil {
|
||||||
|
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||||
}
|
}
|
||||||
return sb.String(), err
|
if len(toolCalls) <= *tc.Index {
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
} else {
|
||||||
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output <- delta.Content
|
||||||
|
content.WriteString(delta.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
if content.String() != "" {
|
||||||
|
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the assistant's reply with its request for tool calls
|
||||||
|
toolCallJson, _ := json.Marshal(toolCalls)
|
||||||
|
messages = append(messages, Message{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||||
|
})
|
||||||
|
|
||||||
|
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
|
// added to the original messages
|
||||||
|
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
return content.String(), err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user