Gemini WIP

This commit is contained in:
Matt Low 2024-05-18 21:15:15 +00:00
parent 75bf9f6125
commit cbcd3b1ba9
1 changed files with 406 additions and 0 deletions

View File

@ -0,0 +1,406 @@
package google
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
type Client struct {
APIKey string
BaseURL string
}
type ContentPart struct {
Text string `json:"text,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Args map[string]string `json:"args"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response interface{} `json:"response"`
}
type Content struct {
Role string `json:"role"`
Parts []ContentPart `json:"parts"`
}
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
}
type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
}
type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"`
}
type Tool struct {
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
}
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
func convertTools(tools []model.Tool) []Tool {
geminiTools := make([]Tool, 0, len(tools))
for _, tool := range tools {
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
geminiTools = append(geminiTools, Tool{
FunctionDeclarations: []FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "OBJECT",
Properties: params,
Required: required,
},
},
},
})
}
return geminiTools
}
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls {
args := make(map[string]string)
for k, v := range call.Parameters {
args[k] = fmt.Sprintf("%v", v)
}
converted[i].FunctionCall = &FunctionCall{
Name: call.Name,
Args: args,
}
}
return converted
}
func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
converted := make([]model.ToolCall, len(parts))
for i, part := range parts {
if part.FunctionCall != nil {
params := make(map[string]interface{})
for k, v := range part.FunctionCall.Args {
params[k] = v
}
converted[i].Name = part.FunctionCall.Name
converted[i].Parameters = params
}
}
return converted
}
func createGenerateContentRequest(
params model.RequestParameters,
messages []model.Message,
) GenerateContentRequest {
requestContents := make([]Content, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
content := Content{
Role: "model",
Parts: convertToolCallToGemini(m.ToolCalls),
}
requestContents = append(requestContents, content)
case "tool_result":
// expand tool_result messages' results into multiple gemini messages
for _, result := range m.ToolResults {
content := Content{
Role: "function",
Parts: []ContentPart{
{
FunctionResp: &FunctionResponse{
Name: result.ToolCallID,
Response: result.Result,
},
},
},
}
requestContents = append(requestContents, content)
}
default:
content := Content{
Role: string(m.Role),
Parts: []ContentPart{
{
Text: m.Content,
},
},
}
requestContents = append(requestContents, content)
}
}
request := GenerateContentRequest{
Contents: requestContents,
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
}
return request
}
func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []model.ToolCall,
callback provider.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx))
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *Client) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
req := createGenerateContentRequest(params, messages)
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
}
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return "", err
}
choice := completionResp.Candidates[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content
}
for _, part := range choice.Content.Parts {
if part.Text != "" {
content += part.Text
}
}
toolCalls := convertToolCallToAPI(choice.Content.Parts)
if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
if err != nil {
return content, err
}
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content,
})
}
// Return the user-facing message.
return content, nil
}
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
if len(params.ToolBag) > 0 {
return "", fmt.Errorf("Tool calling is not supported in streaming mode.")
}
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
req := createGenerateContentRequest(params, messages)
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
}
httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
content := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return "", err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var streamResp GenerateContentResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return "", err
}
for _, candidate := range streamResp.Candidates {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
output <- part.Text
content.WriteString(part.Text)
}
}
}
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
return content.String(), nil
}