From 3cd897d494e4c3959b5927dd893d04115586f5ad Mon Sep 17 00:00:00 2001 From: Matt Low Date: Wed, 25 Jun 2025 07:49:29 +0000 Subject: [PATCH] Properly support per-model maxTokens/temperature --- pkg/cmd/util/util.go | 14 ++++---- pkg/lmcli/lmcli.go | 79 ++++++++++++++++++++++++++++++++---------- pkg/tui/model/model.go | 15 ++++---- 3 files changed, 76 insertions(+), 32 deletions(-) diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 19f481a..69d26ba 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -19,15 +19,16 @@ import ( // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) { - m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") + modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") if err != nil { return nil, err } + p := modelConfig.Client params := provider.RequestParameters{ - Model: m, - MaxTokens: *ctx.Config.Defaults.MaxTokens, - Temperature: *ctx.Config.Defaults.Temperature, + Model: modelConfig.Model, + MaxTokens: modelConfig.MaxTokens, + Temperature: modelConfig.Temperature, } system := ctx.DefaultSystemPrompt() @@ -206,15 +207,16 @@ Example response: }, } - m, _, p, err := ctx.GetModelProvider( + modelConfig, err := ctx.GetModelProvider( *ctx.Config.Conversations.TitleGenerationModel, "", ) if err != nil { return "", err } + p := modelConfig.Client requestParams := provider.RequestParameters{ - Model: m, + Model: modelConfig.Model, MaxTokens: 25, } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 10e3402..c6fa448 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -139,7 +139,35 @@ func (c *Context) DefaultSystemPrompt() string { return c.Config.Defaults.SystemPrompt } -func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) { +type ModelConfig struct { + Provider string + Client provider.ChatCompletionProvider + Model string + MaxTokens int + Temperature float32 +} + +func (c *Context) fillModelConfig(cfg *ModelConfig, m Model) *ModelConfig { + // Set model name + cfg.Model = m.Name + + // Set max tokens + if m.MaxTokens == nil { + cfg.MaxTokens = *c.Config.Defaults.MaxTokens + } else { + cfg.MaxTokens = *m.MaxTokens + } + + // Set temperature + if m.Temperature == nil { + cfg.Temperature = *c.Config.Defaults.Temperature + } else { + cfg.Temperature = *m.Temperature + } + return cfg +} + +func (c *Context) GetModelProvider(model string, provider string) (*ModelConfig, error) { parts := strings.Split(model, "@") if provider == "" && len(parts) > 1 { @@ -158,6 +186,7 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin } for _, m := range p.Models() { + var cfg *ModelConfig if m.Name == model { switch p.Kind { case "anthropic": @@ -165,44 +194,56 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin if p.BaseURL != "" { url = p.BaseURL } - return model, name, &anthropic.AnthropicClient{ - BaseURL: url, - APIKey: p.APIKey, - }, nil + cfg = &ModelConfig{ + Client: &anthropic.AnthropicClient{ + BaseURL: url, + APIKey: p.APIKey, + }, + } + return c.fillModelConfig(cfg, m), nil case "google": url := "https://generativelanguage.googleapis.com" if p.BaseURL != "" { url = p.BaseURL } - return model, name, &google.Client{ - BaseURL: url, - APIKey: p.APIKey, - }, nil + cfg := &ModelConfig{ + Client: &google.Client{ + BaseURL: url, + APIKey: p.APIKey, + }, + } + return c.fillModelConfig(cfg, m), nil case "ollama": url := "http://localhost:11434/api" if p.BaseURL != "" { url = p.BaseURL } - return model, name, &ollama.OllamaClient{ - BaseURL: url, - }, nil + cfg := &ModelConfig{ + Client: &ollama.OllamaClient{ + BaseURL: url, + }, + } + return c.fillModelConfig(cfg, m), nil case "openai": url := "https://api.openai.com" if p.BaseURL != "" { url = p.BaseURL } - return model, name, &openai.OpenAIClient{ - BaseURL: url, - APIKey: p.APIKey, - Headers: p.Headers, - }, nil + cfg := &ModelConfig{ + Client: &openai.OpenAIClient{ + BaseURL: url, + APIKey: p.APIKey, + Headers: p.Headers, + }, + } + return c.fillModelConfig(cfg, m), nil default: - return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) + return nil, fmt.Errorf("unknown provider kind: %s", p.Kind) } } } } - return "", "", nil, fmt.Errorf("unknown model: %s", model) + return nil, fmt.Errorf("unknown model: %s", model) } func Fatal(format string, args ...any) { diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 95a3f61..dbd50e8 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -39,9 +39,9 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat } - model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") - app.Model = model - app.ProviderName = provider + modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") + app.Model = modelConfig.Model + app.ProviderName = modelConfig.Provider app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent) return app } @@ -257,15 +257,16 @@ func (a *AppModel) Prompt( chatReplyChunks chan provider.Chunk, stopSignal chan struct{}, ) (*conversation.Message, error) { - model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) + modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) if err != nil { return nil, err } + p := modelConfig.Client params := provider.RequestParameters{ - Model: model, - MaxTokens: *a.Ctx.Config.Defaults.MaxTokens, - Temperature: *a.Ctx.Config.Defaults.Temperature, + Model: modelConfig.Model, + MaxTokens: modelConfig.MaxTokens, + Temperature: modelConfig.Temperature, } if a.Agent != nil {