Fix for non-streamed gemini responses
This commit is contained in:
parent
b82f3019f0
commit
62d98289e8
@ -144,17 +144,15 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
|||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToAPI(parts []ContentPart) []model.ToolCall {
|
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||||
converted := make([]model.ToolCall, len(parts))
|
converted := make([]model.ToolCall, len(functionCalls))
|
||||||
for i, part := range parts {
|
for i, call := range functionCalls {
|
||||||
if part.FunctionCall != nil {
|
params := make(map[string]interface{})
|
||||||
params := make(map[string]interface{})
|
for k, v := range call.Args {
|
||||||
for k, v := range part.FunctionCall.Args {
|
params[k] = v
|
||||||
params[k] = v
|
|
||||||
}
|
|
||||||
converted[i].Name = part.FunctionCall.Name
|
|
||||||
converted[i].Parameters = params
|
|
||||||
}
|
}
|
||||||
|
converted[i].Name = call.Name
|
||||||
|
converted[i].Parameters = params
|
||||||
}
|
}
|
||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
@ -239,7 +237,7 @@ func createGenerateContentRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
request := &GenerateContentRequest{
|
request := &GenerateContentRequest{
|
||||||
Contents: requestContents,
|
Contents: requestContents,
|
||||||
SystemInstructions: system,
|
SystemInstructions: system,
|
||||||
GenerationConfig: &GenerationConfig{
|
GenerationConfig: &GenerationConfig{
|
||||||
MaxOutputTokens: ¶ms.MaxTokens,
|
MaxOutputTokens: ¶ms.MaxTokens,
|
||||||
@ -361,15 +359,21 @@ func (c *Client) CreateChatCompletion(
|
|||||||
content = lastMessage.Content
|
content = lastMessage.Content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var toolCalls []FunctionCall
|
||||||
for _, part := range choice.Content.Parts {
|
for _, part := range choice.Content.Parts {
|
||||||
if part.Text != "" {
|
if part.Text != "" {
|
||||||
content += part.Text
|
content += part.Text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCalls := convertToolCallToAPI(choice.Content.Parts)
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
messages, err := handleToolCalls(
|
||||||
|
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return content, err
|
return content, err
|
||||||
}
|
}
|
||||||
@ -384,7 +388,6 @@ func (c *Client) CreateChatCompletion(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the user-facing message.
|
|
||||||
return content, nil
|
return content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,7 +433,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
content.WriteString(lastMessage.Content)
|
content.WriteString(lastMessage.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolCalls []model.ToolCall
|
var toolCalls []FunctionCall
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
for {
|
for {
|
||||||
@ -458,7 +461,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
for _, candidate := range streamResp.Candidates {
|
for _, candidate := range streamResp.Candidates {
|
||||||
for _, part := range candidate.Content.Parts {
|
for _, part := range candidate.Content.Parts {
|
||||||
if part.FunctionCall != nil {
|
if part.FunctionCall != nil {
|
||||||
toolCalls = append(toolCalls, convertToolCallToAPI([]ContentPart{part})...)
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
} else if part.Text != "" {
|
} else if part.Text != "" {
|
||||||
output <- part.Text
|
output <- part.Text
|
||||||
content.WriteString(part.Text)
|
content.WriteString(part.Text)
|
||||||
@ -469,7 +472,9 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
|
|
||||||
// If there are function calls, handle them and recurse
|
// If there are function calls, handle them and recurse
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
messages, err := handleToolCalls(
|
||||||
|
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return content.String(), err
|
return content.String(), err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user