Gemini fixes, tool calling
This commit is contained in:
parent
cbcd3b1ba9
commit
1b8d04c96d
@ -8,6 +8,7 @@ import (
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
@ -75,21 +76,28 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
anthropic := &anthropic.AnthropicClient{
|
||||
return &anthropic.AnthropicClient{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
}, nil
|
||||
case "google":
|
||||
url := "https://generativelanguage.googleapis.com"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
return anthropic, nil
|
||||
return &google.Client{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
}, nil
|
||||
case "openai":
|
||||
url := "https://api.openai.com/v1"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
openai := &openai.OpenAIClient{
|
||||
return &openai.OpenAIClient{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
}
|
||||
return openai, nil
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
||||
}
|
||||
|
@ -173,8 +173,20 @@ func createGenerateContentRequest(
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case model.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case model.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
panic("Unhandled role: " + m.Role)
|
||||
}
|
||||
|
||||
content := Content{
|
||||
Role: string(m.Role),
|
||||
Role: role,
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: m.Content,
|
||||
@ -242,7 +254,6 @@ func handleToolCalls(
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
@ -271,7 +282,11 @@ func (c *Client) CreateChatCompletion(
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -330,10 +345,6 @@ func (c *Client) CreateChatCompletionStream(
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
if len(params.ToolBag) > 0 {
|
||||
return "", fmt.Errorf("Tool calling is not supported in streaming mode.")
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
@ -344,7 +355,11 @@ func (c *Client) CreateChatCompletionStream(
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -386,12 +401,25 @@ func (c *Client) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
var toolCalls []model.ToolCall
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
||||
} else if part.Text != "" {
|
||||
output <- part.Text
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user