Gemini fixes, tool calling
This commit is contained in:
parent
cbcd3b1ba9
commit
1b8d04c96d
@ -8,6 +8,7 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
@ -75,21 +76,28 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
|
|||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
anthropic := &anthropic.AnthropicClient{
|
return &anthropic.AnthropicClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
|
}, nil
|
||||||
|
case "google":
|
||||||
|
url := "https://generativelanguage.googleapis.com"
|
||||||
|
if p.BaseURL != nil {
|
||||||
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
return anthropic, nil
|
return &google.Client{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: *p.APIKey,
|
||||||
|
}, nil
|
||||||
case "openai":
|
case "openai":
|
||||||
url := "https://api.openai.com/v1"
|
url := "https://api.openai.com/v1"
|
||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
openai := &openai.OpenAIClient{
|
return &openai.OpenAIClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: *p.APIKey,
|
APIKey: *p.APIKey,
|
||||||
}
|
}, nil
|
||||||
return openai, nil
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
||||||
}
|
}
|
||||||
|
@ -173,8 +173,20 @@ func createGenerateContentRequest(
|
|||||||
requestContents = append(requestContents, content)
|
requestContents = append(requestContents, content)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
var role string
|
||||||
|
switch m.Role {
|
||||||
|
case model.MessageRoleAssistant:
|
||||||
|
role = "model"
|
||||||
|
case model.MessageRoleUser:
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
|
||||||
|
if role == "" {
|
||||||
|
panic("Unhandled role: " + m.Role)
|
||||||
|
}
|
||||||
|
|
||||||
content := Content{
|
content := Content{
|
||||||
Role: string(m.Role),
|
Role: role,
|
||||||
Parts: []ContentPart{
|
Parts: []ContentPart{
|
||||||
{
|
{
|
||||||
Text: m.Content,
|
Text: m.Content,
|
||||||
@ -242,7 +254,6 @@ func handleToolCalls(
|
|||||||
|
|
||||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req.WithContext(ctx))
|
resp, err := client.Do(req.WithContext(ctx))
|
||||||
@ -271,7 +282,11 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
|
url := fmt.Sprintf(
|
||||||
|
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||||
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
|
)
|
||||||
|
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -330,10 +345,6 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
if len(params.ToolBag) > 0 {
|
|
||||||
return "", fmt.Errorf("Tool calling is not supported in streaming mode.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
@ -344,7 +355,11 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
|
url := fmt.Sprintf(
|
||||||
|
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||||
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
|
)
|
||||||
|
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -386,12 +401,25 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, candidate := range streamResp.Candidates {
|
for _, candidate := range streamResp.Candidates {
|
||||||
|
var toolCalls []model.ToolCall
|
||||||
|
|
||||||
for _, part := range candidate.Content.Parts {
|
for _, part := range candidate.Content.Parts {
|
||||||
if part.Text != "" {
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
||||||
|
} else if part.Text != "" {
|
||||||
output <- part.Text
|
output <- part.Text
|
||||||
content.WriteString(part.Text)
|
content.WriteString(part.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there are function calls, handle them and recurse
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||||
|
if err != nil {
|
||||||
|
return content.String(), err
|
||||||
|
}
|
||||||
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user