Private
Public Access
1
0

Add support for google-style reasoning

This commit is contained in:
2025-01-27 06:57:55 +00:00
parent ed0d8784d5
commit e564a93c37

View File

@@ -21,6 +21,7 @@ type Client struct {
type ContentPart struct { type ContentPart struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"` FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
} }
@@ -40,11 +41,18 @@ type Content struct {
Parts []ContentPart `json:"parts"` Parts []ContentPart `json:"parts"`
} }
type ThinkingConfig struct {
// Indicates whether to include thoughts in the response. If true, thoughts are returned
// only if the model supports thought and thoughts are available.
IncludeThoughts bool `json:"includeThoughts,omitempty"`
}
type GenerationConfig struct { type GenerationConfig struct {
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty"` Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"` TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"` TopK *int `json:"topK,omitempty"`
ThinkingConfig *ThinkingConfig `json:"thinkingConfig,omitempty"`
} }
type GenerateContentRequest struct { type GenerateContentRequest struct {
@@ -241,6 +249,9 @@ func createGenerateContentRequest(
MaxOutputTokens: &params.MaxTokens, MaxOutputTokens: &params.MaxTokens,
Temperature: &params.Temperature, Temperature: &params.Temperature,
TopP: &params.TopP, TopP: &params.TopP,
ThinkingConfig: &ThinkingConfig{
IncludeThoughts: true,
},
}, },
} }
@@ -363,7 +374,7 @@ func (c *Client) CreateChatCompletionStream(
} }
url := fmt.Sprintf( url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", "%s/v1alpha/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey, c.BaseURL, params.Model, c.APIKey,
) )
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
@@ -378,14 +389,14 @@ func (c *Client) CreateChatCompletionStream(
defer resp.Body.Close() defer resp.Body.Close()
content := strings.Builder{} content := strings.Builder{}
reasoning := strings.Builder{}
var toolCalls []FunctionCall
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() { if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content) content.WriteString(lastMessage.Content)
} }
var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
lastTokenCount := 0 lastTokenCount := 0
@@ -416,14 +427,20 @@ func (c *Client) CreateChatCompletionStream(
choice := resp.Candidates[0] choice := resp.Candidates[0]
for _, part := range choice.Content.Parts { for _, part := range choice.Content.Parts {
if part.FunctionCall != nil { if part.Thought {
toolCalls = append(toolCalls, *part.FunctionCall) output <- provider.Chunk{
ReasoningContent: part.Text,
TokenCount: uint(tokens),
}
reasoning.WriteString(part.Text)
} else if part.Text != "" { } else if part.Text != "" {
output <- provider.Chunk{ output <- provider.Chunk{
Content: part.Text, Content: part.Text,
TokenCount: uint(tokens), TokenCount: uint(tokens),
} }
content.WriteString(part.Text) content.WriteString(part.Text)
} else if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall)
} }
} }
} }
@@ -432,5 +449,5 @@ func (c *Client) CreateChatCompletionStream(
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
} }
return api.NewMessageWithAssistant(content.String(), ""), nil return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil
} }