Gemini fixes, tool calling

This commit is contained in:
Matt Low 2024-05-18 23:18:53 +00:00
parent cbcd3b1ba9
commit 1b8d04c96d
2 changed files with 50 additions and 14 deletions

View File

@ -8,6 +8,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
@ -75,21 +76,28 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
anthropic := &anthropic.AnthropicClient{ return &anthropic.AnthropicClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != nil {
url = *p.BaseURL
} }
return anthropic, nil return &google.Client{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
case "openai": case "openai":
url := "https://api.openai.com/v1" url := "https://api.openai.com/v1"
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
openai := &openai.OpenAIClient{ return &openai.OpenAIClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
} }, nil
return openai, nil
default: default:
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
} }

View File

@ -173,8 +173,20 @@ func createGenerateContentRequest(
requestContents = append(requestContents, content) requestContents = append(requestContents, content)
} }
default: default:
var role string
switch m.Role {
case model.MessageRoleAssistant:
role = "model"
case model.MessageRoleUser:
role = "user"
}
if role == "" {
panic("Unhandled role: " + m.Role)
}
content := Content{ content := Content{
Role: string(m.Role), Role: role,
Parts: []ContentPart{ Parts: []ContentPart{
{ {
Text: m.Content, Text: m.Content,
@ -242,7 +254,6 @@ func handleToolCalls(
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx)) resp, err := client.Do(req.WithContext(ctx))
@ -271,7 +282,11 @@ func (c *Client) CreateChatCompletion(
return "", err return "", err
} }
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData)) url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -330,10 +345,6 @@ func (c *Client) CreateChatCompletionStream(
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
if len(params.ToolBag) > 0 {
return "", fmt.Errorf("Tool calling is not supported in streaming mode.")
}
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
@ -344,7 +355,11 @@ func (c *Client) CreateChatCompletionStream(
return "", err return "", err
} }
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData)) url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -386,12 +401,25 @@ func (c *Client) CreateChatCompletionStream(
} }
for _, candidate := range streamResp.Candidates { for _, candidate := range streamResp.Candidates {
var toolCalls []model.ToolCall
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.Text != "" { if part.FunctionCall != nil {
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
} else if part.Text != "" {
output <- part.Text output <- part.Text
content.WriteString(part.Text) content.WriteString(part.Text)
} }
} }
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
if err != nil {
return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
} }
} }