Private
Public Access
1
0

Add support for openrouter reasoning + refactor

Started work to make it possible to pass in per-model reasoning config
Cleaned up how we instantiate RequestParameters (TBD: remove
RequestParameters?)
This commit is contained in:
2025-07-29 00:55:28 +00:00
parent 54da088dee
commit 5335b5c28f
6 changed files with 125 additions and 41 deletions

View File

@@ -25,11 +25,7 @@ func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(c
}
p := modelConfig.Client
params := provider.RequestParameters{
Model: modelConfig.Model,
MaxTokens: modelConfig.MaxTokens,
Temperature: modelConfig.Temperature,
}
params := provider.NewRequestParameters(*modelConfig)
system := ctx.DefaultSystemPrompt()

View File

@@ -10,7 +10,7 @@ import (
type Model struct {
Name string
Reasoning bool
Reasoning *bool
MaxTokens *int
Temperature *float32
}
@@ -32,9 +32,10 @@ func (p *Provider) Models() []Model {
type Config struct {
Defaults *struct {
Model *string `yaml:"model" default:"gpt-4"`
MaxTokens *int `yaml:"maxTokens" default:"256"`
Temperature *float32 `yaml:"temperature" default:"0.2"`
Model *string `yaml:"model" default:"default-model"`
MaxTokens *int `yaml:"maxTokens" default:"2048"`
Temperature *float32 `yaml:"temperature" default:"0.6"`
Reasoning *bool `yaml:"reasoning" default:"true"`
SystemPrompt string `yaml:"systemPrompt,omitempty"`
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
Agent string `yaml:"agent"`
@@ -81,7 +82,7 @@ func parseModels(rawModels []any) ([]Model, error) {
if reasoningVal, ok := rawModel["reasoning"]; ok {
if reasoningBool, ok := reasoningVal.(bool); ok {
parsedModel.Reasoning = reasoningBool
parsedModel.Reasoning = &reasoningBool
} else {
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
}
@@ -107,8 +108,7 @@ func parseModels(rawModels []any) ([]Model, error) {
} // else: default is nil
default:
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i)
continue // Skip this unknown model definition format
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d", modelInterface, i)
}
models = append(models, parsedModel)

View File

@@ -70,7 +70,6 @@ func NewContext() (*Context, error) {
return &Context{*config, repo, *chroma}, nil
}
func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers {
@@ -139,15 +138,7 @@ func (c *Context) DefaultSystemPrompt() string {
return c.Config.Defaults.SystemPrompt
}
type ModelConfig struct {
Provider string
Client provider.ChatCompletionProvider
Model string
MaxTokens int
Temperature float32
}
func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig {
// Set model name
cfg.Model = m.Name
@@ -164,15 +155,21 @@ func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
} else {
cfg.Temperature = *m.Temperature
}
if m.Reasoning == nil {
cfg.Reasoning = *c.Config.Defaults.Reasoning
} else {
cfg.Reasoning = *m.Reasoning
}
return cfg
}
func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) {
func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
parts := strings.Split(model, "@")
if provider == "" && len(parts) > 1 {
if providerName == "" && len(parts) > 1 {
model = parts[0]
provider = parts[1]
providerName = parts[1]
}
for _, p := range c.Config.Providers {
@@ -181,12 +178,12 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
name = p.Name
}
if provider != "" && name != provider {
if providerName != "" && name != providerName {
continue
}
for _, m := range p.Models() {
var cfg *ModelConfig
var cfg *provider.ModelConfig
if m.Name == model {
switch p.Kind {
case "anthropic":
@@ -194,7 +191,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = &ModelConfig{
cfg = &provider.ModelConfig{
Client: &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
@@ -206,7 +203,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &ModelConfig{
cfg := &provider.ModelConfig{
Client: &google.Client{
BaseURL: url,
APIKey: p.APIKey,
@@ -218,7 +215,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &ModelConfig{
cfg := &provider.ModelConfig{
Client: &ollama.OllamaClient{
BaseURL: url,
},
@@ -229,7 +226,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &ModelConfig{
cfg := &provider.ModelConfig{
Client: &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,

View File

@@ -23,7 +23,8 @@ type OpenAIClient struct {
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"` // OpenRouter
ReasoningContent string `json:"reasoning_content,omitempty"` // Deepseek, llama-server
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
@@ -59,6 +60,13 @@ type Tool struct {
Function FunctionDefinition `json:"function"`
}
type OpenRouterReasoning struct {
Effort provider.ReasoningEffort `json:"effort,omitempty"` // "high", "medium", "low"
MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit
Exclude bool `json:"exclude,omitempty"` // Exclude reasoning tokens from response
Enabled bool `json:"enabled,omitempty"` // Enable reasoning (default: inferred)
}
type ChatCompletionRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
@@ -68,6 +76,8 @@ type ChatCompletionRequest struct {
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
// Reasoning config. TBD: handle for mulitple providers using the same field
Reasoning *OpenRouterReasoning `json:"reasoning,omitempty"`
}
type ChatCompletionChoice struct {
@@ -185,6 +195,16 @@ func createChatCompletionRequest(
request.ToolChoice = "auto"
}
// Add OpenRouter reasoning config
if params.Provider.Kind == provider.OpenRouter {
request.Reasoning = &OpenRouterReasoning{
Effort: params.Reasoning.Effort,
MaxTokens: params.Reasoning.MaxTokens,
Exclude: params.Reasoning.Exclude,
Enabled: params.Reasoning.Enabled,
}
}
return request
}
@@ -243,13 +263,20 @@ func (c *OpenAIClient) CreateChatCompletion(
}
choice := completionResp.Choices[0]
lastMessage := messages[len(messages)-1]
var content string
lastMessage := messages[len(messages)-1]
var reasoning string
// Check if last message was a pre-fill
if lastMessage.Role.IsAssistant() {
// Append new contents to previous last message
content = lastMessage.Content + choice.Message.Content
// TBD: reasoning
} else {
content = choice.Message.Content
if len(choice.Message.Reasoning) > 0 {
reasoning = choice.Message.Reasoning
}
}
toolCalls := choice.Message.ToolCalls
@@ -257,7 +284,7 @@ func (c *OpenAIClient) CreateChatCompletion(
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
}
return api.NewMessageWithAssistant(content, ""), nil
return api.NewMessageWithAssistant(content, reasoning), nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
@@ -284,7 +311,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1]
// Check if this was a prefill
if lastMessage.Role.IsAssistant() {
// Append the last message's contents to the buffer
content.WriteString(lastMessage.Content)
}
@@ -342,10 +371,23 @@ func (c *OpenAIClient) CreateChatCompletionStream(
}
reasoning.WriteString(delta.ReasoningContent)
}
// Handle reasoning field in stream response
if len(delta.Reasoning) > 0 {
output <- provider.Chunk{
ReasoningContent: delta.Reasoning,
TokenCount: 1,
}
reasoning.WriteString(delta.Reasoning)
}
}
if len(toolCalls) > 0 {
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
msg := api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls))
if err != nil {
return nil, err
}
msg.ReasoningContent = reasoning.String()
return msg, nil
}
return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil

View File

@@ -12,8 +12,61 @@ type Chunk struct {
TokenCount uint
}
type RequestParameters struct {
type ModelConfig struct {
Provider string
Client ChatCompletionProvider
Model string
MaxTokens int
Temperature float32
Reasoning bool
}
func NewRequestParameters(modelConfig ModelConfig) RequestParameters {
params := RequestParameters{
Model: modelConfig.Model,
MaxTokens: modelConfig.MaxTokens,
Temperature: modelConfig.Temperature,
Reasoning: ReasoningConfig{
Enabled: modelConfig.Reasoning,
},
}
return params
}
type ReasoningEffort string
const (
High ReasoningEffort = "high"
Medium ReasoningEffort = "medium"
Low ReasoningEffort = "low"
)
// ProviderKind is a bit leaky, it informs the ChatCompletionProvider what
// provider we're on so we know how to format requests, etc.
type ProviderKind string
const (
OpenRouter ProviderKind = "openrouter"
OpenAI ProviderKind = "openai"
Anthropic ProviderKind = "anthropic"
)
type ProviderConfig struct {
Kind ProviderKind
SupportPrefill bool
}
type ReasoningConfig struct {
Effort ReasoningEffort
MaxTokens int
Exclude bool
Enabled bool
}
type RequestParameters struct {
Provider ProviderConfig
Model string
Reasoning ReasoningConfig
MaxTokens int
Temperature float32

View File

@@ -263,11 +263,7 @@ func (a *AppModel) Prompt(
}
p := modelConfig.Client
params := provider.RequestParameters{
Model: modelConfig.Model,
MaxTokens: modelConfig.MaxTokens,
Temperature: modelConfig.Temperature,
}
params := provider.NewRequestParameters(*modelConfig)
if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox