From 5335b5c28f8794bcb44a4bf660e7383122cee6e2 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Tue, 29 Jul 2025 00:55:28 +0000 Subject: [PATCH] Add support for openrouter reasoning + refactor Started work to make it possible to pass in per-model reasoning config Cleaned up how we instantiate RequestParameters (TBD: remove RequestParameters?) --- pkg/cmd/util/util.go | 6 +--- pkg/lmcli/config.go | 14 ++++----- pkg/lmcli/lmcli.go | 35 ++++++++++------------ pkg/provider/openai/openai.go | 50 ++++++++++++++++++++++++++++--- pkg/provider/provider.go | 55 ++++++++++++++++++++++++++++++++++- pkg/tui/model/model.go | 6 +--- 6 files changed, 125 insertions(+), 41 deletions(-) diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 69d26ba..493f60b 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -25,11 +25,7 @@ func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(c } p := modelConfig.Client - params := provider.RequestParameters{ - Model: modelConfig.Model, - MaxTokens: modelConfig.MaxTokens, - Temperature: modelConfig.Temperature, - } + params := provider.NewRequestParameters(*modelConfig) system := ctx.DefaultSystemPrompt() diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index ac3512d..e790457 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -10,7 +10,7 @@ import ( type Model struct { Name string - Reasoning bool + Reasoning *bool MaxTokens *int Temperature *float32 } @@ -32,9 +32,10 @@ func (p *Provider) Models() []Model { type Config struct { Defaults *struct { - Model *string `yaml:"model" default:"gpt-4"` - MaxTokens *int `yaml:"maxTokens" default:"256"` - Temperature *float32 `yaml:"temperature" default:"0.2"` + Model *string `yaml:"model" default:"default-model"` + MaxTokens *int `yaml:"maxTokens" default:"2048"` + Temperature *float32 `yaml:"temperature" default:"0.6"` + Reasoning *bool `yaml:"reasoning" default:"true"` SystemPrompt string `yaml:"systemPrompt,omitempty"` SystemPromptFile string `yaml:"systemPromptFile,omitempty"` Agent string `yaml:"agent"` @@ -81,7 +82,7 @@ func parseModels(rawModels []any) ([]Model, error) { if reasoningVal, ok := rawModel["reasoning"]; ok { if reasoningBool, ok := reasoningVal.(bool); ok { - parsedModel.Reasoning = reasoningBool + parsedModel.Reasoning = &reasoningBool } else { return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name) } @@ -107,8 +108,7 @@ func parseModels(rawModels []any) ([]Model, error) { } // else: default is nil default: - return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i) - continue // Skip this unknown model definition format + return nil, fmt.Errorf("Invalid model definition type (%T) at index %d", modelInterface, i) } models = append(models, parsedModel) diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 3e165bd..f2a9586 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -70,7 +70,6 @@ func NewContext() (*Context, error) { return &Context{*config, repo, *chroma}, nil } - func (c *Context) GetModels() (models []string) { modelCounts := make(map[string]int) for _, p := range c.Config.Providers { @@ -139,15 +138,7 @@ func (c *Context) DefaultSystemPrompt() string { return c.Config.Defaults.SystemPrompt } -type ModelConfig struct { - Provider string - Client provider.ChatCompletionProvider - Model string - MaxTokens int - Temperature float32 -} - -func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig { +func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig { // Set model name cfg.Model = m.Name @@ -164,15 +155,21 @@ func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig { } else { cfg.Temperature = *m.Temperature } + + if m.Reasoning == nil { + cfg.Reasoning = *c.Config.Defaults.Reasoning + } else { + cfg.Reasoning = *m.Reasoning + } return cfg } -func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) { +func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) { parts := strings.Split(model, "@") - if provider == "" && len(parts) > 1 { + if providerName == "" && len(parts) > 1 { model = parts[0] - provider = parts[1] + providerName = parts[1] } for _, p := range c.Config.Providers { @@ -181,12 +178,12 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, name = p.Name } - if provider != "" && name != provider { + if providerName != "" && name != providerName { continue } for _, m := range p.Models() { - var cfg *ModelConfig + var cfg *provider.ModelConfig if m.Name == model { switch p.Kind { case "anthropic": @@ -194,7 +191,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, if p.BaseURL != "" { url = p.BaseURL } - cfg = &ModelConfig{ + cfg = &provider.ModelConfig{ Client: &anthropic.AnthropicClient{ BaseURL: url, APIKey: p.APIKey, @@ -206,7 +203,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, if p.BaseURL != "" { url = p.BaseURL } - cfg := &ModelConfig{ + cfg := &provider.ModelConfig{ Client: &google.Client{ BaseURL: url, APIKey: p.APIKey, @@ -218,7 +215,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, if p.BaseURL != "" { url = p.BaseURL } - cfg := &ModelConfig{ + cfg := &provider.ModelConfig{ Client: &ollama.OllamaClient{ BaseURL: url, }, @@ -229,7 +226,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, if p.BaseURL != "" { url = p.BaseURL } - cfg := &ModelConfig{ + cfg := &provider.ModelConfig{ Client: &openai.OpenAIClient{ BaseURL: url, APIKey: p.APIKey, diff --git a/pkg/provider/openai/openai.go b/pkg/provider/openai/openai.go index e1cf06d..b5344a1 100644 --- a/pkg/provider/openai/openai.go +++ b/pkg/provider/openai/openai.go @@ -23,7 +23,8 @@ type OpenAIClient struct { type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` - ReasoningContent string `json:"reasoning_content,omitempty"` + Reasoning string `json:"reasoning,omitempty"` // OpenRouter + ReasoningContent string `json:"reasoning_content,omitempty"` // Deepseek, llama-server ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` } @@ -59,6 +60,13 @@ type Tool struct { Function FunctionDefinition `json:"function"` } +type OpenRouterReasoning struct { + Effort provider.ReasoningEffort `json:"effort,omitempty"` // "high", "medium", "low" + MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit + Exclude bool `json:"exclude,omitempty"` // Exclude reasoning tokens from response + Enabled bool `json:"enabled,omitempty"` // Enable reasoning (default: inferred) +} + type ChatCompletionRequest struct { Model string `json:"model"` MaxTokens int `json:"max_tokens,omitempty"` @@ -68,6 +76,8 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` ToolChoice string `json:"tool_choice,omitempty"` Stream bool `json:"stream,omitempty"` + // Reasoning config. TBD: handle for mulitple providers using the same field + Reasoning *OpenRouterReasoning `json:"reasoning,omitempty"` } type ChatCompletionChoice struct { @@ -185,6 +195,16 @@ func createChatCompletionRequest( request.ToolChoice = "auto" } + // Add OpenRouter reasoning config + if params.Provider.Kind == provider.OpenRouter { + request.Reasoning = &OpenRouterReasoning{ + Effort: params.Reasoning.Effort, + MaxTokens: params.Reasoning.MaxTokens, + Exclude: params.Reasoning.Exclude, + Enabled: params.Reasoning.Enabled, + } + } + return request } @@ -243,13 +263,20 @@ func (c *OpenAIClient) CreateChatCompletion( } choice := completionResp.Choices[0] + lastMessage := messages[len(messages)-1] var content string - lastMessage := messages[len(messages)-1] + var reasoning string + // Check if last message was a pre-fill if lastMessage.Role.IsAssistant() { + // Append new contents to previous last message content = lastMessage.Content + choice.Message.Content + // TBD: reasoning } else { content = choice.Message.Content + if len(choice.Message.Reasoning) > 0 { + reasoning = choice.Message.Reasoning + } } toolCalls := choice.Message.ToolCalls @@ -257,7 +284,7 @@ func (c *OpenAIClient) CreateChatCompletion( return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil } - return api.NewMessageWithAssistant(content, ""), nil + return api.NewMessageWithAssistant(content, reasoning), nil } func (c *OpenAIClient) CreateChatCompletionStream( @@ -284,7 +311,9 @@ func (c *OpenAIClient) CreateChatCompletionStream( toolCalls := []ToolCall{} lastMessage := messages[len(messages)-1] + // Check if this was a prefill if lastMessage.Role.IsAssistant() { + // Append the last message's contents to the buffer content.WriteString(lastMessage.Content) } @@ -342,10 +371,23 @@ func (c *OpenAIClient) CreateChatCompletionStream( } reasoning.WriteString(delta.ReasoningContent) } + // Handle reasoning field in stream response + if len(delta.Reasoning) > 0 { + output <- provider.Chunk{ + ReasoningContent: delta.Reasoning, + TokenCount: 1, + } + reasoning.WriteString(delta.Reasoning) + } } if len(toolCalls) > 0 { - return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil + msg := api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)) + if err != nil { + return nil, err + } + msg.ReasoningContent = reasoning.String() + return msg, nil } return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index d387ced..ad4182a 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -12,8 +12,61 @@ type Chunk struct { TokenCount uint } +type ModelConfig struct { + Provider string + Client ChatCompletionProvider + Model string + MaxTokens int + Temperature float32 + Reasoning bool +} + +func NewRequestParameters(modelConfig ModelConfig) RequestParameters { + params := RequestParameters{ + Model: modelConfig.Model, + MaxTokens: modelConfig.MaxTokens, + Temperature: modelConfig.Temperature, + Reasoning: ReasoningConfig{ + Enabled: modelConfig.Reasoning, + }, + } + return params +} + +type ReasoningEffort string + +const ( + High ReasoningEffort = "high" + Medium ReasoningEffort = "medium" + Low ReasoningEffort = "low" +) + +// ProviderKind is a bit leaky, it informs the ChatCompletionProvider what +// provider we're on so we know how to format requests, etc. +type ProviderKind string + +const ( + OpenRouter ProviderKind = "openrouter" + OpenAI ProviderKind = "openai" + Anthropic ProviderKind = "anthropic" +) + +type ProviderConfig struct { + Kind ProviderKind + SupportPrefill bool +} + +type ReasoningConfig struct { + Effort ReasoningEffort + MaxTokens int + Exclude bool + Enabled bool +} + type RequestParameters struct { - Model string + Provider ProviderConfig + Model string + Reasoning ReasoningConfig MaxTokens int Temperature float32 diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index dbd50e8..0cec615 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -263,11 +263,7 @@ func (a *AppModel) Prompt( } p := modelConfig.Client - params := provider.RequestParameters{ - Model: modelConfig.Model, - MaxTokens: modelConfig.MaxTokens, - Temperature: modelConfig.Temperature, - } + params := provider.NewRequestParameters(*modelConfig) if a.Agent != nil { params.Toolbox = a.Agent.Toolbox