Gemini cleanup, tool calling working

This commit is contained in:
Matt Low 2024-05-19 01:38:02 +00:00
parent 1b8d04c96d
commit a291e7b42c

View File

@ -41,9 +41,18 @@ type Content struct {
Parts []ContentPart `json:"parts"`
}
type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
}
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
Contents []Content `json:"contents"`
Tools []Tool `json:"tools,omitempty"`
SystemInstructions string `json:"systemInstructions,omitempty"`
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
}
type Candidate struct {
@ -52,8 +61,15 @@ type Candidate struct {
Index int `json:"index"`
}
type UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GenerateContentResponse struct {
Candidates []Candidate `json:"candidates"`
Candidates []Candidate `json:"candidates"`
UsageMetadata UsageMetadata `json:"usageMetadata"`
}
type Tool struct {
@ -75,7 +91,7 @@ type ToolParameters struct {
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
Values []string `json:"values,omitempty"`
}
func convertTools(tools []model.Tool) []Tool {
@ -85,10 +101,11 @@ func convertTools(tools []model.Tool) []Tool {
var required []string
for _, param := range tool.Parameters {
// TODO: proper enum handing
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
Values: param.Enum,
}
if param.Required {
required = append(required, param.Name)
@ -142,13 +159,36 @@ func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
return converted
}
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults {
var obj interface{}
err := json.Unmarshal([]byte(result.Result), &obj)
if err != nil {
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
}
results[i] = FunctionResponse{
Name: result.ToolName,
Response: obj,
}
}
return results, nil
}
func createGenerateContentRequest(
params model.RequestParameters,
messages []model.Message,
) GenerateContentRequest {
) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages))
for _, m := range messages {
startIdx := 0
var system string
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
system = messages[0].Content
startIdx = 1
}
for _, m := range messages[startIdx:] {
switch m.Role {
case "tool_call":
content := Content{
@ -157,16 +197,17 @@ func createGenerateContentRequest(
}
requestContents = append(requestContents, content)
case "tool_result":
results, err := convertToolResultsToGemini(m.ToolResults)
if err != nil {
return nil, err
}
// expand tool_result messages' results into multiple gemini messages
for _, result := range m.ToolResults {
for _, result := range results {
content := Content{
Role: "function",
Parts: []ContentPart{
{
FunctionResp: &FunctionResponse{
Name: result.ToolCallID,
Response: result.Result,
},
FunctionResp: &result,
},
},
}
@ -175,10 +216,10 @@ func createGenerateContentRequest(
default:
var role string
switch m.Role {
case model.MessageRoleAssistant:
role = "model"
case model.MessageRoleUser:
role = "user"
case model.MessageRoleAssistant:
role = "model"
case model.MessageRoleUser:
role = "user"
}
if role == "" {
@ -197,15 +238,21 @@ func createGenerateContentRequest(
}
}
request := GenerateContentRequest{
request := &GenerateContentRequest{
Contents: requestContents,
SystemInstructions: system,
GenerationConfig: &GenerationConfig{
MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature,
TopP: &params.TopP,
},
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
}
return request
return request, nil
}
func handleToolCalls(
@ -276,7 +323,10 @@ func (c *Client) CreateChatCompletion(
return "", fmt.Errorf("Can't create completion from no messages")
}
req := createGenerateContentRequest(params, messages)
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
}
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
@ -349,7 +399,10 @@ func (c *Client) CreateChatCompletionStream(
return "", fmt.Errorf("Can't create completion from no messages")
}
req := createGenerateContentRequest(params, messages)
req, err := createGenerateContentRequest(params, messages)
if err != nil {
return "", err
}
jsonData, err := json.Marshal(req)
if err != nil {
return "", err
@ -377,6 +430,8 @@ func (c *Client) CreateChatCompletionStream(
content.WriteString(lastMessage.Content)
}
var toolCalls []model.ToolCall
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
@ -401,8 +456,6 @@ func (c *Client) CreateChatCompletionStream(
}
for _, candidate := range streamResp.Candidates {
var toolCalls []model.ToolCall
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
@ -411,18 +464,18 @@ func (c *Client) CreateChatCompletionStream(
content.WriteString(part.Text)
}
}
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
if err != nil {
return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
}
}
// If there are function calls, handle them and recurse
if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
if err != nil {
return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,