Update config handling

- Stop using pointers where unnecessary
- Removed default system prompt
- Set indent level to 2 when writing config
- Update ordering of config struct, which affects marshalling
- Make provider `name` optional, defaulting to the provider's `kind`
This commit is contained in:
Matt Low 2024-06-23 16:02:26 +00:00
parent f89cc7b410
commit ba7018af11
3 changed files with 48 additions and 39 deletions

View File

@ -57,7 +57,7 @@ func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
// --system-prompt
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
// --system-prompt-file
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")

View File

@ -10,29 +10,29 @@ import (
type Config struct {
Defaults *struct {
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
Model *string `yaml:"model" default:"gpt-4"`
MaxTokens *int `yaml:"maxTokens" default:"256"`
Temperature *float32 `yaml:"temperature" default:"0.2"`
Model *string `yaml:"model" default:"gpt-4"`
SystemPrompt string `yaml:"systemPrompt,omitempty"`
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
} `yaml:"defaults"`
Conversations *struct {
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
} `yaml:"conversations"`
Tools *struct {
EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"`
Providers []*struct {
Name *string `yaml:"name,omitempty"`
Kind *string `yaml:"kind"`
BaseURL *string `yaml:"baseUrl,omitempty"`
APIKey *string `yaml:"apiKey,omitempty"`
Models *[]string `yaml:"models"`
} `yaml:"providers"`
Chroma *struct {
Style *string `yaml:"style" default:"onedark"`
Formatter *string `yaml:"formatter" default:"terminal16m"`
} `yaml:"chroma"`
Tools *struct {
EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"`
Providers []*struct {
Name string `yaml:"name,omitempty"`
Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"`
Models []string `yaml:"models"`
} `yaml:"providers"`
}
func NewConfig(configFile string) (*Config, error) {
@ -60,8 +60,9 @@ func NewConfig(configFile string) (*Config, error) {
if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
}
bytes, _ := yaml.Marshal(c)
_, err = file.Write(bytes)
encoder := yaml.NewEncoder(file)
encoder.SetIndent(2)
err = encoder.Encode(c)
if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err)
}
@ -78,8 +79,5 @@ func (c *Config) GetSystemPrompt() string {
}
return content
}
if c.Defaults.SystemPrompt == nil {
return ""
}
return *c.Defaults.SystemPrompt
return c.Defaults.SystemPrompt
}

View File

@ -18,7 +18,8 @@ import (
)
type Context struct {
Config *Config // may be updated at runtime
// high level app configuration, may be mutated at runtime
Config Config
Store ConversationStore
Chroma *tty.ChromaHighlighter
@ -54,15 +55,20 @@ func NewContext() (*Context, error) {
}
}
return &Context{config, store, chroma, enabledTools}, nil
return &Context{*config, store, chroma, enabledTools}, nil
}
func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers {
for _, m := range *p.Models {
name := p.Kind
if p.Name != "" {
name = p.Name
}
for _, m := range p.Models {
modelCounts[m]++
models = append(models, fmt.Sprintf("%s@%s", m, *p.Name))
models = append(models, fmt.Sprintf("%s@%s", m, name))
}
}
@ -85,50 +91,55 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
}
for _, p := range c.Config.Providers {
if provider != "" && *p.Name != provider {
name := p.Kind
if p.Name != "" {
name = p.Name
}
if provider != "" && name != provider {
continue
}
for _, m := range *p.Models {
for _, m := range p.Models {
if m == model {
switch *p.Kind {
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != nil {
url = *p.BaseURL
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &anthropic.AnthropicClient{
BaseURL: url,
APIKey: *p.APIKey,
APIKey: p.APIKey,
}, nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != nil {
url = *p.BaseURL
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &google.Client{
BaseURL: url,
APIKey: *p.APIKey,
APIKey: p.APIKey,
}, nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != nil {
url = *p.BaseURL
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &ollama.OllamaClient{
BaseURL: url,
}, nil
case "openai":
url := "https://api.openai.com"
if p.BaseURL != nil {
url = *p.BaseURL
if p.BaseURL != "" {
url = p.BaseURL
}
return model, &openai.OpenAIClient{
BaseURL: url,
APIKey: *p.APIKey,
APIKey: p.APIKey,
}, nil
default:
return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
}
}
}