Fix for non-streamed gemini responses

This commit is contained in:
Matt Low 2024-05-19 02:59:43 +00:00
parent b82f3019f0
commit 62d98289e8
1 changed files with 22 additions and 17 deletions

View File

@ -144,17 +144,15 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
return converted return converted
} }
func convertToolCallToAPI(parts []ContentPart) []model.ToolCall { func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
converted := make([]model.ToolCall, len(parts)) converted := make([]model.ToolCall, len(functionCalls))
for i, part := range parts { for i, call := range functionCalls {
if part.FunctionCall != nil { params := make(map[string]interface{})
params := make(map[string]interface{}) for k, v := range call.Args {
for k, v := range part.FunctionCall.Args { params[k] = v
params[k] = v
}
converted[i].Name = part.FunctionCall.Name
converted[i].Parameters = params
} }
converted[i].Name = call.Name
converted[i].Parameters = params
} }
return converted return converted
} }
@ -239,7 +237,7 @@ func createGenerateContentRequest(
} }
request := &GenerateContentRequest{ request := &GenerateContentRequest{
Contents: requestContents, Contents: requestContents,
SystemInstructions: system, SystemInstructions: system,
GenerationConfig: &GenerationConfig{ GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens, MaxOutputTokens: &params.MaxTokens,
@ -361,15 +359,21 @@ func (c *Client) CreateChatCompletion(
content = lastMessage.Content content = lastMessage.Content
} }
var toolCalls []FunctionCall
for _, part := range choice.Content.Parts { for _, part := range choice.Content.Parts {
if part.Text != "" { if part.Text != "" {
content += part.Text content += part.Text
} }
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
}
} }
toolCalls := convertToolCallToAPI(choice.Content.Parts)
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content, toolCalls, callback, messages) messages, err := handleToolCalls(
params, content, convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil { if err != nil {
return content, err return content, err
} }
@ -384,7 +388,6 @@ func (c *Client) CreateChatCompletion(
}) })
} }
// Return the user-facing message.
return content, nil return content, nil
} }
@ -430,7 +433,7 @@ func (c *Client) CreateChatCompletionStream(
content.WriteString(lastMessage.Content) content.WriteString(lastMessage.Content)
} }
var toolCalls []model.ToolCall var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
for { for {
@ -458,7 +461,7 @@ func (c *Client) CreateChatCompletionStream(
for _, candidate := range streamResp.Candidates { for _, candidate := range streamResp.Candidates {
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- part.Text output <- part.Text
content.WriteString(part.Text) content.WriteString(part.Text)
@ -469,7 +472,9 @@ func (c *Client) CreateChatCompletionStream(
// If there are function calls, handle them and recurse // If there are function calls, handle them and recurse
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) messages, err := handleToolCalls(
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
)
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }