Gemini cleanup, tool calling working
This commit is contained in:
parent
1b8d04c96d
commit
a291e7b42c
@ -41,9 +41,18 @@ type Content struct {
|
|||||||
Parts []ContentPart `json:"parts"`
|
Parts []ContentPart `json:"parts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GenerationConfig struct {
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type GenerateContentRequest struct {
|
type GenerateContentRequest struct {
|
||||||
Contents []Content `json:"contents"`
|
Contents []Content `json:"contents"`
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
SystemInstructions string `json:"systemInstructions,omitempty"`
|
||||||
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Candidate struct {
|
type Candidate struct {
|
||||||
@ -52,8 +61,15 @@ type Candidate struct {
|
|||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
type GenerateContentResponse struct {
|
type GenerateContentResponse struct {
|
||||||
Candidates []Candidate `json:"candidates"`
|
Candidates []Candidate `json:"candidates"`
|
||||||
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
@ -75,7 +91,7 @@ type ToolParameters struct {
|
|||||||
type ToolParameter struct {
|
type ToolParameter struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Enum []string `json:"enum,omitempty"`
|
Values []string `json:"values,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertTools(tools []model.Tool) []Tool {
|
func convertTools(tools []model.Tool) []Tool {
|
||||||
@ -85,10 +101,11 @@ func convertTools(tools []model.Tool) []Tool {
|
|||||||
var required []string
|
var required []string
|
||||||
|
|
||||||
for _, param := range tool.Parameters {
|
for _, param := range tool.Parameters {
|
||||||
|
// TODO: proper enum handing
|
||||||
params[param.Name] = ToolParameter{
|
params[param.Name] = ToolParameter{
|
||||||
Type: param.Type,
|
Type: param.Type,
|
||||||
Description: param.Description,
|
Description: param.Description,
|
||||||
Enum: param.Enum,
|
Values: param.Enum,
|
||||||
}
|
}
|
||||||
if param.Required {
|
if param.Required {
|
||||||
required = append(required, param.Name)
|
required = append(required, param.Name)
|
||||||
@ -142,13 +159,36 @@ func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
|
|||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
||||||
|
results := make([]FunctionResponse, len(toolResults))
|
||||||
|
for i, result := range toolResults {
|
||||||
|
var obj interface{}
|
||||||
|
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||||
|
}
|
||||||
|
results[i] = FunctionResponse{
|
||||||
|
Name: result.ToolName,
|
||||||
|
Response: obj,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
func createGenerateContentRequest(
|
func createGenerateContentRequest(
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
) GenerateContentRequest {
|
) (*GenerateContentRequest, error) {
|
||||||
requestContents := make([]Content, 0, len(messages))
|
requestContents := make([]Content, 0, len(messages))
|
||||||
|
|
||||||
for _, m := range messages {
|
startIdx := 0
|
||||||
|
var system string
|
||||||
|
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||||
|
system = messages[0].Content
|
||||||
|
startIdx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range messages[startIdx:] {
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case "tool_call":
|
case "tool_call":
|
||||||
content := Content{
|
content := Content{
|
||||||
@ -157,16 +197,17 @@ func createGenerateContentRequest(
|
|||||||
}
|
}
|
||||||
requestContents = append(requestContents, content)
|
requestContents = append(requestContents, content)
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
|
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
// expand tool_result messages' results into multiple gemini messages
|
// expand tool_result messages' results into multiple gemini messages
|
||||||
for _, result := range m.ToolResults {
|
for _, result := range results {
|
||||||
content := Content{
|
content := Content{
|
||||||
Role: "function",
|
Role: "function",
|
||||||
Parts: []ContentPart{
|
Parts: []ContentPart{
|
||||||
{
|
{
|
||||||
FunctionResp: &FunctionResponse{
|
FunctionResp: &result,
|
||||||
Name: result.ToolCallID,
|
|
||||||
Response: result.Result,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -175,10 +216,10 @@ func createGenerateContentRequest(
|
|||||||
default:
|
default:
|
||||||
var role string
|
var role string
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case model.MessageRoleAssistant:
|
case model.MessageRoleAssistant:
|
||||||
role = "model"
|
role = "model"
|
||||||
case model.MessageRoleUser:
|
case model.MessageRoleUser:
|
||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
|
|
||||||
if role == "" {
|
if role == "" {
|
||||||
@ -197,15 +238,21 @@ func createGenerateContentRequest(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := GenerateContentRequest{
|
request := &GenerateContentRequest{
|
||||||
Contents: requestContents,
|
Contents: requestContents,
|
||||||
|
SystemInstructions: system,
|
||||||
|
GenerationConfig: &GenerationConfig{
|
||||||
|
MaxOutputTokens: ¶ms.MaxTokens,
|
||||||
|
Temperature: ¶ms.Temperature,
|
||||||
|
TopP: ¶ms.TopP,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(params.ToolBag) > 0 {
|
if len(params.ToolBag) > 0 {
|
||||||
request.Tools = convertTools(params.ToolBag)
|
request.Tools = convertTools(params.ToolBag)
|
||||||
}
|
}
|
||||||
|
|
||||||
return request
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleToolCalls(
|
func handleToolCalls(
|
||||||
@ -276,7 +323,10 @@ func (c *Client) CreateChatCompletion(
|
|||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createGenerateContentRequest(params, messages)
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -349,7 +399,10 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createGenerateContentRequest(params, messages)
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@ -377,6 +430,8 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
content.WriteString(lastMessage.Content)
|
content.WriteString(lastMessage.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var toolCalls []model.ToolCall
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
@ -401,8 +456,6 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, candidate := range streamResp.Candidates {
|
for _, candidate := range streamResp.Candidates {
|
||||||
var toolCalls []model.ToolCall
|
|
||||||
|
|
||||||
for _, part := range candidate.Content.Parts {
|
for _, part := range candidate.Content.Parts {
|
||||||
if part.FunctionCall != nil {
|
if part.FunctionCall != nil {
|
||||||
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
||||||
@ -411,18 +464,18 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
content.WriteString(part.Text)
|
content.WriteString(part.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are function calls, handle them and recurse
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
|
||||||
if err != nil {
|
|
||||||
return content.String(), err
|
|
||||||
}
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If there are function calls, handle them and recurse
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||||
|
if err != nil {
|
||||||
|
return content.String(), err
|
||||||
|
}
|
||||||
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
|
}
|
||||||
|
|
||||||
if callback != nil {
|
if callback != nil {
|
||||||
callback(model.Message{
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
|
Loading…
Reference in New Issue
Block a user