From 1b8d04c96d0d4ea1322b120694fc0f58d8f18de3 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 18 May 2024 23:18:53 +0000 Subject: [PATCH] Gemini fixes, tool calling --- pkg/lmcli/lmcli.go | 18 +++++++---- pkg/lmcli/provider/google/google.go | 46 +++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 6a5b4c0..a4d2012 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -8,6 +8,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/util" @@ -75,21 +76,28 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl if p.BaseURL != nil { url = *p.BaseURL } - anthropic := &anthropic.AnthropicClient{ + return &anthropic.AnthropicClient{ BaseURL: url, APIKey: *p.APIKey, + }, nil + case "google": + url := "https://generativelanguage.googleapis.com" + if p.BaseURL != nil { + url = *p.BaseURL } - return anthropic, nil + return &google.Client{ + BaseURL: url, + APIKey: *p.APIKey, + }, nil case "openai": url := "https://api.openai.com/v1" if p.BaseURL != nil { url = *p.BaseURL } - openai := &openai.OpenAIClient{ + return &openai.OpenAIClient{ BaseURL: url, APIKey: *p.APIKey, - } - return openai, nil + }, nil default: return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) } diff --git a/pkg/lmcli/provider/google/google.go b/pkg/lmcli/provider/google/google.go index 1bf2a0c..88eb7bc 100644 --- a/pkg/lmcli/provider/google/google.go +++ b/pkg/lmcli/provider/google/google.go @@ -173,8 +173,20 @@ func createGenerateContentRequest( requestContents = append(requestContents, content) } default: + var role string + switch m.Role { + case model.MessageRoleAssistant: + role = "model" + case model.MessageRoleUser: + role = "user" + } + + if role == "" { + panic("Unhandled role: " + m.Role) + } + content := Content{ - Role: string(m.Role), + Role: role, Parts: []ContentPart{ { Text: m.Content, @@ -242,7 +254,6 @@ func handleToolCalls( func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.APIKey) client := &http.Client{} resp, err := client.Do(req.WithContext(ctx)) @@ -271,7 +282,11 @@ func (c *Client) CreateChatCompletion( return "", err } - httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData)) + url := fmt.Sprintf( + "%s/v1beta/models/%s:generateContent?key=%s", + c.BaseURL, params.Model, c.APIKey, + ) + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return "", err } @@ -330,10 +345,6 @@ func (c *Client) CreateChatCompletionStream( callback provider.ReplyCallback, output chan<- string, ) (string, error) { - if len(params.ToolBag) > 0 { - return "", fmt.Errorf("Tool calling is not supported in streaming mode.") - } - if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } @@ -344,7 +355,11 @@ func (c *Client) CreateChatCompletionStream( return "", err } - httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData)) + url := fmt.Sprintf( + "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", + c.BaseURL, params.Model, c.APIKey, + ) + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { return "", err } @@ -386,12 +401,25 @@ func (c *Client) CreateChatCompletionStream( } for _, candidate := range streamResp.Candidates { + var toolCalls []model.ToolCall + for _, part := range candidate.Content.Parts { - if part.Text != "" { + if part.FunctionCall != nil { + toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...) + } else if part.Text != "" { output <- part.Text content.WriteString(part.Text) } } + + // If there are function calls, handle them and recurse + if len(toolCalls) > 0 { + messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) + if err != nil { + return content.String(), err + } + return c.CreateChatCompletionStream(ctx, params, messages, callback, output) + } } }