Update config handling

- Stop using pointers where unnecessary
- Removed default system prompt
- Set indent level to 2 when writing config
- Update ordering of config struct, which affects marshalling
- Make provider `name` optional, defaulting to the provider's `kind`
This commit is contained in:
Matt Low 2024-06-23 16:02:26 +00:00
parent f89cc7b410
commit ba7018af11
3 changed files with 48 additions and 39 deletions

View File

@ -57,7 +57,7 @@ func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature") f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
// --system-prompt // --system-prompt
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt") f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
// --system-prompt-file // --system-prompt-file
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt") f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")

View File

@ -10,29 +10,29 @@ import (
type Config struct { type Config struct {
Defaults *struct { Defaults *struct {
SystemPromptFile string `yaml:"systemPromptFile,omitempty"` Model *string `yaml:"model" default:"gpt-4"`
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
MaxTokens *int `yaml:"maxTokens" default:"256"` MaxTokens *int `yaml:"maxTokens" default:"256"`
Temperature *float32 `yaml:"temperature" default:"0.2"` Temperature *float32 `yaml:"temperature" default:"0.2"`
Model *string `yaml:"model" default:"gpt-4"` SystemPrompt string `yaml:"systemPrompt,omitempty"`
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
} `yaml:"defaults"` } `yaml:"defaults"`
Conversations *struct { Conversations *struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
} `yaml:"conversations"` } `yaml:"conversations"`
Tools *struct {
EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"`
Providers []*struct {
Name *string `yaml:"name,omitempty"`
Kind *string `yaml:"kind"`
BaseURL *string `yaml:"baseUrl,omitempty"`
APIKey *string `yaml:"apiKey,omitempty"`
Models *[]string `yaml:"models"`
} `yaml:"providers"`
Chroma *struct { Chroma *struct {
Style *string `yaml:"style" default:"onedark"` Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"` Formatter *string `yaml:"formatter" default:"terminal16m"`
} `yaml:"chroma"` } `yaml:"chroma"`
Tools *struct {
EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"`
Providers []*struct {
Name string `yaml:"name,omitempty"`
Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"`
Models []string `yaml:"models"`
} `yaml:"providers"`
} }
func NewConfig(configFile string) (*Config, error) { func NewConfig(configFile string) (*Config, error) {
@ -60,8 +60,9 @@ func NewConfig(configFile string) (*Config, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err) return nil, fmt.Errorf("Could not open config file for writing: %v", err)
} }
bytes, _ := yaml.Marshal(c) encoder := yaml.NewEncoder(file)
_, err = file.Write(bytes) encoder.SetIndent(2)
err = encoder.Encode(c)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err) return nil, fmt.Errorf("Could not save default configuration: %v", err)
} }
@ -78,8 +79,5 @@ func (c *Config) GetSystemPrompt() string {
} }
return content return content
} }
if c.Defaults.SystemPrompt == nil { return c.Defaults.SystemPrompt
return ""
}
return *c.Defaults.SystemPrompt
} }

View File

@ -18,7 +18,8 @@ import (
) )
type Context struct { type Context struct {
Config *Config // may be updated at runtime // high level app configuration, may be mutated at runtime
Config Config
Store ConversationStore Store ConversationStore
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
@ -54,15 +55,20 @@ func NewContext() (*Context, error) {
} }
} }
return &Context{config, store, chroma, enabledTools}, nil return &Context{*config, store, chroma, enabledTools}, nil
} }
func (c *Context) GetModels() (models []string) { func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int) modelCounts := make(map[string]int)
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
for _, m := range *p.Models { name := p.Kind
if p.Name != "" {
name = p.Name
}
for _, m := range p.Models {
modelCounts[m]++ modelCounts[m]++
models = append(models, fmt.Sprintf("%s@%s", m, *p.Name)) models = append(models, fmt.Sprintf("%s@%s", m, name))
} }
} }
@ -85,50 +91,55 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
} }
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
if provider != "" && *p.Name != provider { name := p.Kind
if p.Name != "" {
name = p.Name
}
if provider != "" && name != provider {
continue continue
} }
for _, m := range *p.Models { for _, m := range p.Models {
if m == model { if m == model {
switch *p.Kind { switch p.Kind {
case "anthropic": case "anthropic":
url := "https://api.anthropic.com" url := "https://api.anthropic.com"
if p.BaseURL != nil { if p.BaseURL != "" {
url = *p.BaseURL url = p.BaseURL
} }
return model, &anthropic.AnthropicClient{ return model, &anthropic.AnthropicClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: p.APIKey,
}, nil }, nil
case "google": case "google":
url := "https://generativelanguage.googleapis.com" url := "https://generativelanguage.googleapis.com"
if p.BaseURL != nil { if p.BaseURL != "" {
url = *p.BaseURL url = p.BaseURL
} }
return model, &google.Client{ return model, &google.Client{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: p.APIKey,
}, nil }, nil
case "ollama": case "ollama":
url := "http://localhost:11434/api" url := "http://localhost:11434/api"
if p.BaseURL != nil { if p.BaseURL != "" {
url = *p.BaseURL url = p.BaseURL
} }
return model, &ollama.OllamaClient{ return model, &ollama.OllamaClient{
BaseURL: url, BaseURL: url,
}, nil }, nil
case "openai": case "openai":
url := "https://api.openai.com" url := "https://api.openai.com"
if p.BaseURL != nil { if p.BaseURL != "" {
url = *p.BaseURL url = p.BaseURL
} }
return model, &openai.OpenAIClient{ return model, &openai.OpenAIClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: p.APIKey,
}, nil }, nil
default: default:
return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
} }
} }
} }