From 7c3a2c3cb95cfe9248214193606f4a1dafcae8ad Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 29 Mar 2025 23:25:01 +0000 Subject: [PATCH] Add support for per-model configuration A provider's models may now be provided in the config as a simple string, or as a mapping with at least a `name` key. This is a valid configuration: ```yaml models: - name: model-a reasoning: true - model-b ``` This opens the door to providing model-specific configuration, e.g. an overridden max tokens count, or whether the model is a 'reasoning' model. --- pkg/lmcli/config.go | 126 ++++++++++++++++++++++++----- pkg/lmcli/lmcli.go | 18 ++--- pkg/tui/views/settings/settings.go | 6 +- 3 files changed, 117 insertions(+), 33 deletions(-) diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index 914ae8d..0208a8d 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -8,6 +8,27 @@ import ( "gopkg.in/yaml.v3" ) +type Model struct { + Name string + Reasoning bool + MaxTokens int +} + +type Provider struct { + Name string `yaml:"name,omitempty"` + Display string `yaml:"display,omitempty"` + Kind string `yaml:"kind"` + BaseURL string `yaml:"baseUrl,omitempty"` + APIKey string `yaml:"apiKey,omitempty"` + RawModels []any `yaml:"models"` // Raw models from YAML + _models []Model // Parsed models for internal use + Headers map[string]string `yaml:"headers"` +} + +func (p *Provider) Models() []Model { + return p._models +} + type Config struct { Defaults *struct { Model *string `yaml:"model" default:"gpt-4"` @@ -29,47 +50,110 @@ type Config struct { SystemPrompt string `yaml:"systemPrompt"` Tools []string `yaml:"tools"` } `yaml:"agents"` - Providers []*struct { - Name string `yaml:"name,omitempty"` - Display string `yaml:"display,omitempty"` - Kind string `yaml:"kind"` - BaseURL string `yaml:"baseUrl,omitempty"` - APIKey string `yaml:"apiKey,omitempty"` - Models []string `yaml:"models"` - Headers map[string]string `yaml:"headers"` - } `yaml:"providers"` + Providers []*Provider `yaml:"providers"` +} + +// Helper function to parse the Models field for a provider +func parseModels(rawModels []any) ([]Model, error) { + if rawModels == nil { + return nil, fmt.Errorf("No models to parse") + } + + models := make([]Model, 0, len(rawModels)) + for i, modelInterface := range rawModels { + var parsedModel Model + + switch rawModel := modelInterface.(type) { + case string: + parsedModel.Name = rawModel + case map[string]any: + // Case 2: Model defined as a map + if nameVal, ok := rawModel["name"]; ok { + if nameStr, ok := nameVal.(string); ok { + parsedModel.Name = nameStr + } else { + return nil, fmt.Errorf("Invalid 'name' type (%T) for model index %d", nameVal, i) + } + } else { + return nil, fmt.Errorf("Missing 'name' for model index %d", i) + } + + if reasoningVal, ok := rawModel["reasoning"]; ok { + if reasoningBool, ok := reasoningVal.(bool); ok { + parsedModel.Reasoning = reasoningBool + } else { + return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name) + } + } // else: default is false + + if maxTokensVal, ok := rawModel["maxTokens"]; ok { + // YAML numbers often unmarshal as int, sometimes float64. Handle int primarily. + if maxTokensInt, ok := maxTokensVal.(int); ok { + parsedModel.MaxTokens = maxTokensInt + } else { + return nil, fmt.Errorf("Invalid 'maxTokens' type (%T) for model '%s'", maxTokensVal, parsedModel.Name) + } + } // else: default is 0 + + default: + return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i) + continue // Skip this unknown model definition format + } + + models = append(models, parsedModel) + } + return models, nil } func NewConfig(configFile string) (*Config, error) { - shouldWriteDefaults := false c := &Config{} configExists := true configBytes, err := os.ReadFile(configFile) if os.IsNotExist(err) { configExists = false + fmt.Printf("Configuration file not found at %s. Creating with defaults.\n", configFile) } else if err != nil { - return nil, fmt.Errorf("Could not read config file: %v", err) + return nil, fmt.Errorf("Could not read config file '%s': %w", configFile, err) } else { - yaml.Unmarshal(configBytes, c) + err = yaml.Unmarshal(configBytes, c) + if err != nil { + return nil, fmt.Errorf("Could not parse config file '%s': %w", configFile, err) + } } - shouldWriteDefaults = util.SetStructDefaults(c) - if !configExists || shouldWriteDefaults { - if configExists { - fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak") - os.Rename(configFile, configFile+".bak") - } - fmt.Printf("Writing configuration file to %s\n", configFile) + // Update the config with default values + util.SetStructDefaults(c) + + // Create a default config file + if !configExists { file, err := os.Create(configFile) if err != nil { - return nil, fmt.Errorf("Could not open config file for writing: %v", err) + return nil, fmt.Errorf("Could not open config file '%s' for writing: %w", configFile, err) } + defer file.Close() + encoder := yaml.NewEncoder(file) encoder.SetIndent(2) err = encoder.Encode(c) if err != nil { - return nil, fmt.Errorf("Could not save default configuration: %v", err) + file.Close() + os.Remove(configFile) + return nil, fmt.Errorf("Could not save default configuration to '%s': %w", configFile, err) + } + } + + for _, p := range c.Providers { + p._models, err = parseModels(p.RawModels) + if err != nil { + name := p.Name + if name == "" { + name = p.Display + } + if name == "" { + name = p.Kind + } + return nil, fmt.Errorf("Failed to models for provider '%s': %v", name, err) } } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index c216278..1522419 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -12,12 +12,12 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/provider" "git.mlow.ca/mlow/lmcli/pkg/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/provider/google" "git.mlow.ca/mlow/lmcli/pkg/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/provider/openai" - "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util/tty" "gorm.io/driver/sqlite" @@ -33,9 +33,9 @@ type Agent struct { type Context struct { // high level app configuration, may be mutated at runtime - Config Config + Config Config Conversations conversation.Repo - Chroma *tty.ChromaHighlighter + Chroma tty.ChromaHighlighter } func NewContext() (*Context, error) { @@ -51,7 +51,7 @@ func NewContext() (*Context, error) { } chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) - return &Context{*config, store, chroma}, nil + return &Context{*config, store, *chroma}, nil } func createOrOpenAppend(path string) (*os.File, error) { @@ -102,9 +102,9 @@ func (c *Context) GetModels() (models []string) { name = p.Name } - for _, m := range p.Models { - modelCounts[m]++ - models = append(models, fmt.Sprintf("%s@%s", m, name)) + for _, m := range p.Models() { + modelCounts[m.Name]++ + models = append(models, fmt.Sprintf("%s@%s", m.Name, name)) } } @@ -180,8 +180,8 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin continue } - for _, m := range p.Models { - if m == model { + for _, m := range p.Models() { + if m.Name == model { switch p.Kind { case "anthropic": url := "https://api.anthropic.com" diff --git a/pkg/tui/views/settings/settings.go b/pkg/tui/views/settings/settings.go index e69a527..ff4fae8 100644 --- a/pkg/tui/views/settings/settings.go +++ b/pkg/tui/views/settings/settings.go @@ -103,10 +103,10 @@ func (m *Model) getModelOptions() []list.OptionGroup { group := list.OptionGroup{ Name: providerLabel, } - for _, model := range p.Models { + for _, model := range p.Models() { group.Options = append(group.Options, list.Option{ - Label: model, - Value: modelOpt{provider, model}, + Label: model.Name, + Value: modelOpt{provider, model.Name}, }) } modelOpts = append(modelOpts, group)