Gemini WIP
This commit is contained in:
parent
75bf9f6125
commit
cbcd3b1ba9
406
pkg/lmcli/provider/google/google.go
Normal file
406
pkg/lmcli/provider/google/google.go
Normal file
@ -0,0 +1,406 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]string `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []ContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type GenerateContentResponse struct {
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type FunctionDeclaration struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
geminiTools := make([]Tool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
geminiTools = append(geminiTools, Tool{
|
||||
FunctionDeclarations: []FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "OBJECT",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||
converted := make([]ContentPart, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
args := make(map[string]string)
|
||||
for k, v := range call.Parameters {
|
||||
args[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
converted[i].FunctionCall = &FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(parts))
|
||||
for i, part := range parts {
|
||||
if part.FunctionCall != nil {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range part.FunctionCall.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = part.FunctionCall.Name
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) GenerateContentRequest {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
content := Content{
|
||||
Role: "model",
|
||||
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
case "tool_result":
|
||||
// expand tool_result messages' results into multiple gemini messages
|
||||
for _, result := range m.ToolResults {
|
||||
content := Content{
|
||||
Role: "function",
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
FunctionResp: &FunctionResponse{
|
||||
Name: result.ToolCallID,
|
||||
Response: result.Result,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
default:
|
||||
content := Content{
|
||||
Role: string(m.Role),
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: m.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
}
|
||||
|
||||
request := GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []model.ToolCall,
|
||||
callback provider.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createGenerateContentRequest(params, messages)
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp GenerateContentResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := completionResp.Candidates[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content
|
||||
}
|
||||
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := convertToolCallToAPI(choice.Content.Parts)
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
if len(params.ToolBag) > 0 {
|
||||
return "", fmt.Errorf("Tool calling is not supported in streaming mode.")
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createGenerateContentRequest(params, messages)
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamResp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
output <- part.Text
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user