Add support for openrouter reasoning + refactor
Started work to make it possible to pass in per-model reasoning config Cleaned up how we instantiate RequestParameters (TBD: remove RequestParameters?)
This commit is contained in:
@@ -25,11 +25,7 @@ func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(c
|
|||||||
}
|
}
|
||||||
p := modelConfig.Client
|
p := modelConfig.Client
|
||||||
|
|
||||||
params := provider.RequestParameters{
|
params := provider.NewRequestParameters(*modelConfig)
|
||||||
Model: modelConfig.Model,
|
|
||||||
MaxTokens: modelConfig.MaxTokens,
|
|
||||||
Temperature: modelConfig.Temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
system := ctx.DefaultSystemPrompt()
|
system := ctx.DefaultSystemPrompt()
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Name string
|
Name string
|
||||||
Reasoning bool
|
Reasoning *bool
|
||||||
MaxTokens *int
|
MaxTokens *int
|
||||||
Temperature *float32
|
Temperature *float32
|
||||||
}
|
}
|
||||||
@@ -32,9 +32,10 @@ func (p *Provider) Models() []Model {
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Defaults *struct {
|
Defaults *struct {
|
||||||
Model *string `yaml:"model" default:"gpt-4"`
|
Model *string `yaml:"model" default:"default-model"`
|
||||||
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
MaxTokens *int `yaml:"maxTokens" default:"2048"`
|
||||||
Temperature *float32 `yaml:"temperature" default:"0.2"`
|
Temperature *float32 `yaml:"temperature" default:"0.6"`
|
||||||
|
Reasoning *bool `yaml:"reasoning" default:"true"`
|
||||||
SystemPrompt string `yaml:"systemPrompt,omitempty"`
|
SystemPrompt string `yaml:"systemPrompt,omitempty"`
|
||||||
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
||||||
Agent string `yaml:"agent"`
|
Agent string `yaml:"agent"`
|
||||||
@@ -81,7 +82,7 @@ func parseModels(rawModels []any) ([]Model, error) {
|
|||||||
|
|
||||||
if reasoningVal, ok := rawModel["reasoning"]; ok {
|
if reasoningVal, ok := rawModel["reasoning"]; ok {
|
||||||
if reasoningBool, ok := reasoningVal.(bool); ok {
|
if reasoningBool, ok := reasoningVal.(bool); ok {
|
||||||
parsedModel.Reasoning = reasoningBool
|
parsedModel.Reasoning = &reasoningBool
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
|
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
|
||||||
}
|
}
|
||||||
@@ -107,8 +108,7 @@ func parseModels(rawModels []any) ([]Model, error) {
|
|||||||
} // else: default is nil
|
} // else: default is nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i)
|
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d", modelInterface, i)
|
||||||
continue // Skip this unknown model definition format
|
|
||||||
}
|
}
|
||||||
|
|
||||||
models = append(models, parsedModel)
|
models = append(models, parsedModel)
|
||||||
|
|||||||
@@ -70,7 +70,6 @@ func NewContext() (*Context, error) {
|
|||||||
return &Context{*config, repo, *chroma}, nil
|
return &Context{*config, repo, *chroma}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
modelCounts := make(map[string]int)
|
modelCounts := make(map[string]int)
|
||||||
for _, p := range c.Config.Providers {
|
for _, p := range c.Config.Providers {
|
||||||
@@ -139,15 +138,7 @@ func (c *Context) DefaultSystemPrompt() string {
|
|||||||
return c.Config.Defaults.SystemPrompt
|
return c.Config.Defaults.SystemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelConfig struct {
|
func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig {
|
||||||
Provider string
|
|
||||||
Client provider.ChatCompletionProvider
|
|
||||||
Model string
|
|
||||||
MaxTokens int
|
|
||||||
Temperature float32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
|
|
||||||
// Set model name
|
// Set model name
|
||||||
cfg.Model = m.Name
|
cfg.Model = m.Name
|
||||||
|
|
||||||
@@ -164,15 +155,21 @@ func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig {
|
|||||||
} else {
|
} else {
|
||||||
cfg.Temperature = *m.Temperature
|
cfg.Temperature = *m.Temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.Reasoning == nil {
|
||||||
|
cfg.Reasoning = *c.Config.Defaults.Reasoning
|
||||||
|
} else {
|
||||||
|
cfg.Reasoning = *m.Reasoning
|
||||||
|
}
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) {
|
func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
|
||||||
parts := strings.Split(model, "@")
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
if provider == "" && len(parts) > 1 {
|
if providerName == "" && len(parts) > 1 {
|
||||||
model = parts[0]
|
model = parts[0]
|
||||||
provider = parts[1]
|
providerName = parts[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range c.Config.Providers {
|
for _, p := range c.Config.Providers {
|
||||||
@@ -181,12 +178,12 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
|
|||||||
name = p.Name
|
name = p.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider != "" && name != provider {
|
if providerName != "" && name != providerName {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range p.Models() {
|
for _, m := range p.Models() {
|
||||||
var cfg *ModelConfig
|
var cfg *provider.ModelConfig
|
||||||
if m.Name == model {
|
if m.Name == model {
|
||||||
switch p.Kind {
|
switch p.Kind {
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
@@ -194,7 +191,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
|
|||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
cfg = &ModelConfig{
|
cfg = &provider.ModelConfig{
|
||||||
Client: &anthropic.AnthropicClient{
|
Client: &anthropic.AnthropicClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
@@ -206,7 +203,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
|
|||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
cfg := &ModelConfig{
|
cfg := &provider.ModelConfig{
|
||||||
Client: &google.Client{
|
Client: &google.Client{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
@@ -218,7 +215,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
|
|||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
cfg := &ModelConfig{
|
cfg := &provider.ModelConfig{
|
||||||
Client: &ollama.OllamaClient{
|
Client: &ollama.OllamaClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
},
|
},
|
||||||
@@ -229,7 +226,7 @@ func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig,
|
|||||||
if p.BaseURL != "" {
|
if p.BaseURL != "" {
|
||||||
url = p.BaseURL
|
url = p.BaseURL
|
||||||
}
|
}
|
||||||
cfg := &ModelConfig{
|
cfg := &provider.ModelConfig{
|
||||||
Client: &openai.OpenAIClient{
|
Client: &openai.OpenAIClient{
|
||||||
BaseURL: url,
|
BaseURL: url,
|
||||||
APIKey: p.APIKey,
|
APIKey: p.APIKey,
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ type OpenAIClient struct {
|
|||||||
type ChatCompletionMessage struct {
|
type ChatCompletionMessage struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content,omitempty"`
|
||||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
Reasoning string `json:"reasoning,omitempty"` // OpenRouter
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"` // Deepseek, llama-server
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -59,6 +60,13 @@ type Tool struct {
|
|||||||
Function FunctionDefinition `json:"function"`
|
Function FunctionDefinition `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenRouterReasoning struct {
|
||||||
|
Effort provider.ReasoningEffort `json:"effort,omitempty"` // "high", "medium", "low"
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit
|
||||||
|
Exclude bool `json:"exclude,omitempty"` // Exclude reasoning tokens from response
|
||||||
|
Enabled bool `json:"enabled,omitempty"` // Enable reasoning (default: inferred)
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
@@ -68,6 +76,8 @@ type ChatCompletionRequest struct {
|
|||||||
Tools []Tool `json:"tools,omitempty"`
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
ToolChoice string `json:"tool_choice,omitempty"`
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
// Reasoning config. TBD: handle for mulitple providers using the same field
|
||||||
|
Reasoning *OpenRouterReasoning `json:"reasoning,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionChoice struct {
|
type ChatCompletionChoice struct {
|
||||||
@@ -185,6 +195,16 @@ func createChatCompletionRequest(
|
|||||||
request.ToolChoice = "auto"
|
request.ToolChoice = "auto"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add OpenRouter reasoning config
|
||||||
|
if params.Provider.Kind == provider.OpenRouter {
|
||||||
|
request.Reasoning = &OpenRouterReasoning{
|
||||||
|
Effort: params.Reasoning.Effort,
|
||||||
|
MaxTokens: params.Reasoning.MaxTokens,
|
||||||
|
Exclude: params.Reasoning.Exclude,
|
||||||
|
Enabled: params.Reasoning.Enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return request
|
return request
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,13 +263,20 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
choice := completionResp.Choices[0]
|
choice := completionResp.Choices[0]
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
|
||||||
var content string
|
var content string
|
||||||
lastMessage := messages[len(messages)-1]
|
var reasoning string
|
||||||
|
// Check if last message was a pre-fill
|
||||||
if lastMessage.Role.IsAssistant() {
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
// Append new contents to previous last message
|
||||||
content = lastMessage.Content + choice.Message.Content
|
content = lastMessage.Content + choice.Message.Content
|
||||||
|
// TBD: reasoning
|
||||||
} else {
|
} else {
|
||||||
content = choice.Message.Content
|
content = choice.Message.Content
|
||||||
|
if len(choice.Message.Reasoning) > 0 {
|
||||||
|
reasoning = choice.Message.Reasoning
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCalls := choice.Message.ToolCalls
|
toolCalls := choice.Message.ToolCalls
|
||||||
@@ -257,7 +284,7 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.NewMessageWithAssistant(content, ""), nil
|
return api.NewMessageWithAssistant(content, reasoning), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
@@ -284,7 +311,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
toolCalls := []ToolCall{}
|
toolCalls := []ToolCall{}
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
lastMessage := messages[len(messages)-1]
|
||||||
|
// Check if this was a prefill
|
||||||
if lastMessage.Role.IsAssistant() {
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
// Append the last message's contents to the buffer
|
||||||
content.WriteString(lastMessage.Content)
|
content.WriteString(lastMessage.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -342,10 +371,23 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
reasoning.WriteString(delta.ReasoningContent)
|
reasoning.WriteString(delta.ReasoningContent)
|
||||||
}
|
}
|
||||||
|
// Handle reasoning field in stream response
|
||||||
|
if len(delta.Reasoning) > 0 {
|
||||||
|
output <- provider.Chunk{
|
||||||
|
ReasoningContent: delta.Reasoning,
|
||||||
|
TokenCount: 1,
|
||||||
|
}
|
||||||
|
reasoning.WriteString(delta.Reasoning)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
msg := api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msg.ReasoningContent = reasoning.String()
|
||||||
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil
|
return api.NewMessageWithAssistant(content.String(), reasoning.String()), nil
|
||||||
|
|||||||
@@ -12,8 +12,61 @@ type Chunk struct {
|
|||||||
TokenCount uint
|
TokenCount uint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelConfig struct {
|
||||||
|
Provider string
|
||||||
|
Client ChatCompletionProvider
|
||||||
|
Model string
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float32
|
||||||
|
Reasoning bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestParameters(modelConfig ModelConfig) RequestParameters {
|
||||||
|
params := RequestParameters{
|
||||||
|
Model: modelConfig.Model,
|
||||||
|
MaxTokens: modelConfig.MaxTokens,
|
||||||
|
Temperature: modelConfig.Temperature,
|
||||||
|
Reasoning: ReasoningConfig{
|
||||||
|
Enabled: modelConfig.Reasoning,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReasoningEffort string
|
||||||
|
|
||||||
|
const (
|
||||||
|
High ReasoningEffort = "high"
|
||||||
|
Medium ReasoningEffort = "medium"
|
||||||
|
Low ReasoningEffort = "low"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderKind is a bit leaky, it informs the ChatCompletionProvider what
|
||||||
|
// provider we're on so we know how to format requests, etc.
|
||||||
|
type ProviderKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpenRouter ProviderKind = "openrouter"
|
||||||
|
OpenAI ProviderKind = "openai"
|
||||||
|
Anthropic ProviderKind = "anthropic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProviderConfig struct {
|
||||||
|
Kind ProviderKind
|
||||||
|
SupportPrefill bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReasoningConfig struct {
|
||||||
|
Effort ReasoningEffort
|
||||||
|
MaxTokens int
|
||||||
|
Exclude bool
|
||||||
|
Enabled bool
|
||||||
|
}
|
||||||
|
|
||||||
type RequestParameters struct {
|
type RequestParameters struct {
|
||||||
Model string
|
Provider ProviderConfig
|
||||||
|
Model string
|
||||||
|
Reasoning ReasoningConfig
|
||||||
|
|
||||||
MaxTokens int
|
MaxTokens int
|
||||||
Temperature float32
|
Temperature float32
|
||||||
|
|||||||
@@ -263,11 +263,7 @@ func (a *AppModel) Prompt(
|
|||||||
}
|
}
|
||||||
p := modelConfig.Client
|
p := modelConfig.Client
|
||||||
|
|
||||||
params := provider.RequestParameters{
|
params := provider.NewRequestParameters(*modelConfig)
|
||||||
Model: modelConfig.Model,
|
|
||||||
MaxTokens: modelConfig.MaxTokens,
|
|
||||||
Temperature: modelConfig.Temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.Agent != nil {
|
if a.Agent != nil {
|
||||||
params.Toolbox = a.Agent.Toolbox
|
params.Toolbox = a.Agent.Toolbox
|
||||||
|
|||||||
Reference in New Issue
Block a user