From e564a93c374be09e5859b319ca4053a5a5cbd65c Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 27 Jan 2025 06:57:55 +0000 Subject: [PATCH] Add support for google-style reasoning --- pkg/provider/google/google.go | 37 +++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/pkg/provider/google/google.go b/pkg/provider/google/google.go index 7a61f4a..b6a827b 100644 --- a/pkg/provider/google/google.go +++ b/pkg/provider/google/google.go @@ -21,6 +21,7 @@ type Client struct { type ContentPart struct { Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionResp *FunctionResponse `json:"functionResponse,omitempty"` } @@ -40,11 +41,18 @@ type Content struct { Parts []ContentPart `json:"parts"` } +type ThinkingConfig struct { + // Indicates whether to include thoughts in the response. If true, thoughts are returned + // only if the model supports thought and thoughts are available. + IncludeThoughts bool `json:"includeThoughts,omitempty"` +} + type GenerationConfig struct { - MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` - TopP *float32 `json:"topP,omitempty"` - TopK *int `json:"topK,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"topP,omitempty"` + TopK *int `json:"topK,omitempty"` + ThinkingConfig *ThinkingConfig `json:"thinkingConfig,omitempty"` } type GenerateContentRequest struct { @@ -241,6 +249,9 @@ func createGenerateContentRequest( MaxOutputTokens: ¶ms.MaxTokens, Temperature: ¶ms.Temperature, TopP: ¶ms.TopP, + ThinkingConfig: &ThinkingConfig{ + IncludeThoughts: true, + }, }, } @@ -363,7 +374,7 @@ func (c *Client) CreateChatCompletionStream( } url := fmt.Sprintf( - "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", + "%s/v1alpha/models/%s:streamGenerateContent?key=%s&alt=sse", c.BaseURL, params.Model, c.APIKey, ) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) @@ -378,14 +389,14 @@ func (c *Client) CreateChatCompletionStream( defer resp.Body.Close() content := strings.Builder{} + reasoning := strings.Builder{} + var toolCalls []FunctionCall lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content.WriteString(lastMessage.Content) } - var toolCalls []FunctionCall - reader := bufio.NewReader(resp.Body) lastTokenCount := 0 @@ -416,14 +427,20 @@ func (c *Client) CreateChatCompletionStream( choice := resp.Candidates[0] for _, part := range choice.Content.Parts { - if part.FunctionCall != nil { - toolCalls = append(toolCalls, *part.FunctionCall) + if part.Thought { + output <- provider.Chunk{ + ReasoningContent: part.Text, + TokenCount: uint(tokens), + } + reasoning.WriteString(part.Text) } else if part.Text != "" { output <- provider.Chunk{ Content: part.Text, TokenCount: uint(tokens), } content.WriteString(part.Text) + } else if part.FunctionCall != nil { + toolCalls = append(toolCalls, *part.FunctionCall) } } } @@ -432,5 +449,5 @@ func (c *Client) CreateChatCompletionStream( return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil } - return api.NewMessageWithAssistant(content.String(), ""), nil + return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil }