lmcli/pkg/provider/google/google.go

437 lines
9.9 KiB
Go
Raw Permalink Normal View History

2024-05-18 15:15:15 -06:00
package google
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
2024-05-18 15:15:15 -06:00
)
type Client struct {
APIKey string
BaseURL string
}
type ContentPart struct {
Text string `json:"text,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
}
type FunctionCall struct {
Name string `json:"name"`
Args map[string]string `json:"args"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response interface{} `json:"response"`
}
type Content struct {
Role string `json:"role"`
Parts []ContentPart `json:"parts"`
}
type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
SystemInstruction *Content `json:"systemInstruction,omitempty"`
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
}
type Candidate struct {
Content Content `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
}
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"`
UsageMetadata UsageMetadata `json:"usageMetadata"`
}
type Tool struct {
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
}
type FunctionDeclaration struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Values []string `json:"values,omitempty"`
}
func convertTools(tools []api.ToolSpec) []Tool {
2024-05-19 15:50:43 -06:00
geminiTools := make([]Tool, len(tools))
for i, tool := range tools {
2024-05-18 15:15:15 -06:00
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
2024-05-18 19:38:02 -06:00
// TODO: proper enum handing
2024-05-18 15:15:15 -06:00
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
2024-05-18 19:38:02 -06:00
Values: param.Enum,
2024-05-18 15:15:15 -06:00
}
if param.Required {
required = append(required, param.Name)
}
}
2024-05-22 23:53:13 -06:00
geminiTools[i] = Tool{
2024-05-18 15:15:15 -06:00
FunctionDeclarations: []FunctionDeclaration{
{
Name: tool.Name,
Description: tool.Description,
Parameters: ToolParameters{
Type: "OBJECT",
Properties: params,
Required: required,
},
},
},
2024-05-19 15:50:43 -06:00
}
2024-05-18 15:15:15 -06:00
}
return geminiTools
}
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
2024-05-18 15:15:15 -06:00
converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls {
args := make(map[string]string)
for k, v := range call.Parameters {
args[k] = fmt.Sprintf("%v", v)
}
converted[i].FunctionCall = &FunctionCall{
Name: call.Name,
Args: args,
}
}
return converted
}
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
converted := make([]api.ToolCall, len(functionCalls))
2024-05-18 20:59:43 -06:00
for i, call := range functionCalls {
params := make(map[string]interface{})
for k, v := range call.Args {
params[k] = v
2024-05-18 15:15:15 -06:00
}
2024-05-18 20:59:43 -06:00
converted[i].Name = call.Name
converted[i].Parameters = params
2024-05-18 15:15:15 -06:00
}
return converted
}
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
2024-05-18 19:38:02 -06:00
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
err := json.Unmarshal([]byte(result.Result), &obj)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
}
results[i] = FunctionResponse{
Name: result.ToolName,
Response: obj,
}
}
return results, nil
}
2024-05-18 15:15:15 -06:00
func createGenerateContentRequest(
params provider.RequestParameters,
messages []api.Message,
2024-05-18 19:38:02 -06:00
) (*GenerateContentRequest, error) {
2024-05-18 15:15:15 -06:00
requestContents := make([]Content, 0, len(messages))
2024-05-18 19:38:02 -06:00
startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
2024-05-18 19:38:02 -06:00
system = messages[0].Content
startIdx = 1
}
for _, m := range messages[startIdx:] {
2024-05-18 15:15:15 -06:00
switch m.Role {
case "tool_call":
content := Content{
Role: "model",
Parts: convertToolCallToGemini(m.ToolCalls),
}
requestContents = append(requestContents, content)
case "tool_result":
2024-05-18 19:38:02 -06:00
results, err := convertToolResultsToGemini(m.ToolResults)
if err != nil {
return nil, err
}
2024-05-18 15:15:15 -06:00
// expand tool_result messages' results into multiple gemini messages
2024-05-18 19:38:02 -06:00
for _, result := range results {
2024-05-18 15:15:15 -06:00
content := Content{
Role: "function",
Parts: []ContentPart{
{
2024-05-18 19:38:02 -06:00
FunctionResp: &result,
2024-05-18 15:15:15 -06:00
},
},
}
requestContents = append(requestContents, content)
}
default:
2024-05-18 17:18:53 -06:00
var role string
switch m.Role {
case api.MessageRoleAssistant:
2024-05-18 19:38:02 -06:00
role = "model"
case api.MessageRoleUser:
2024-05-18 19:38:02 -06:00
role = "user"
2024-05-18 17:18:53 -06:00
}
if role == "" {
panic("Unhandled role: " + m.Role)
}
2024-05-18 15:15:15 -06:00
content := Content{
2024-05-18 17:18:53 -06:00
Role: role,
2024-05-18 15:15:15 -06:00
Parts: []ContentPart{
{
Text: m.Content,
},
},
}
requestContents = append(requestContents, content)
}
}
2024-05-18 19:38:02 -06:00
request := &GenerateContentRequest{
2024-06-01 13:47:08 -06:00
Contents: requestContents,
2024-05-18 19:38:02 -06:00
GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature,
TopP: &params.TopP,
},
2024-05-18 15:15:15 -06:00
}
2024-06-01 13:47:08 -06:00
if system != "" {
request.SystemInstruction = &Content{
Parts: []ContentPart{
{
Text: system,
},
},
}
}
if len(params.Toolbox) > 0 {
request.Tools = convertTools(params.Toolbox)
2024-05-18 15:15:15 -06:00
}
2024-05-18 19:38:02 -06:00
return request, nil
2024-05-18 15:15:15 -06:00
}
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
2024-05-18 15:15:15 -06:00
client := &http.Client{}
resp, err := client.Do(req)
2024-05-18 15:15:15 -06:00
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
}
func (c *Client) CreateChatCompletion(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
2024-05-18 15:15:15 -06:00
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
2024-05-18 15:15:15 -06:00
}
2024-05-18 19:38:02 -06:00
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return nil, err
2024-05-18 19:38:02 -06:00
}
2024-05-18 15:15:15 -06:00
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
2024-05-18 17:18:53 -06:00
url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
2024-05-18 15:15:15 -06:00
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
resp, err := c.sendRequest(httpReq)
2024-05-18 15:15:15 -06:00
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
defer resp.Body.Close()
var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
choice := completionResp.Candidates[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content
}
2024-05-18 20:59:43 -06:00
var toolCalls []FunctionCall
2024-05-18 15:15:15 -06:00
for _, part := range choice.Content.Parts {
if part.Text != "" {
content += part.Text
}
2024-05-18 20:59:43 -06:00
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
}
2024-05-18 15:15:15 -06:00
}
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
2024-05-18 15:15:15 -06:00
}
return api.NewMessageWithAssistant(content), nil
2024-05-18 15:15:15 -06:00
}
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
params provider.RequestParameters,
messages []api.Message,
output chan<- provider.Chunk,
) (*api.Message, error) {
2024-05-18 15:15:15 -06:00
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
2024-05-18 15:15:15 -06:00
}
2024-05-18 19:38:02 -06:00
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return nil, err
2024-05-18 19:38:02 -06:00
}
2024-05-18 15:15:15 -06:00
jsonData, err := json.Marshal(req)
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
2024-05-18 17:18:53 -06:00
url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey,
)
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
2024-05-18 15:15:15 -06:00
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
resp, err := c.sendRequest(httpReq)
2024-05-18 15:15:15 -06:00
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
defer resp.Body.Close()
content := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
2024-05-18 20:59:43 -06:00
var toolCalls []FunctionCall
2024-05-18 19:38:02 -06:00
2024-05-18 15:15:15 -06:00
reader := bufio.NewReader(resp.Body)
lastTokenCount := 0
2024-05-18 15:15:15 -06:00
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
2024-05-18 15:15:15 -06:00
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
var resp GenerateContentResponse
err = json.Unmarshal(line, &resp)
2024-05-18 15:15:15 -06:00
if err != nil {
return nil, err
2024-05-18 15:15:15 -06:00
}
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
lastTokenCount += tokens
choice := resp.Candidates[0]
for _, part := range choice.Content.Parts {
if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" {
output <- provider.Chunk{
Content: part.Text,
TokenCount: uint(tokens),
2024-05-18 15:15:15 -06:00
}
content.WriteString(part.Text)
2024-05-18 15:15:15 -06:00
}
2024-05-18 19:38:02 -06:00
}
}
2024-05-18 17:18:53 -06:00
2024-05-18 19:38:02 -06:00
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
2024-05-18 15:15:15 -06:00
}
return api.NewMessageWithAssistant(content.String()), nil
2024-05-18 15:15:15 -06:00
}