Private
Public Access
1
0

Add support for per-model configuration

A provider's models may now be provided in the config as a simple
string, or as a mapping with at least a `name` key.

This is a valid configuration:

```yaml
models:
- name: model-a
  reasoning: true
- model-b
```

This opens the door to providing model-specific configuration, e.g. an
overridden max tokens count, or whether the model is a 'reasoning'
model.
This commit is contained in:
2025-03-29 23:25:01 +00:00
parent 43f8de89c5
commit 7c3a2c3cb9
3 changed files with 117 additions and 33 deletions

View File

@@ -8,6 +8,27 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type Model struct {
Name string
Reasoning bool
MaxTokens int
}
type Provider struct {
Name string `yaml:"name,omitempty"`
Display string `yaml:"display,omitempty"`
Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"`
RawModels []any `yaml:"models"` // Raw models from YAML
_models []Model // Parsed models for internal use
Headers map[string]string `yaml:"headers"`
}
func (p *Provider) Models() []Model {
return p._models
}
type Config struct { type Config struct {
Defaults *struct { Defaults *struct {
Model *string `yaml:"model" default:"gpt-4"` Model *string `yaml:"model" default:"gpt-4"`
@@ -29,47 +50,110 @@ type Config struct {
SystemPrompt string `yaml:"systemPrompt"` SystemPrompt string `yaml:"systemPrompt"`
Tools []string `yaml:"tools"` Tools []string `yaml:"tools"`
} `yaml:"agents"` } `yaml:"agents"`
Providers []*struct { Providers []*Provider `yaml:"providers"`
Name string `yaml:"name,omitempty"` }
Display string `yaml:"display,omitempty"`
Kind string `yaml:"kind"` // Helper function to parse the Models field for a provider
BaseURL string `yaml:"baseUrl,omitempty"` func parseModels(rawModels []any) ([]Model, error) {
APIKey string `yaml:"apiKey,omitempty"` if rawModels == nil {
Models []string `yaml:"models"` return nil, fmt.Errorf("No models to parse")
Headers map[string]string `yaml:"headers"` }
} `yaml:"providers"`
models := make([]Model, 0, len(rawModels))
for i, modelInterface := range rawModels {
var parsedModel Model
switch rawModel := modelInterface.(type) {
case string:
parsedModel.Name = rawModel
case map[string]any:
// Case 2: Model defined as a map
if nameVal, ok := rawModel["name"]; ok {
if nameStr, ok := nameVal.(string); ok {
parsedModel.Name = nameStr
} else {
return nil, fmt.Errorf("Invalid 'name' type (%T) for model index %d", nameVal, i)
}
} else {
return nil, fmt.Errorf("Missing 'name' for model index %d", i)
}
if reasoningVal, ok := rawModel["reasoning"]; ok {
if reasoningBool, ok := reasoningVal.(bool); ok {
parsedModel.Reasoning = reasoningBool
} else {
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
}
} // else: default is false
if maxTokensVal, ok := rawModel["maxTokens"]; ok {
// YAML numbers often unmarshal as int, sometimes float64. Handle int primarily.
if maxTokensInt, ok := maxTokensVal.(int); ok {
parsedModel.MaxTokens = maxTokensInt
} else {
return nil, fmt.Errorf("Invalid 'maxTokens' type (%T) for model '%s'", maxTokensVal, parsedModel.Name)
}
} // else: default is 0
default:
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i)
continue // Skip this unknown model definition format
}
models = append(models, parsedModel)
}
return models, nil
} }
func NewConfig(configFile string) (*Config, error) { func NewConfig(configFile string) (*Config, error) {
shouldWriteDefaults := false
c := &Config{} c := &Config{}
configExists := true configExists := true
configBytes, err := os.ReadFile(configFile) configBytes, err := os.ReadFile(configFile)
if os.IsNotExist(err) { if os.IsNotExist(err) {
configExists = false configExists = false
fmt.Printf("Configuration file not found at %s. Creating with defaults.\n", configFile)
} else if err != nil { } else if err != nil {
return nil, fmt.Errorf("Could not read config file: %v", err) return nil, fmt.Errorf("Could not read config file '%s': %w", configFile, err)
} else { } else {
yaml.Unmarshal(configBytes, c) err = yaml.Unmarshal(configBytes, c)
if err != nil {
return nil, fmt.Errorf("Could not parse config file '%s': %w", configFile, err)
}
} }
shouldWriteDefaults = util.SetStructDefaults(c) // Update the config with default values
if !configExists || shouldWriteDefaults { util.SetStructDefaults(c)
if configExists {
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak") // Create a default config file
os.Rename(configFile, configFile+".bak") if !configExists {
}
fmt.Printf("Writing configuration file to %s\n", configFile)
file, err := os.Create(configFile) file, err := os.Create(configFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err) return nil, fmt.Errorf("Could not open config file '%s' for writing: %w", configFile, err)
} }
defer file.Close()
encoder := yaml.NewEncoder(file) encoder := yaml.NewEncoder(file)
encoder.SetIndent(2) encoder.SetIndent(2)
err = encoder.Encode(c) err = encoder.Encode(c)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err) file.Close()
os.Remove(configFile)
return nil, fmt.Errorf("Could not save default configuration to '%s': %w", configFile, err)
}
}
for _, p := range c.Providers {
p._models, err = parseModels(p.RawModels)
if err != nil {
name := p.Name
if name == "" {
name = p.Display
}
if name == "" {
name = p.Kind
}
return nil, fmt.Errorf("Failed to models for provider '%s': %v", name, err)
} }
} }

View File

@@ -12,12 +12,12 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/provider" "git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/provider/google" "git.mlow.ca/mlow/lmcli/pkg/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@@ -33,9 +33,9 @@ type Agent struct {
type Context struct { type Context struct {
// high level app configuration, may be mutated at runtime // high level app configuration, may be mutated at runtime
Config Config Config Config
Conversations conversation.Repo Conversations conversation.Repo
Chroma *tty.ChromaHighlighter Chroma tty.ChromaHighlighter
} }
func NewContext() (*Context, error) { func NewContext() (*Context, error) {
@@ -51,7 +51,7 @@ func NewContext() (*Context, error) {
} }
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
return &Context{*config, store, chroma}, nil return &Context{*config, store, *chroma}, nil
} }
func createOrOpenAppend(path string) (*os.File, error) { func createOrOpenAppend(path string) (*os.File, error) {
@@ -102,9 +102,9 @@ func (c *Context) GetModels() (models []string) {
name = p.Name name = p.Name
} }
for _, m := range p.Models { for _, m := range p.Models() {
modelCounts[m]++ modelCounts[m.Name]++
models = append(models, fmt.Sprintf("%s@%s", m, name)) models = append(models, fmt.Sprintf("%s@%s", m.Name, name))
} }
} }
@@ -180,8 +180,8 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
continue continue
} }
for _, m := range p.Models { for _, m := range p.Models() {
if m == model { if m.Name == model {
switch p.Kind { switch p.Kind {
case "anthropic": case "anthropic":
url := "https://api.anthropic.com" url := "https://api.anthropic.com"

View File

@@ -103,10 +103,10 @@ func (m *Model) getModelOptions() []list.OptionGroup {
group := list.OptionGroup{ group := list.OptionGroup{
Name: providerLabel, Name: providerLabel,
} }
for _, model := range p.Models { for _, model := range p.Models() {
group.Options = append(group.Options, list.Option{ group.Options = append(group.Options, list.Option{
Label: model, Label: model.Name,
Value: modelOpt{provider, model}, Value: modelOpt{provider, model.Name},
}) })
} }
modelOpts = append(modelOpts, group) modelOpts = append(modelOpts, group)