From ba7018af11174eda8fa0aa873d4cf607e15a157b Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 23 Jun 2024 16:02:26 +0000 Subject: [PATCH] Update config handling - Stop using pointers where unnecessary - Removed default system prompt - Set indent level to 2 when writing config - Update ordering of config struct, which affects marshalling - Make provider `name` optional, defaulting to the provider's `kind` --- pkg/cmd/cmd.go | 2 +- pkg/lmcli/config.go | 36 ++++++++++++++++----------------- pkg/lmcli/lmcli.go | 49 +++++++++++++++++++++++++++------------------ 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 764cea0..90dc878 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -57,7 +57,7 @@ func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) { f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature") // --system-prompt - f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt") + f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt") // --system-prompt-file f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt") cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index 89190e1..ad25250 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -10,29 +10,29 @@ import ( type Config struct { Defaults *struct { - SystemPromptFile string `yaml:"systemPromptFile,omitempty"` - SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."` + Model *string `yaml:"model" default:"gpt-4"` MaxTokens *int `yaml:"maxTokens" default:"256"` Temperature *float32 `yaml:"temperature" default:"0.2"` - Model *string `yaml:"model" default:"gpt-4"` + SystemPrompt string `yaml:"systemPrompt,omitempty"` + SystemPromptFile string `yaml:"systemPromptFile,omitempty"` } `yaml:"defaults"` Conversations *struct { TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` } `yaml:"conversations"` - Tools *struct { - EnabledTools []string `yaml:"enabledTools"` - } `yaml:"tools"` - Providers []*struct { - Name *string `yaml:"name,omitempty"` - Kind *string `yaml:"kind"` - BaseURL *string `yaml:"baseUrl,omitempty"` - APIKey *string `yaml:"apiKey,omitempty"` - Models *[]string `yaml:"models"` - } `yaml:"providers"` Chroma *struct { Style *string `yaml:"style" default:"onedark"` Formatter *string `yaml:"formatter" default:"terminal16m"` } `yaml:"chroma"` + Tools *struct { + EnabledTools []string `yaml:"enabledTools"` + } `yaml:"tools"` + Providers []*struct { + Name string `yaml:"name,omitempty"` + Kind string `yaml:"kind"` + BaseURL string `yaml:"baseUrl,omitempty"` + APIKey string `yaml:"apiKey,omitempty"` + Models []string `yaml:"models"` + } `yaml:"providers"` } func NewConfig(configFile string) (*Config, error) { @@ -60,8 +60,9 @@ func NewConfig(configFile string) (*Config, error) { if err != nil { return nil, fmt.Errorf("Could not open config file for writing: %v", err) } - bytes, _ := yaml.Marshal(c) - _, err = file.Write(bytes) + encoder := yaml.NewEncoder(file) + encoder.SetIndent(2) + err = encoder.Encode(c) if err != nil { return nil, fmt.Errorf("Could not save default configuration: %v", err) } @@ -78,8 +79,5 @@ func (c *Config) GetSystemPrompt() string { } return content } - if c.Defaults.SystemPrompt == nil { - return "" - } - return *c.Defaults.SystemPrompt + return c.Defaults.SystemPrompt } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 3ee8c3e..6de906a 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -18,7 +18,8 @@ import ( ) type Context struct { - Config *Config // may be updated at runtime + // high level app configuration, may be mutated at runtime + Config Config Store ConversationStore Chroma *tty.ChromaHighlighter @@ -54,15 +55,20 @@ func NewContext() (*Context, error) { } } - return &Context{config, store, chroma, enabledTools}, nil + return &Context{*config, store, chroma, enabledTools}, nil } func (c *Context) GetModels() (models []string) { modelCounts := make(map[string]int) for _, p := range c.Config.Providers { - for _, m := range *p.Models { + name := p.Kind + if p.Name != "" { + name = p.Name + } + + for _, m := range p.Models { modelCounts[m]++ - models = append(models, fmt.Sprintf("%s@%s", m, *p.Name)) + models = append(models, fmt.Sprintf("%s@%s", m, name)) } } @@ -85,50 +91,55 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv } for _, p := range c.Config.Providers { - if provider != "" && *p.Name != provider { + name := p.Kind + if p.Name != "" { + name = p.Name + } + + if provider != "" && name != provider { continue } - for _, m := range *p.Models { + for _, m := range p.Models { if m == model { - switch *p.Kind { + switch p.Kind { case "anthropic": url := "https://api.anthropic.com" - if p.BaseURL != nil { - url = *p.BaseURL + if p.BaseURL != "" { + url = p.BaseURL } return model, &anthropic.AnthropicClient{ BaseURL: url, - APIKey: *p.APIKey, + APIKey: p.APIKey, }, nil case "google": url := "https://generativelanguage.googleapis.com" - if p.BaseURL != nil { - url = *p.BaseURL + if p.BaseURL != "" { + url = p.BaseURL } return model, &google.Client{ BaseURL: url, - APIKey: *p.APIKey, + APIKey: p.APIKey, }, nil case "ollama": url := "http://localhost:11434/api" - if p.BaseURL != nil { - url = *p.BaseURL + if p.BaseURL != "" { + url = p.BaseURL } return model, &ollama.OllamaClient{ BaseURL: url, }, nil case "openai": url := "https://api.openai.com" - if p.BaseURL != nil { - url = *p.BaseURL + if p.BaseURL != "" { + url = p.BaseURL } return model, &openai.OpenAIClient{ BaseURL: url, - APIKey: *p.APIKey, + APIKey: p.APIKey, }, nil default: - return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) + return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) } } }