2024-05-18 15:15:15 -06:00
|
|
|
package google
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"bytes"
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
|
2024-06-09 10:42:53 -06:00
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
2024-05-18 15:15:15 -06:00
|
|
|
)
|
|
|
|
|
2024-06-22 19:48:31 -06:00
|
|
|
type Client struct {
|
|
|
|
APIKey string
|
|
|
|
BaseURL string
|
|
|
|
}
|
|
|
|
|
|
|
|
type ContentPart struct {
|
|
|
|
Text string `json:"text,omitempty"`
|
|
|
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
|
|
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type FunctionCall struct {
|
|
|
|
Name string `json:"name"`
|
|
|
|
Args map[string]string `json:"args"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type FunctionResponse struct {
|
|
|
|
Name string `json:"name"`
|
|
|
|
Response interface{} `json:"response"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Content struct {
|
|
|
|
Role string `json:"role"`
|
|
|
|
Parts []ContentPart `json:"parts"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type GenerationConfig struct {
|
|
|
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
|
|
Temperature *float32 `json:"temperature,omitempty"`
|
|
|
|
TopP *float32 `json:"topP,omitempty"`
|
|
|
|
TopK *int `json:"topK,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type GenerateContentRequest struct {
|
|
|
|
Contents []Content `json:"contents"`
|
|
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
|
|
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
|
|
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Candidate struct {
|
|
|
|
Content Content `json:"content"`
|
|
|
|
FinishReason string `json:"finishReason"`
|
|
|
|
Index int `json:"index"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type UsageMetadata struct {
|
|
|
|
PromptTokenCount int `json:"promptTokenCount"`
|
|
|
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
|
|
TotalTokenCount int `json:"totalTokenCount"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type GenerateContentResponse struct {
|
|
|
|
Candidates []Candidate `json:"candidates"`
|
|
|
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Tool struct {
|
|
|
|
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type FunctionDeclaration struct {
|
|
|
|
Name string `json:"name"`
|
|
|
|
Description string `json:"description"`
|
|
|
|
Parameters ToolParameters `json:"parameters"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolParameters struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
|
|
Required []string `json:"required,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolParameter struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
Description string `json:"description"`
|
|
|
|
Values []string `json:"values,omitempty"`
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertTools(tools []api.ToolSpec) []Tool {
|
2024-05-19 15:50:43 -06:00
|
|
|
geminiTools := make([]Tool, len(tools))
|
|
|
|
for i, tool := range tools {
|
2024-05-18 15:15:15 -06:00
|
|
|
params := make(map[string]ToolParameter)
|
|
|
|
var required []string
|
|
|
|
|
|
|
|
for _, param := range tool.Parameters {
|
2024-05-18 19:38:02 -06:00
|
|
|
// TODO: proper enum handing
|
2024-05-18 15:15:15 -06:00
|
|
|
params[param.Name] = ToolParameter{
|
|
|
|
Type: param.Type,
|
|
|
|
Description: param.Description,
|
2024-05-18 19:38:02 -06:00
|
|
|
Values: param.Enum,
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
if param.Required {
|
|
|
|
required = append(required, param.Name)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-22 23:53:13 -06:00
|
|
|
geminiTools[i] = Tool{
|
2024-05-18 15:15:15 -06:00
|
|
|
FunctionDeclarations: []FunctionDeclaration{
|
|
|
|
{
|
|
|
|
Name: tool.Name,
|
|
|
|
Description: tool.Description,
|
|
|
|
Parameters: ToolParameters{
|
|
|
|
Type: "OBJECT",
|
|
|
|
Properties: params,
|
|
|
|
Required: required,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2024-05-19 15:50:43 -06:00
|
|
|
}
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
return geminiTools
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
2024-05-18 15:15:15 -06:00
|
|
|
converted := make([]ContentPart, len(toolCalls))
|
|
|
|
for i, call := range toolCalls {
|
|
|
|
args := make(map[string]string)
|
|
|
|
for k, v := range call.Parameters {
|
|
|
|
args[k] = fmt.Sprintf("%v", v)
|
|
|
|
}
|
|
|
|
converted[i].FunctionCall = &FunctionCall{
|
|
|
|
Name: call.Name,
|
|
|
|
Args: args,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return converted
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
|
|
|
converted := make([]api.ToolCall, len(functionCalls))
|
2024-05-18 20:59:43 -06:00
|
|
|
for i, call := range functionCalls {
|
|
|
|
params := make(map[string]interface{})
|
|
|
|
for k, v := range call.Args {
|
|
|
|
params[k] = v
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
2024-05-18 20:59:43 -06:00
|
|
|
converted[i].Name = call.Name
|
|
|
|
converted[i].Parameters = params
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
return converted
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
2024-05-18 19:38:02 -06:00
|
|
|
results := make([]FunctionResponse, len(toolResults))
|
|
|
|
for i, result := range toolResults {
|
|
|
|
var obj interface{}
|
|
|
|
err := json.Unmarshal([]byte(result.Result), &obj)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
|
|
|
}
|
|
|
|
results[i] = FunctionResponse{
|
|
|
|
Name: result.ToolName,
|
|
|
|
Response: obj,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return results, nil
|
|
|
|
}
|
|
|
|
|
2024-05-18 15:15:15 -06:00
|
|
|
func createGenerateContentRequest(
|
2024-06-12 02:35:07 -06:00
|
|
|
params api.RequestParameters,
|
|
|
|
messages []api.Message,
|
2024-05-18 19:38:02 -06:00
|
|
|
) (*GenerateContentRequest, error) {
|
2024-05-18 15:15:15 -06:00
|
|
|
requestContents := make([]Content, 0, len(messages))
|
|
|
|
|
2024-05-18 19:38:02 -06:00
|
|
|
startIdx := 0
|
|
|
|
var system string
|
2024-06-12 02:35:07 -06:00
|
|
|
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
2024-05-18 19:38:02 -06:00
|
|
|
system = messages[0].Content
|
|
|
|
startIdx = 1
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, m := range messages[startIdx:] {
|
2024-05-18 15:15:15 -06:00
|
|
|
switch m.Role {
|
|
|
|
case "tool_call":
|
|
|
|
content := Content{
|
|
|
|
Role: "model",
|
|
|
|
Parts: convertToolCallToGemini(m.ToolCalls),
|
|
|
|
}
|
|
|
|
requestContents = append(requestContents, content)
|
|
|
|
case "tool_result":
|
2024-05-18 19:38:02 -06:00
|
|
|
results, err := convertToolResultsToGemini(m.ToolResults)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2024-05-18 15:15:15 -06:00
|
|
|
// expand tool_result messages' results into multiple gemini messages
|
2024-05-18 19:38:02 -06:00
|
|
|
for _, result := range results {
|
2024-05-18 15:15:15 -06:00
|
|
|
content := Content{
|
|
|
|
Role: "function",
|
|
|
|
Parts: []ContentPart{
|
|
|
|
{
|
2024-05-18 19:38:02 -06:00
|
|
|
FunctionResp: &result,
|
2024-05-18 15:15:15 -06:00
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
requestContents = append(requestContents, content)
|
|
|
|
}
|
|
|
|
default:
|
2024-05-18 17:18:53 -06:00
|
|
|
var role string
|
|
|
|
switch m.Role {
|
2024-06-12 02:35:07 -06:00
|
|
|
case api.MessageRoleAssistant:
|
2024-05-18 19:38:02 -06:00
|
|
|
role = "model"
|
2024-06-12 02:35:07 -06:00
|
|
|
case api.MessageRoleUser:
|
2024-05-18 19:38:02 -06:00
|
|
|
role = "user"
|
2024-05-18 17:18:53 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
if role == "" {
|
|
|
|
panic("Unhandled role: " + m.Role)
|
|
|
|
}
|
|
|
|
|
2024-05-18 15:15:15 -06:00
|
|
|
content := Content{
|
2024-05-18 17:18:53 -06:00
|
|
|
Role: role,
|
2024-05-18 15:15:15 -06:00
|
|
|
Parts: []ContentPart{
|
|
|
|
{
|
|
|
|
Text: m.Content,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
requestContents = append(requestContents, content)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-18 19:38:02 -06:00
|
|
|
request := &GenerateContentRequest{
|
2024-06-01 13:47:08 -06:00
|
|
|
Contents: requestContents,
|
2024-05-18 19:38:02 -06:00
|
|
|
GenerationConfig: &GenerationConfig{
|
|
|
|
MaxOutputTokens: ¶ms.MaxTokens,
|
|
|
|
Temperature: ¶ms.Temperature,
|
|
|
|
TopP: ¶ms.TopP,
|
|
|
|
},
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-01 13:47:08 -06:00
|
|
|
if system != "" {
|
|
|
|
request.SystemInstruction = &Content{
|
|
|
|
Parts: []ContentPart{
|
|
|
|
{
|
|
|
|
Text: system,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-23 13:10:03 -06:00
|
|
|
if len(params.Toolbox) > 0 {
|
|
|
|
request.Tools = convertTools(params.Toolbox)
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-05-18 19:38:02 -06:00
|
|
|
return request, nil
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
2024-05-18 15:15:15 -06:00
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
client := &http.Client{}
|
|
|
|
resp, err := client.Do(req)
|
2024-05-18 15:15:15 -06:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
if resp.StatusCode != 200 {
|
|
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Client) CreateChatCompletion(
|
|
|
|
ctx context.Context,
|
2024-06-12 02:35:07 -06:00
|
|
|
params api.RequestParameters,
|
|
|
|
messages []api.Message,
|
|
|
|
) (*api.Message, error) {
|
2024-05-18 15:15:15 -06:00
|
|
|
if len(messages) == 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-05-18 19:38:02 -06:00
|
|
|
req, err := createGenerateContentRequest(params, messages)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 19:38:02 -06:00
|
|
|
}
|
2024-05-18 15:15:15 -06:00
|
|
|
jsonData, err := json.Marshal(req)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-05-18 17:18:53 -06:00
|
|
|
url := fmt.Sprintf(
|
|
|
|
"%s/v1beta/models/%s:generateContent?key=%s",
|
|
|
|
c.BaseURL, params.Model, c.APIKey,
|
|
|
|
)
|
2024-06-12 02:35:07 -06:00
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
2024-05-18 15:15:15 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
resp, err := c.sendRequest(httpReq)
|
2024-05-18 15:15:15 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
var completionResp GenerateContentResponse
|
|
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
choice := completionResp.Candidates[0]
|
|
|
|
|
|
|
|
var content string
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
|
|
if lastMessage.Role.IsAssistant() {
|
|
|
|
content = lastMessage.Content
|
|
|
|
}
|
|
|
|
|
2024-05-18 20:59:43 -06:00
|
|
|
var toolCalls []FunctionCall
|
2024-05-18 15:15:15 -06:00
|
|
|
for _, part := range choice.Content.Parts {
|
|
|
|
if part.Text != "" {
|
|
|
|
content += part.Text
|
|
|
|
}
|
2024-05-18 20:59:43 -06:00
|
|
|
|
|
|
|
if part.FunctionCall != nil {
|
|
|
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
|
|
|
}
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
if len(toolCalls) > 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleToolCall,
|
|
|
|
Content: content,
|
|
|
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
|
|
}, nil
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleAssistant,
|
|
|
|
Content: content,
|
|
|
|
}, nil
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Client) CreateChatCompletionStream(
|
|
|
|
ctx context.Context,
|
2024-06-12 02:35:07 -06:00
|
|
|
params api.RequestParameters,
|
|
|
|
messages []api.Message,
|
2024-06-09 10:42:53 -06:00
|
|
|
output chan<- api.Chunk,
|
2024-06-12 02:35:07 -06:00
|
|
|
) (*api.Message, error) {
|
2024-05-18 15:15:15 -06:00
|
|
|
if len(messages) == 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-05-18 19:38:02 -06:00
|
|
|
req, err := createGenerateContentRequest(params, messages)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 19:38:02 -06:00
|
|
|
}
|
2024-05-18 15:15:15 -06:00
|
|
|
jsonData, err := json.Marshal(req)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-05-18 17:18:53 -06:00
|
|
|
url := fmt.Sprintf(
|
|
|
|
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
|
|
|
c.BaseURL, params.Model, c.APIKey,
|
|
|
|
)
|
2024-06-12 02:35:07 -06:00
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
2024-05-18 15:15:15 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
resp, err := c.sendRequest(httpReq)
|
2024-05-18 15:15:15 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
content := strings.Builder{}
|
|
|
|
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
|
|
if lastMessage.Role.IsAssistant() {
|
|
|
|
content.WriteString(lastMessage.Content)
|
|
|
|
}
|
|
|
|
|
2024-05-18 20:59:43 -06:00
|
|
|
var toolCalls []FunctionCall
|
2024-05-18 19:38:02 -06:00
|
|
|
|
2024-05-18 15:15:15 -06:00
|
|
|
reader := bufio.NewReader(resp.Body)
|
2024-06-09 14:45:18 -06:00
|
|
|
|
|
|
|
lastTokenCount := 0
|
2024-05-18 15:15:15 -06:00
|
|
|
for {
|
|
|
|
line, err := reader.ReadBytes('\n')
|
|
|
|
if err != nil {
|
|
|
|
if err == io.EOF {
|
|
|
|
break
|
|
|
|
}
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
line = bytes.TrimSpace(line)
|
|
|
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
|
|
|
2024-06-09 14:45:18 -06:00
|
|
|
var resp GenerateContentResponse
|
|
|
|
err = json.Unmarshal(line, &resp)
|
2024-05-18 15:15:15 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-09 14:45:18 -06:00
|
|
|
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
|
|
|
lastTokenCount += tokens
|
|
|
|
|
|
|
|
choice := resp.Candidates[0]
|
|
|
|
for _, part := range choice.Content.Parts {
|
|
|
|
if part.FunctionCall != nil {
|
|
|
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
|
|
|
} else if part.Text != "" {
|
|
|
|
output <- api.Chunk{
|
|
|
|
Content: part.Text,
|
|
|
|
TokenCount: uint(tokens),
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
2024-06-09 14:45:18 -06:00
|
|
|
content.WriteString(part.Text)
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
2024-05-18 19:38:02 -06:00
|
|
|
}
|
|
|
|
}
|
2024-05-18 17:18:53 -06:00
|
|
|
|
2024-05-18 19:38:02 -06:00
|
|
|
// If there are function calls, handle them and recurse
|
|
|
|
if len(toolCalls) > 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleToolCall,
|
|
|
|
Content: content.String(),
|
|
|
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
|
|
}, nil
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleAssistant,
|
|
|
|
Content: content.String(),
|
|
|
|
}, nil
|
2024-05-18 15:15:15 -06:00
|
|
|
}
|