Fix for non-streamed gemini responses

This commit is contained in:
Matt Low 2024-05-19 02:59:43 +00:00
parent b82f3019f0
commit 62d98289e8

View File

@ -144,18 +144,16 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
return converted
}
func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
converted := make([]model.ToolCall, len(parts))
for i, part := range parts {
if part.FunctionCall != nil {
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
converted := make([]model.ToolCall, len(functionCalls))
for i, call := range functionCalls {
params := make(map[string]interface{})
for k, v := range part.FunctionCall.Args {
for k, v := range call.Args {
params[k] = v
}
converted[i].Name = part.FunctionCall.Name
converted[i].Name = call.Name
converted[i].Parameters = params
}
}
return converted
}
@ -361,15 +359,21 @@ func (c *Client) CreateChatCompletion(
content = lastMessage.Content
}
var toolCalls []FunctionCall
for _, part := range choice.Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
}
}
toolCalls := convertToolCallToAPI(choice.Content.Parts)
if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
messages, err := handleToolCalls(
params, content, convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil {
return content, err
}
@ -384,7 +388,6 @@ func (c *Client) CreateChatCompletion(
})
}
// Return the user-facing message.
return content, nil
}
@ -430,7 +433,7 @@ func (c *Client) CreateChatCompletionStream(
content.WriteString(lastMessage.Content)
}
var toolCalls []model.ToolCall
var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body)
for {
@ -458,7 +461,7 @@ func (c *Client) CreateChatCompletionStream(
for _, candidate := range streamResp.Candidates {
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" {
output <- part.Text
content.WriteString(part.Text)
@ -469,7 +472,9 @@ func (c *Client) CreateChatCompletionStream(
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
messages, err := handleToolCalls(
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil {
return content.String(), err
}