Add support for per-model configuration
A provider's models may now be provided in the config as a simple string, or as a mapping with at least a `name` key. This is a valid configuration: ```yaml models: - name: model-a reasoning: true - model-b ``` This opens the door to providing model-specific configuration, e.g. an overridden max tokens count, or whether the model is a 'reasoning' model.
This commit is contained in:
@@ -8,6 +8,27 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
Name string
|
||||
Reasoning bool
|
||||
MaxTokens int
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
Name string `yaml:"name,omitempty"`
|
||||
Display string `yaml:"display,omitempty"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl,omitempty"`
|
||||
APIKey string `yaml:"apiKey,omitempty"`
|
||||
RawModels []any `yaml:"models"` // Raw models from YAML
|
||||
_models []Model // Parsed models for internal use
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
}
|
||||
|
||||
func (p *Provider) Models() []Model {
|
||||
return p._models
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Defaults *struct {
|
||||
Model *string `yaml:"model" default:"gpt-4"`
|
||||
@@ -29,47 +50,110 @@ type Config struct {
|
||||
SystemPrompt string `yaml:"systemPrompt"`
|
||||
Tools []string `yaml:"tools"`
|
||||
} `yaml:"agents"`
|
||||
Providers []*struct {
|
||||
Name string `yaml:"name,omitempty"`
|
||||
Display string `yaml:"display,omitempty"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl,omitempty"`
|
||||
APIKey string `yaml:"apiKey,omitempty"`
|
||||
Models []string `yaml:"models"`
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
} `yaml:"providers"`
|
||||
Providers []*Provider `yaml:"providers"`
|
||||
}
|
||||
|
||||
// Helper function to parse the Models field for a provider
|
||||
func parseModels(rawModels []any) ([]Model, error) {
|
||||
if rawModels == nil {
|
||||
return nil, fmt.Errorf("No models to parse")
|
||||
}
|
||||
|
||||
models := make([]Model, 0, len(rawModels))
|
||||
for i, modelInterface := range rawModels {
|
||||
var parsedModel Model
|
||||
|
||||
switch rawModel := modelInterface.(type) {
|
||||
case string:
|
||||
parsedModel.Name = rawModel
|
||||
case map[string]any:
|
||||
// Case 2: Model defined as a map
|
||||
if nameVal, ok := rawModel["name"]; ok {
|
||||
if nameStr, ok := nameVal.(string); ok {
|
||||
parsedModel.Name = nameStr
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid 'name' type (%T) for model index %d", nameVal, i)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("Missing 'name' for model index %d", i)
|
||||
}
|
||||
|
||||
if reasoningVal, ok := rawModel["reasoning"]; ok {
|
||||
if reasoningBool, ok := reasoningVal.(bool); ok {
|
||||
parsedModel.Reasoning = reasoningBool
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
|
||||
}
|
||||
} // else: default is false
|
||||
|
||||
if maxTokensVal, ok := rawModel["maxTokens"]; ok {
|
||||
// YAML numbers often unmarshal as int, sometimes float64. Handle int primarily.
|
||||
if maxTokensInt, ok := maxTokensVal.(int); ok {
|
||||
parsedModel.MaxTokens = maxTokensInt
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid 'maxTokens' type (%T) for model '%s'", maxTokensVal, parsedModel.Name)
|
||||
}
|
||||
} // else: default is 0
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i)
|
||||
continue // Skip this unknown model definition format
|
||||
}
|
||||
|
||||
models = append(models, parsedModel)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func NewConfig(configFile string) (*Config, error) {
|
||||
shouldWriteDefaults := false
|
||||
c := &Config{}
|
||||
|
||||
configExists := true
|
||||
configBytes, err := os.ReadFile(configFile)
|
||||
if os.IsNotExist(err) {
|
||||
configExists = false
|
||||
fmt.Printf("Configuration file not found at %s. Creating with defaults.\n", configFile)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
||||
return nil, fmt.Errorf("Could not read config file '%s': %w", configFile, err)
|
||||
} else {
|
||||
yaml.Unmarshal(configBytes, c)
|
||||
err = yaml.Unmarshal(configBytes, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not parse config file '%s': %w", configFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
shouldWriteDefaults = util.SetStructDefaults(c)
|
||||
if !configExists || shouldWriteDefaults {
|
||||
if configExists {
|
||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
|
||||
os.Rename(configFile, configFile+".bak")
|
||||
}
|
||||
fmt.Printf("Writing configuration file to %s\n", configFile)
|
||||
// Update the config with default values
|
||||
util.SetStructDefaults(c)
|
||||
|
||||
// Create a default config file
|
||||
if !configExists {
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||
return nil, fmt.Errorf("Could not open config file '%s' for writing: %w", configFile, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
encoder := yaml.NewEncoder(file)
|
||||
encoder.SetIndent(2)
|
||||
err = encoder.Encode(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||
file.Close()
|
||||
os.Remove(configFile)
|
||||
return nil, fmt.Errorf("Could not save default configuration to '%s': %w", configFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range c.Providers {
|
||||
p._models, err = parseModels(p.RawModels)
|
||||
if err != nil {
|
||||
name := p.Name
|
||||
if name == "" {
|
||||
name = p.Display
|
||||
}
|
||||
if name == "" {
|
||||
name = p.Kind
|
||||
}
|
||||
return nil, fmt.Errorf("Failed to models for provider '%s': %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,12 +12,12 @@ import (
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -33,9 +33,9 @@ type Agent struct {
|
||||
|
||||
type Context struct {
|
||||
// high level app configuration, may be mutated at runtime
|
||||
Config Config
|
||||
Config Config
|
||||
Conversations conversation.Repo
|
||||
Chroma *tty.ChromaHighlighter
|
||||
Chroma tty.ChromaHighlighter
|
||||
}
|
||||
|
||||
func NewContext() (*Context, error) {
|
||||
@@ -51,7 +51,7 @@ func NewContext() (*Context, error) {
|
||||
}
|
||||
|
||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
return &Context{*config, store, chroma}, nil
|
||||
return &Context{*config, store, *chroma}, nil
|
||||
}
|
||||
|
||||
func createOrOpenAppend(path string) (*os.File, error) {
|
||||
@@ -102,9 +102,9 @@ func (c *Context) GetModels() (models []string) {
|
||||
name = p.Name
|
||||
}
|
||||
|
||||
for _, m := range p.Models {
|
||||
modelCounts[m]++
|
||||
models = append(models, fmt.Sprintf("%s@%s", m, name))
|
||||
for _, m := range p.Models() {
|
||||
modelCounts[m.Name]++
|
||||
models = append(models, fmt.Sprintf("%s@%s", m.Name, name))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,8 +180,8 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range p.Models {
|
||||
if m == model {
|
||||
for _, m := range p.Models() {
|
||||
if m.Name == model {
|
||||
switch p.Kind {
|
||||
case "anthropic":
|
||||
url := "https://api.anthropic.com"
|
||||
|
||||
@@ -103,10 +103,10 @@ func (m *Model) getModelOptions() []list.OptionGroup {
|
||||
group := list.OptionGroup{
|
||||
Name: providerLabel,
|
||||
}
|
||||
for _, model := range p.Models {
|
||||
for _, model := range p.Models() {
|
||||
group.Options = append(group.Options, list.Option{
|
||||
Label: model,
|
||||
Value: modelOpt{provider, model},
|
||||
Label: model.Name,
|
||||
Value: modelOpt{provider, model.Name},
|
||||
})
|
||||
}
|
||||
modelOpts = append(modelOpts, group)
|
||||
|
||||
Reference in New Issue
Block a user