Gemini cleanup, tool calling working

This commit is contained in:
Matt Low 2024-05-19 01:38:02 +00:00
parent 1b8d04c96d
commit a291e7b42c

View File

@ -41,9 +41,18 @@ type Content struct {
Parts []ContentPart `json:"parts"` Parts []ContentPart `json:"parts"`
} }
type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GenerateContentRequest struct { type GenerateContentRequest struct {
Contents []Content `json:"contents"` Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
SystemInstructions string `json:"systemInstructions,omitempty"`
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
} }
type Candidate struct { type Candidate struct {
@ -52,8 +61,15 @@ type Candidate struct {
Index int `json:"index"` Index int `json:"index"`
} }
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GenerateContentResponse struct { type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"` Candidates []Candidate `json:"candidates"`
UsageMetadata UsageMetadata `json:"usageMetadata"`
} }
type Tool struct { type Tool struct {
@ -75,7 +91,7 @@ type ToolParameters struct {
type ToolParameter struct { type ToolParameter struct {
Type string `json:"type"` Type string `json:"type"`
Description string `json:"description"` Description string `json:"description"`
Enum []string `json:"enum,omitempty"` Values []string `json:"values,omitempty"`
} }
func convertTools(tools []model.Tool) []Tool { func convertTools(tools []model.Tool) []Tool {
@ -85,10 +101,11 @@ func convertTools(tools []model.Tool) []Tool {
var required []string var required []string
for _, param := range tool.Parameters { for _, param := range tool.Parameters {
// TODO: proper enum handing
params[param.Name] = ToolParameter{ params[param.Name] = ToolParameter{
Type: param.Type, Type: param.Type,
Description: param.Description, Description: param.Description,
Enum: param.Enum, Values: param.Enum,
} }
if param.Required { if param.Required {
required = append(required, param.Name) required = append(required, param.Name)
@ -142,13 +159,36 @@ func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
return converted return converted
} }
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
err := json.Unmarshal([]byte(result.Result), &obj)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
}
results[i] = FunctionResponse{
Name: result.ToolName,
Response: obj,
}
}
return results, nil
}
func createGenerateContentRequest( func createGenerateContentRequest(
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
) GenerateContentRequest { ) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages)) requestContents := make([]Content, 0, len(messages))
for _, m := range messages { startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
system = messages[0].Content
startIdx = 1
}
for _, m := range messages[startIdx:] {
switch m.Role { switch m.Role {
case "tool_call": case "tool_call":
content := Content{ content := Content{
@ -157,16 +197,17 @@ func createGenerateContentRequest(
} }
requestContents = append(requestContents, content) requestContents = append(requestContents, content)
case "tool_result": case "tool_result":
results, err := convertToolResultsToGemini(m.ToolResults)
if err != nil {
return nil, err
}
// expand tool_result messages' results into multiple gemini messages // expand tool_result messages' results into multiple gemini messages
for _, result := range m.ToolResults { for _, result := range results {
content := Content{ content := Content{
Role: "function", Role: "function",
Parts: []ContentPart{ Parts: []ContentPart{
{ {
FunctionResp: &FunctionResponse{ FunctionResp: &result,
Name: result.ToolCallID,
Response: result.Result,
},
}, },
}, },
} }
@ -197,15 +238,21 @@ func createGenerateContentRequest(
} }
} }
request := GenerateContentRequest{ request := &GenerateContentRequest{
Contents: requestContents, Contents: requestContents,
SystemInstructions: system,
GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature,
TopP: &params.TopP,
},
} }
if len(params.ToolBag) > 0 { if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag) request.Tools = convertTools(params.ToolBag)
} }
return request return request, nil
} }
func handleToolCalls( func handleToolCalls(
@ -276,7 +323,10 @@ func (c *Client) CreateChatCompletion(
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req := createGenerateContentRequest(params, messages) req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
}
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return "", err
@ -349,7 +399,10 @@ func (c *Client) CreateChatCompletionStream(
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req := createGenerateContentRequest(params, messages) req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
}
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return "", err
@ -377,6 +430,8 @@ func (c *Client) CreateChatCompletionStream(
content.WriteString(lastMessage.Content) content.WriteString(lastMessage.Content)
} }
var toolCalls []model.ToolCall
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
for { for {
line, err := reader.ReadBytes('\n') line, err := reader.ReadBytes('\n')
@ -401,8 +456,6 @@ func (c *Client) CreateChatCompletionStream(
} }
for _, candidate := range streamResp.Candidates { for _, candidate := range streamResp.Candidates {
var toolCalls []model.ToolCall
for _, part := range candidate.Content.Parts { for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...) toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
@ -411,6 +464,8 @@ func (c *Client) CreateChatCompletionStream(
content.WriteString(part.Text) content.WriteString(part.Text)
} }
} }
}
}
// If there are function calls, handle them and recurse // If there are function calls, handle them and recurse
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
@ -420,8 +475,6 @@ func (c *Client) CreateChatCompletionStream(
} }
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
}
}
if callback != nil { if callback != nil {
callback(model.Message{ callback(model.Message{