Add support for google-style reasoning
This commit is contained in:
@@ -21,6 +21,7 @@ type Client struct {
|
|||||||
|
|
||||||
type ContentPart struct {
|
type ContentPart struct {
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
|
Thought bool `json:"thought,omitempty"`
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -40,11 +41,18 @@ type Content struct {
|
|||||||
Parts []ContentPart `json:"parts"`
|
Parts []ContentPart `json:"parts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ThinkingConfig struct {
|
||||||
|
// Indicates whether to include thoughts in the response. If true, thoughts are returned
|
||||||
|
// only if the model supports thought and thoughts are available.
|
||||||
|
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type GenerationConfig struct {
|
type GenerationConfig struct {
|
||||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
Temperature *float32 `json:"temperature,omitempty"`
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
TopP *float32 `json:"topP,omitempty"`
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
TopK *int `json:"topK,omitempty"`
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
ThinkingConfig *ThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateContentRequest struct {
|
type GenerateContentRequest struct {
|
||||||
@@ -241,6 +249,9 @@ func createGenerateContentRequest(
|
|||||||
MaxOutputTokens: ¶ms.MaxTokens,
|
MaxOutputTokens: ¶ms.MaxTokens,
|
||||||
Temperature: ¶ms.Temperature,
|
Temperature: ¶ms.Temperature,
|
||||||
TopP: ¶ms.TopP,
|
TopP: ¶ms.TopP,
|
||||||
|
ThinkingConfig: &ThinkingConfig{
|
||||||
|
IncludeThoughts: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,7 +374,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf(
|
url := fmt.Sprintf(
|
||||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
"%s/v1alpha/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||||
c.BaseURL, params.Model, c.APIKey,
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
)
|
)
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
@@ -378,14 +389,14 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
content := strings.Builder{}
|
content := strings.Builder{}
|
||||||
|
reasoning := strings.Builder{}
|
||||||
|
var toolCalls []FunctionCall
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
lastMessage := messages[len(messages)-1]
|
||||||
if lastMessage.Role.IsAssistant() {
|
if lastMessage.Role.IsAssistant() {
|
||||||
content.WriteString(lastMessage.Content)
|
content.WriteString(lastMessage.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolCalls []FunctionCall
|
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
|
||||||
lastTokenCount := 0
|
lastTokenCount := 0
|
||||||
@@ -416,14 +427,20 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
|
|
||||||
choice := resp.Candidates[0]
|
choice := resp.Candidates[0]
|
||||||
for _, part := range choice.Content.Parts {
|
for _, part := range choice.Content.Parts {
|
||||||
if part.FunctionCall != nil {
|
if part.Thought {
|
||||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
output <- provider.Chunk{
|
||||||
|
ReasoningContent: part.Text,
|
||||||
|
TokenCount: uint(tokens),
|
||||||
|
}
|
||||||
|
reasoning.WriteString(part.Text)
|
||||||
} else if part.Text != "" {
|
} else if part.Text != "" {
|
||||||
output <- provider.Chunk{
|
output <- provider.Chunk{
|
||||||
Content: part.Text,
|
Content: part.Text,
|
||||||
TokenCount: uint(tokens),
|
TokenCount: uint(tokens),
|
||||||
}
|
}
|
||||||
content.WriteString(part.Text)
|
content.WriteString(part.Text)
|
||||||
|
} else if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -432,5 +449,5 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.NewMessageWithAssistant(content.String(), ""), nil
|
return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user