Gemini cleanup, tool calling working
This commit is contained in:
parent
1b8d04c96d
commit
a291e7b42c
@ -41,9 +41,18 @@ type Content struct {
|
||||
Parts []ContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GenerationConfig struct {
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
SystemInstructions string `json:"systemInstructions,omitempty"`
|
||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
@ -52,8 +61,15 @@ type Candidate struct {
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
type GenerateContentResponse struct {
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
@ -75,7 +91,7 @@ type ToolParameters struct {
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
Values []string `json:"values,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
@ -85,10 +101,11 @@ func convertTools(tools []model.Tool) []Tool {
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
// TODO: proper enum handing
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
Values: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
@ -142,13 +159,36 @@ func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
||||
results := make([]FunctionResponse, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||
}
|
||||
results[i] = FunctionResponse{
|
||||
Name: result.ToolName,
|
||||
Response: obj,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) GenerateContentRequest {
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
startIdx := 0
|
||||
var system string
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
system = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
for _, m := range messages[startIdx:] {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
content := Content{
|
||||
@ -157,16 +197,17 @@ func createGenerateContentRequest(
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
case "tool_result":
|
||||
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// expand tool_result messages' results into multiple gemini messages
|
||||
for _, result := range m.ToolResults {
|
||||
for _, result := range results {
|
||||
content := Content{
|
||||
Role: "function",
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
FunctionResp: &FunctionResponse{
|
||||
Name: result.ToolCallID,
|
||||
Response: result.Result,
|
||||
},
|
||||
FunctionResp: &result,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -175,10 +216,10 @@ func createGenerateContentRequest(
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case model.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case model.MessageRoleUser:
|
||||
role = "user"
|
||||
case model.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case model.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
@ -197,15 +238,21 @@ func createGenerateContentRequest(
|
||||
}
|
||||
}
|
||||
|
||||
request := GenerateContentRequest{
|
||||
request := &GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
SystemInstructions: system,
|
||||
GenerationConfig: &GenerationConfig{
|
||||
MaxOutputTokens: ¶ms.MaxTokens,
|
||||
Temperature: ¶ms.Temperature,
|
||||
TopP: ¶ms.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
}
|
||||
|
||||
return request
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
@ -276,7 +323,10 @@ func (c *Client) CreateChatCompletion(
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createGenerateContentRequest(params, messages)
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -349,7 +399,10 @@ func (c *Client) CreateChatCompletionStream(
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createGenerateContentRequest(params, messages)
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@ -377,6 +430,8 @@ func (c *Client) CreateChatCompletionStream(
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
var toolCalls []model.ToolCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
@ -401,8 +456,6 @@ func (c *Client) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
var toolCalls []model.ToolCall
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
||||
@ -411,18 +464,18 @@ func (c *Client) CreateChatCompletionStream(
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
|
Loading…
Reference in New Issue
Block a user