Fix for non-streamed gemini responses
This commit is contained in:
parent
b82f3019f0
commit
62d98289e8
@ -144,17 +144,15 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(parts))
|
||||
for i, part := range parts {
|
||||
if part.FunctionCall != nil {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range part.FunctionCall.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = part.FunctionCall.Name
|
||||
converted[i].Parameters = params
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(functionCalls))
|
||||
for i, call := range functionCalls {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range call.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = call.Name
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return converted
|
||||
}
|
||||
@ -239,7 +237,7 @@ func createGenerateContentRequest(
|
||||
}
|
||||
|
||||
request := &GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
Contents: requestContents,
|
||||
SystemInstructions: system,
|
||||
GenerationConfig: &GenerationConfig{
|
||||
MaxOutputTokens: ¶ms.MaxTokens,
|
||||
@ -361,15 +359,21 @@ func (c *Client) CreateChatCompletion(
|
||||
content = lastMessage.Content
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := convertToolCallToAPI(choice.Content.Parts)
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
||||
messages, err := handleToolCalls(
|
||||
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
@ -384,7 +388,6 @@ func (c *Client) CreateChatCompletion(
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return content, nil
|
||||
}
|
||||
|
||||
@ -430,7 +433,7 @@ func (c *Client) CreateChatCompletionStream(
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
var toolCalls []model.ToolCall
|
||||
var toolCalls []FunctionCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
@ -458,7 +461,7 @@ func (c *Client) CreateChatCompletionStream(
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- part.Text
|
||||
content.WriteString(part.Text)
|
||||
@ -469,7 +472,9 @@ func (c *Client) CreateChatCompletionStream(
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
||||
messages, err := handleToolCalls(
|
||||
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user