From a291e7b42c0b2260dc13ef9f62a588211c5cf187 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 19 May 2024 01:38:02 +0000 Subject: [PATCH] Gemini cleanup, tool calling working --- pkg/lmcli/provider/google/google.go | 115 ++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 31 deletions(-) diff --git a/pkg/lmcli/provider/google/google.go b/pkg/lmcli/provider/google/google.go index 88eb7bc..cad3d0c 100644 --- a/pkg/lmcli/provider/google/google.go +++ b/pkg/lmcli/provider/google/google.go @@ -41,9 +41,18 @@ type Content struct { Parts []ContentPart `json:"parts"` } +type GenerationConfig struct { + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"topP,omitempty"` + TopK *int `json:"topK,omitempty"` +} + type GenerateContentRequest struct { - Contents []Content `json:"contents"` - Tools []Tool `json:"tools,omitempty"` + Contents []Content `json:"contents"` + Tools []Tool `json:"tools,omitempty"` + SystemInstructions string `json:"systemInstructions,omitempty"` + GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"` } type Candidate struct { @@ -52,8 +61,15 @@ type Candidate struct { Index int `json:"index"` } +type UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` +} + type GenerateContentResponse struct { - Candidates []Candidate `json:"candidates"` + Candidates []Candidate `json:"candidates"` + UsageMetadata UsageMetadata `json:"usageMetadata"` } type Tool struct { @@ -75,7 +91,7 @@ type ToolParameters struct { type ToolParameter struct { Type string `json:"type"` Description string `json:"description"` - Enum []string `json:"enum,omitempty"` + Values []string `json:"values,omitempty"` } func convertTools(tools []model.Tool) []Tool { @@ -85,10 +101,11 @@ func convertTools(tools []model.Tool) []Tool { var required []string for _, param := range tool.Parameters { + // TODO: proper enum handing params[param.Name] = ToolParameter{ Type: param.Type, Description: param.Description, - Enum: param.Enum, + Values: param.Enum, } if param.Required { required = append(required, param.Name) @@ -142,13 +159,36 @@ func convertToolCallToAPI(parts []ContentPart) []model.ToolCall { return converted } +func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) { + results := make([]FunctionResponse, len(toolResults)) + for i, result := range toolResults { + var obj interface{} + err := json.Unmarshal([]byte(result.Result), &obj) + if err != nil { + return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err) + } + results[i] = FunctionResponse{ + Name: result.ToolName, + Response: obj, + } + } + return results, nil +} + func createGenerateContentRequest( params model.RequestParameters, messages []model.Message, -) GenerateContentRequest { +) (*GenerateContentRequest, error) { requestContents := make([]Content, 0, len(messages)) - for _, m := range messages { + startIdx := 0 + var system string + if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { + system = messages[0].Content + startIdx = 1 + } + + for _, m := range messages[startIdx:] { switch m.Role { case "tool_call": content := Content{ @@ -157,16 +197,17 @@ func createGenerateContentRequest( } requestContents = append(requestContents, content) case "tool_result": + results, err := convertToolResultsToGemini(m.ToolResults) + if err != nil { + return nil, err + } // expand tool_result messages' results into multiple gemini messages - for _, result := range m.ToolResults { + for _, result := range results { content := Content{ Role: "function", Parts: []ContentPart{ { - FunctionResp: &FunctionResponse{ - Name: result.ToolCallID, - Response: result.Result, - }, + FunctionResp: &result, }, }, } @@ -175,10 +216,10 @@ func createGenerateContentRequest( default: var role string switch m.Role { - case model.MessageRoleAssistant: - role = "model" - case model.MessageRoleUser: - role = "user" + case model.MessageRoleAssistant: + role = "model" + case model.MessageRoleUser: + role = "user" } if role == "" { @@ -197,15 +238,21 @@ func createGenerateContentRequest( } } - request := GenerateContentRequest{ + request := &GenerateContentRequest{ Contents: requestContents, + SystemInstructions: system, + GenerationConfig: &GenerationConfig{ + MaxOutputTokens: ¶ms.MaxTokens, + Temperature: ¶ms.Temperature, + TopP: ¶ms.TopP, + }, } if len(params.ToolBag) > 0 { request.Tools = convertTools(params.ToolBag) } - return request + return request, nil } func handleToolCalls( @@ -276,7 +323,10 @@ func (c *Client) CreateChatCompletion( return "", fmt.Errorf("Can't create completion from no messages") } - req := createGenerateContentRequest(params, messages) + req, err := createGenerateContentRequest(params, messages) + if err != nil { + return "", err + } jsonData, err := json.Marshal(req) if err != nil { return "", err @@ -349,7 +399,10 @@ func (c *Client) CreateChatCompletionStream( return "", fmt.Errorf("Can't create completion from no messages") } - req := createGenerateContentRequest(params, messages) + req, err := createGenerateContentRequest(params, messages) + if err != nil { + return "", err + } jsonData, err := json.Marshal(req) if err != nil { return "", err @@ -377,6 +430,8 @@ func (c *Client) CreateChatCompletionStream( content.WriteString(lastMessage.Content) } + var toolCalls []model.ToolCall + reader := bufio.NewReader(resp.Body) for { line, err := reader.ReadBytes('\n') @@ -401,8 +456,6 @@ func (c *Client) CreateChatCompletionStream( } for _, candidate := range streamResp.Candidates { - var toolCalls []model.ToolCall - for _, part := range candidate.Content.Parts { if part.FunctionCall != nil { toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...) @@ -411,18 +464,18 @@ func (c *Client) CreateChatCompletionStream( content.WriteString(part.Text) } } - - // If there are function calls, handle them and recurse - if len(toolCalls) > 0 { - messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) - if err != nil { - return content.String(), err - } - return c.CreateChatCompletionStream(ctx, params, messages, callback, output) - } } } + // If there are function calls, handle them and recurse + if len(toolCalls) > 0 { + messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) + if err != nil { + return content.String(), err + } + return c.CreateChatCompletionStream(ctx, params, messages, callback, output) + } + if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant,