diff --git a/pkg/lmcli/provider/google/google.go b/pkg/lmcli/provider/google/google.go index cad3d0c..bf7f8ae 100644 --- a/pkg/lmcli/provider/google/google.go +++ b/pkg/lmcli/provider/google/google.go @@ -144,17 +144,15 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart { return converted } -func convertToolCallToAPI(parts []ContentPart) []model.ToolCall { - converted := make([]model.ToolCall, len(parts)) - for i, part := range parts { - if part.FunctionCall != nil { - params := make(map[string]interface{}) - for k, v := range part.FunctionCall.Args { - params[k] = v - } - converted[i].Name = part.FunctionCall.Name - converted[i].Parameters = params +func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall { + converted := make([]model.ToolCall, len(functionCalls)) + for i, call := range functionCalls { + params := make(map[string]interface{}) + for k, v := range call.Args { + params[k] = v } + converted[i].Name = call.Name + converted[i].Parameters = params } return converted } @@ -239,7 +237,7 @@ func createGenerateContentRequest( } request := &GenerateContentRequest{ - Contents: requestContents, + Contents: requestContents, SystemInstructions: system, GenerationConfig: &GenerationConfig{ MaxOutputTokens: ¶ms.MaxTokens, @@ -361,15 +359,21 @@ func (c *Client) CreateChatCompletion( content = lastMessage.Content } + var toolCalls []FunctionCall for _, part := range choice.Content.Parts { if part.Text != "" { content += part.Text } + + if part.FunctionCall != nil { + toolCalls = append(toolCalls, *part.FunctionCall) + } } - toolCalls := convertToolCallToAPI(choice.Content.Parts) if len(toolCalls) > 0 { - messages, err := handleToolCalls(params, content, toolCalls, callback, messages) + messages, err := handleToolCalls( + params, content, convertToolCallToAPI(toolCalls), callback, messages, + ) if err != nil { return content, err } @@ -384,7 +388,6 @@ func (c *Client) CreateChatCompletion( }) } - // Return the user-facing message. return content, nil } @@ -430,7 +433,7 @@ func (c *Client) CreateChatCompletionStream( content.WriteString(lastMessage.Content) } - var toolCalls []model.ToolCall + var toolCalls []FunctionCall reader := bufio.NewReader(resp.Body) for { @@ -458,7 +461,7 @@ func (c *Client) CreateChatCompletionStream( for _, candidate := range streamResp.Candidates { for _, part := range candidate.Content.Parts { if part.FunctionCall != nil { - toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...) + toolCalls = append(toolCalls, *part.FunctionCall) } else if part.Text != "" { output <- part.Text content.WriteString(part.Text) @@ -469,7 +472,9 @@ func (c *Client) CreateChatCompletionStream( // If there are function calls, handle them and recurse if len(toolCalls) > 0 { - messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) + messages, err := handleToolCalls( + params, content.String(), convertToolCallToAPI(toolCalls), callback, messages, + ) if err != nil { return content.String(), err }