From cbcd3b1ba9e29a8ab68f670ea8f11d275b7df517 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 18 May 2024 21:15:15 +0000 Subject: [PATCH] Gemini WIP --- pkg/lmcli/provider/google/google.go | 406 ++++++++++++++++++++++++++++ 1 file changed, 406 insertions(+) create mode 100644 pkg/lmcli/provider/google/google.go diff --git a/pkg/lmcli/provider/google/google.go b/pkg/lmcli/provider/google/google.go new file mode 100644 index 0000000..1bf2a0c --- /dev/null +++ b/pkg/lmcli/provider/google/google.go @@ -0,0 +1,406 @@ +package google + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" +) + +type Client struct { + APIKey string + BaseURL string +} + +type ContentPart struct { + Text string `json:"text,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` + FunctionResp *FunctionResponse `json:"functionResponse,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Args map[string]string `json:"args"` +} + +type FunctionResponse struct { + Name string `json:"name"` + Response interface{} `json:"response"` +} + +type Content struct { + Role string `json:"role"` + Parts []ContentPart `json:"parts"` +} + +type GenerateContentRequest struct { + Contents []Content `json:"contents"` + Tools []Tool `json:"tools,omitempty"` +} + +type Candidate struct { + Content Content `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` +} + +type GenerateContentResponse struct { + Candidates []Candidate `json:"candidates"` +} + +type Tool struct { + FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"` +} + +type FunctionDeclaration struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters ToolParameters `json:"parameters"` +} + +type ToolParameters struct { + Type string `json:"type"` + Properties map[string]ToolParameter `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type ToolParameter struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +func convertTools(tools []model.Tool) []Tool { + geminiTools := make([]Tool, 0, len(tools)) + for _, tool := range tools { + params := make(map[string]ToolParameter) + var required []string + + for _, param := range tool.Parameters { + params[param.Name] = ToolParameter{ + Type: param.Type, + Description: param.Description, + Enum: param.Enum, + } + if param.Required { + required = append(required, param.Name) + } + } + + geminiTools = append(geminiTools, Tool{ + FunctionDeclarations: []FunctionDeclaration{ + { + Name: tool.Name, + Description: tool.Description, + Parameters: ToolParameters{ + Type: "OBJECT", + Properties: params, + Required: required, + }, + }, + }, + }) + } + return geminiTools +} + +func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart { + converted := make([]ContentPart, len(toolCalls)) + for i, call := range toolCalls { + args := make(map[string]string) + for k, v := range call.Parameters { + args[k] = fmt.Sprintf("%v", v) + } + converted[i].FunctionCall = &FunctionCall{ + Name: call.Name, + Args: args, + } + } + return converted +} + +func convertToolCallToAPI(parts []ContentPart) []model.ToolCall { + converted := make([]model.ToolCall, len(parts)) + for i, part := range parts { + if part.FunctionCall != nil { + params := make(map[string]interface{}) + for k, v := range part.FunctionCall.Args { + params[k] = v + } + converted[i].Name = part.FunctionCall.Name + converted[i].Parameters = params + } + } + return converted +} + +func createGenerateContentRequest( + params model.RequestParameters, + messages []model.Message, +) GenerateContentRequest { + requestContents := make([]Content, 0, len(messages)) + + for _, m := range messages { + switch m.Role { + case "tool_call": + content := Content{ + Role: "model", + Parts: convertToolCallToGemini(m.ToolCalls), + } + requestContents = append(requestContents, content) + case "tool_result": + // expand tool_result messages' results into multiple gemini messages + for _, result := range m.ToolResults { + content := Content{ + Role: "function", + Parts: []ContentPart{ + { + FunctionResp: &FunctionResponse{ + Name: result.ToolCallID, + Response: result.Result, + }, + }, + }, + } + requestContents = append(requestContents, content) + } + default: + content := Content{ + Role: string(m.Role), + Parts: []ContentPart{ + { + Text: m.Content, + }, + }, + } + requestContents = append(requestContents, content) + } + } + + request := GenerateContentRequest{ + Contents: requestContents, + } + + if len(params.ToolBag) > 0 { + request.Tools = convertTools(params.ToolBag) + } + + return request +} + +func handleToolCalls( + params model.RequestParameters, + content string, + toolCalls []model.ToolCall, + callback provider.ReplyCallback, + messages []model.Message, +) ([]model.Message, error) { + lastMessage := messages[len(messages)-1] + continuation := false + if lastMessage.Role.IsAssistant() { + continuation = true + } + + toolCall := model.Message{ + Role: model.MessageRoleToolCall, + Content: content, + ToolCalls: toolCalls, + } + + toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) + if err != nil { + return nil, err + } + + toolResult := model.Message{ + Role: model.MessageRoleToolResult, + ToolResults: toolResults, + } + + if callback != nil { + callback(toolCall) + callback(toolResult) + } + + if continuation { + messages[len(messages)-1] = toolCall + } else { + messages = append(messages, toolCall) + } + messages = append(messages, toolResult) + + return messages, nil +} + +func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + client := &http.Client{} + resp, err := client.Do(req.WithContext(ctx)) + + if resp.StatusCode != 200 { + bytes, _ := io.ReadAll(resp.Body) + return resp, fmt.Errorf("%v", string(bytes)) + } + + return resp, err +} + +func (c *Client) CreateChatCompletion( + ctx context.Context, + params model.RequestParameters, + messages []model.Message, + callback provider.ReplyCallback, +) (string, error) { + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + + req := createGenerateContentRequest(params, messages) + jsonData, err := json.Marshal(req) + if err != nil { + return "", err + } + + httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:generateContent", c.BaseURL, params.Model), bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + + resp, err := c.sendRequest(ctx, httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var completionResp GenerateContentResponse + err = json.NewDecoder(resp.Body).Decode(&completionResp) + if err != nil { + return "", err + } + + choice := completionResp.Candidates[0] + + var content string + lastMessage := messages[len(messages)-1] + if lastMessage.Role.IsAssistant() { + content = lastMessage.Content + } + + for _, part := range choice.Content.Parts { + if part.Text != "" { + content += part.Text + } + } + + toolCalls := convertToolCallToAPI(choice.Content.Parts) + if len(toolCalls) > 0 { + messages, err := handleToolCalls(params, content, toolCalls, callback, messages) + if err != nil { + return content, err + } + + return c.CreateChatCompletion(ctx, params, messages, callback) + } + + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: content, + }) + } + + // Return the user-facing message. + return content, nil +} + +func (c *Client) CreateChatCompletionStream( + ctx context.Context, + params model.RequestParameters, + messages []model.Message, + callback provider.ReplyCallback, + output chan<- string, +) (string, error) { + if len(params.ToolBag) > 0 { + return "", fmt.Errorf("Tool calling is not supported in streaming mode.") + } + + if len(messages) == 0 { + return "", fmt.Errorf("Can't create completion from no messages") + } + + req := createGenerateContentRequest(params, messages) + jsonData, err := json.Marshal(req) + if err != nil { + return "", err + } + + httpReq, err := http.NewRequest("POST", fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", c.BaseURL, params.Model), bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + + resp, err := c.sendRequest(ctx, httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + content := strings.Builder{} + + lastMessage := messages[len(messages)-1] + if lastMessage.Role.IsAssistant() { + content.WriteString(lastMessage.Content) + } + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + return "", err + } + + line = bytes.TrimSpace(line) + if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { + continue + } + + line = bytes.TrimPrefix(line, []byte("data: ")) + + var streamResp GenerateContentResponse + err = json.Unmarshal(line, &streamResp) + if err != nil { + return "", err + } + + for _, candidate := range streamResp.Candidates { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + output <- part.Text + content.WriteString(part.Text) + } + } + } + } + + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: content.String(), + }) + } + + return content.String(), nil +}