Add support for per-model configuration
A provider's models may now be provided in the config as a simple string, or as a mapping with at least a `name` key. This is a valid configuration: ```yaml models: - name: model-a reasoning: true - model-b ``` This opens the door to providing model-specific configuration, e.g. an overridden max tokens count, or whether the model is a 'reasoning' model.
This commit is contained in:
@@ -8,6 +8,27 @@ import (
|
|||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Name string
|
||||||
|
Reasoning bool
|
||||||
|
MaxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Provider struct {
|
||||||
|
Name string `yaml:"name,omitempty"`
|
||||||
|
Display string `yaml:"display,omitempty"`
|
||||||
|
Kind string `yaml:"kind"`
|
||||||
|
BaseURL string `yaml:"baseUrl,omitempty"`
|
||||||
|
APIKey string `yaml:"apiKey,omitempty"`
|
||||||
|
RawModels []any `yaml:"models"` // Raw models from YAML
|
||||||
|
_models []Model // Parsed models for internal use
|
||||||
|
Headers map[string]string `yaml:"headers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) Models() []Model {
|
||||||
|
return p._models
|
||||||
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Defaults *struct {
|
Defaults *struct {
|
||||||
Model *string `yaml:"model" default:"gpt-4"`
|
Model *string `yaml:"model" default:"gpt-4"`
|
||||||
@@ -29,47 +50,110 @@ type Config struct {
|
|||||||
SystemPrompt string `yaml:"systemPrompt"`
|
SystemPrompt string `yaml:"systemPrompt"`
|
||||||
Tools []string `yaml:"tools"`
|
Tools []string `yaml:"tools"`
|
||||||
} `yaml:"agents"`
|
} `yaml:"agents"`
|
||||||
Providers []*struct {
|
Providers []*Provider `yaml:"providers"`
|
||||||
Name string `yaml:"name,omitempty"`
|
}
|
||||||
Display string `yaml:"display,omitempty"`
|
|
||||||
Kind string `yaml:"kind"`
|
// Helper function to parse the Models field for a provider
|
||||||
BaseURL string `yaml:"baseUrl,omitempty"`
|
func parseModels(rawModels []any) ([]Model, error) {
|
||||||
APIKey string `yaml:"apiKey,omitempty"`
|
if rawModels == nil {
|
||||||
Models []string `yaml:"models"`
|
return nil, fmt.Errorf("No models to parse")
|
||||||
Headers map[string]string `yaml:"headers"`
|
}
|
||||||
} `yaml:"providers"`
|
|
||||||
|
models := make([]Model, 0, len(rawModels))
|
||||||
|
for i, modelInterface := range rawModels {
|
||||||
|
var parsedModel Model
|
||||||
|
|
||||||
|
switch rawModel := modelInterface.(type) {
|
||||||
|
case string:
|
||||||
|
parsedModel.Name = rawModel
|
||||||
|
case map[string]any:
|
||||||
|
// Case 2: Model defined as a map
|
||||||
|
if nameVal, ok := rawModel["name"]; ok {
|
||||||
|
if nameStr, ok := nameVal.(string); ok {
|
||||||
|
parsedModel.Name = nameStr
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Invalid 'name' type (%T) for model index %d", nameVal, i)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Missing 'name' for model index %d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reasoningVal, ok := rawModel["reasoning"]; ok {
|
||||||
|
if reasoningBool, ok := reasoningVal.(bool); ok {
|
||||||
|
parsedModel.Reasoning = reasoningBool
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
|
||||||
|
}
|
||||||
|
} // else: default is false
|
||||||
|
|
||||||
|
if maxTokensVal, ok := rawModel["maxTokens"]; ok {
|
||||||
|
// YAML numbers often unmarshal as int, sometimes float64. Handle int primarily.
|
||||||
|
if maxTokensInt, ok := maxTokensVal.(int); ok {
|
||||||
|
parsedModel.MaxTokens = maxTokensInt
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Invalid 'maxTokens' type (%T) for model '%s'", maxTokensVal, parsedModel.Name)
|
||||||
|
}
|
||||||
|
} // else: default is 0
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d in provider '%s'", modelInterface, i)
|
||||||
|
continue // Skip this unknown model definition format
|
||||||
|
}
|
||||||
|
|
||||||
|
models = append(models, parsedModel)
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfig(configFile string) (*Config, error) {
|
func NewConfig(configFile string) (*Config, error) {
|
||||||
shouldWriteDefaults := false
|
|
||||||
c := &Config{}
|
c := &Config{}
|
||||||
|
|
||||||
configExists := true
|
configExists := true
|
||||||
configBytes, err := os.ReadFile(configFile)
|
configBytes, err := os.ReadFile(configFile)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
configExists = false
|
configExists = false
|
||||||
|
fmt.Printf("Configuration file not found at %s. Creating with defaults.\n", configFile)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
return nil, fmt.Errorf("Could not read config file '%s': %w", configFile, err)
|
||||||
} else {
|
} else {
|
||||||
yaml.Unmarshal(configBytes, c)
|
err = yaml.Unmarshal(configBytes, c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not parse config file '%s': %w", configFile, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shouldWriteDefaults = util.SetStructDefaults(c)
|
// Update the config with default values
|
||||||
if !configExists || shouldWriteDefaults {
|
util.SetStructDefaults(c)
|
||||||
if configExists {
|
|
||||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
|
// Create a default config file
|
||||||
os.Rename(configFile, configFile+".bak")
|
if !configExists {
|
||||||
}
|
|
||||||
fmt.Printf("Writing configuration file to %s\n", configFile)
|
|
||||||
file, err := os.Create(configFile)
|
file, err := os.Create(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
return nil, fmt.Errorf("Could not open config file '%s' for writing: %w", configFile, err)
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
encoder := yaml.NewEncoder(file)
|
encoder := yaml.NewEncoder(file)
|
||||||
encoder.SetIndent(2)
|
encoder.SetIndent(2)
|
||||||
err = encoder.Encode(c)
|
err = encoder.Encode(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
file.Close()
|
||||||
|
os.Remove(configFile)
|
||||||
|
return nil, fmt.Errorf("Could not save default configuration to '%s': %w", configFile, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range c.Providers {
|
||||||
|
p._models, err = parseModels(p.RawModels)
|
||||||
|
if err != nil {
|
||||||
|
name := p.Name
|
||||||
|
if name == "" {
|
||||||
|
name = p.Display
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
name = p.Kind
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("Failed to models for provider '%s': %v", name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/provider/google"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
@@ -35,7 +35,7 @@ type Context struct {
|
|||||||
// high level app configuration, may be mutated at runtime
|
// high level app configuration, may be mutated at runtime
|
||||||
Config Config
|
Config Config
|
||||||
Conversations conversation.Repo
|
Conversations conversation.Repo
|
||||||
Chroma *tty.ChromaHighlighter
|
Chroma tty.ChromaHighlighter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContext() (*Context, error) {
|
func NewContext() (*Context, error) {
|
||||||
@@ -51,7 +51,7 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||||
return &Context{*config, store, chroma}, nil
|
return &Context{*config, store, *chroma}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createOrOpenAppend(path string) (*os.File, error) {
|
func createOrOpenAppend(path string) (*os.File, error) {
|
||||||
@@ -102,9 +102,9 @@ func (c *Context) GetModels() (models []string) {
|
|||||||
name = p.Name
|
name = p.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range p.Models {
|
for _, m := range p.Models() {
|
||||||
modelCounts[m]++
|
modelCounts[m.Name]++
|
||||||
models = append(models, fmt.Sprintf("%s@%s", m, name))
|
models = append(models, fmt.Sprintf("%s@%s", m.Name, name))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,8 +180,8 @@ func (c *Context) GetModelProvider(model string, provider string) (string, strin
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range p.Models {
|
for _, m := range p.Models() {
|
||||||
if m == model {
|
if m.Name == model {
|
||||||
switch p.Kind {
|
switch p.Kind {
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
url := "https://api.anthropic.com"
|
url := "https://api.anthropic.com"
|
||||||
|
|||||||
@@ -103,10 +103,10 @@ func (m *Model) getModelOptions() []list.OptionGroup {
|
|||||||
group := list.OptionGroup{
|
group := list.OptionGroup{
|
||||||
Name: providerLabel,
|
Name: providerLabel,
|
||||||
}
|
}
|
||||||
for _, model := range p.Models {
|
for _, model := range p.Models() {
|
||||||
group.Options = append(group.Options, list.Option{
|
group.Options = append(group.Options, list.Option{
|
||||||
Label: model,
|
Label: model.Name,
|
||||||
Value: modelOpt{provider, model},
|
Value: modelOpt{provider, model.Name},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
modelOpts = append(modelOpts, group)
|
modelOpts = append(modelOpts, group)
|
||||||
|
|||||||
Reference in New Issue
Block a user