Started work to make it possible to pass in per-model reasoning config Cleaned up how we instantiate RequestParameters (TBD: remove RequestParameters?)
173 lines
5.1 KiB
Go
173 lines
5.1 KiB
Go
package lmcli
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
type Model struct {
|
|
Name string
|
|
Reasoning *bool
|
|
MaxTokens *int
|
|
Temperature *float32
|
|
}
|
|
|
|
type Provider struct {
|
|
Name string `yaml:"name,omitempty"`
|
|
Display string `yaml:"display,omitempty"`
|
|
Kind string `yaml:"kind"`
|
|
BaseURL string `yaml:"baseUrl,omitempty"`
|
|
APIKey string `yaml:"apiKey,omitempty"`
|
|
RawModels []any `yaml:"models"` // Raw models from YAML
|
|
_models []Model // Parsed models for internal use
|
|
Headers map[string]string `yaml:"headers"`
|
|
}
|
|
|
|
func (p *Provider) Models() []Model {
|
|
return p._models
|
|
}
|
|
|
|
type Config struct {
|
|
Defaults *struct {
|
|
Model *string `yaml:"model" default:"default-model"`
|
|
MaxTokens *int `yaml:"maxTokens" default:"2048"`
|
|
Temperature *float32 `yaml:"temperature" default:"0.6"`
|
|
Reasoning *bool `yaml:"reasoning" default:"true"`
|
|
SystemPrompt string `yaml:"systemPrompt,omitempty"`
|
|
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
|
Agent string `yaml:"agent"`
|
|
} `yaml:"defaults"`
|
|
Conversations *struct {
|
|
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
|
} `yaml:"conversations"`
|
|
Chroma *struct {
|
|
Style *string `yaml:"style" default:"onedark"`
|
|
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
|
} `yaml:"chroma"`
|
|
Agents []*struct {
|
|
Name string `yaml:"name"`
|
|
SystemPrompt string `yaml:"systemPrompt"`
|
|
Tools []string `yaml:"tools"`
|
|
} `yaml:"agents"`
|
|
Providers []*Provider `yaml:"providers"`
|
|
}
|
|
|
|
// Helper function to parse the Models field for a provider
|
|
func parseModels(rawModels []any) ([]Model, error) {
|
|
if rawModels == nil {
|
|
return nil, fmt.Errorf("No models to parse")
|
|
}
|
|
|
|
models := make([]Model, 0, len(rawModels))
|
|
for i, modelInterface := range rawModels {
|
|
var parsedModel Model
|
|
|
|
switch rawModel := modelInterface.(type) {
|
|
case string:
|
|
parsedModel.Name = rawModel
|
|
case map[string]any:
|
|
// Case 2: Model defined as a map
|
|
if nameVal, ok := rawModel["name"]; ok {
|
|
if nameStr, ok := nameVal.(string); ok {
|
|
parsedModel.Name = nameStr
|
|
} else {
|
|
return nil, fmt.Errorf("Invalid 'name' type (%T) for model index %d", nameVal, i)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("Missing 'name' for model index %d", i)
|
|
}
|
|
|
|
if reasoningVal, ok := rawModel["reasoning"]; ok {
|
|
if reasoningBool, ok := reasoningVal.(bool); ok {
|
|
parsedModel.Reasoning = &reasoningBool
|
|
} else {
|
|
return nil, fmt.Errorf("Invalid 'reasoning' type (%T) for model '%s'", reasoningVal, parsedModel.Name)
|
|
}
|
|
} // else: default is false
|
|
|
|
if maxTokensVal, ok := rawModel["maxTokens"]; ok {
|
|
// YAML numbers often unmarshal as int, sometimes float64. Handle int primarily.
|
|
if maxTokensInt, ok := maxTokensVal.(int); ok {
|
|
parsedModel.MaxTokens = &maxTokensInt
|
|
} else {
|
|
return nil, fmt.Errorf("Invalid 'maxTokens' type (%T) for model '%s'", maxTokensVal, parsedModel.Name)
|
|
}
|
|
} // else: default is nil
|
|
|
|
if temperatureVal, ok := rawModel["temperature"]; ok {
|
|
// YAML numbers often unmarshal as int, sometimes float64. Handle int primarily.
|
|
if temperatureFloat, ok := temperatureVal.(float64); ok {
|
|
asFloat32 := float32(temperatureFloat)
|
|
parsedModel.Temperature = &asFloat32
|
|
} else {
|
|
return nil, fmt.Errorf("Invalid 'temperature' type (%T) for model '%s'", temperatureVal, parsedModel.Name)
|
|
}
|
|
} // else: default is nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("Invalid model definition type (%T) at index %d", modelInterface, i)
|
|
}
|
|
|
|
models = append(models, parsedModel)
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func NewConfig(configFile string) (*Config, error) {
|
|
c := &Config{}
|
|
|
|
configExists := true
|
|
configBytes, err := os.ReadFile(configFile)
|
|
if os.IsNotExist(err) {
|
|
configExists = false
|
|
fmt.Printf("Configuration file not found at %s. Creating with defaults.\n", configFile)
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("Could not read config file '%s': %w", configFile, err)
|
|
} else {
|
|
err = yaml.Unmarshal(configBytes, c)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not parse config file '%s': %w", configFile, err)
|
|
}
|
|
}
|
|
|
|
// Update the config with default values
|
|
util.SetStructDefaults(c)
|
|
|
|
// Create a default config file
|
|
if !configExists {
|
|
file, err := os.Create(configFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not open config file '%s' for writing: %w", configFile, err)
|
|
}
|
|
defer file.Close()
|
|
|
|
encoder := yaml.NewEncoder(file)
|
|
encoder.SetIndent(2)
|
|
err = encoder.Encode(c)
|
|
if err != nil {
|
|
file.Close()
|
|
os.Remove(configFile)
|
|
return nil, fmt.Errorf("Could not save default configuration to '%s': %w", configFile, err)
|
|
}
|
|
}
|
|
|
|
for _, p := range c.Providers {
|
|
p._models, err = parseModels(p.RawModels)
|
|
if err != nil {
|
|
name := p.Name
|
|
if name == "" {
|
|
name = p.Display
|
|
}
|
|
if name == "" {
|
|
name = p.Kind
|
|
}
|
|
return nil, fmt.Errorf("Failed to models for provider '%s': %v", name, err)
|
|
}
|
|
}
|
|
|
|
return c, nil
|
|
}
|