Update config handling
- Stop using pointers where unnecessary - Removed default system prompt - Set indent level to 2 when writing config - Update ordering of config struct, which affects marshalling - Make provider `name` optional, defaulting to the provider's `kind`
This commit is contained in:
parent
f89cc7b410
commit
ba7018af11
@ -57,7 +57,7 @@ func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
|
||||
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
|
||||
|
||||
// --system-prompt
|
||||
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||
// --system-prompt-file
|
||||
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
|
@ -10,29 +10,29 @@ import (
|
||||
|
||||
type Config struct {
|
||||
Defaults *struct {
|
||||
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||
Model *string `yaml:"model" default:"gpt-4"`
|
||||
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
||||
Temperature *float32 `yaml:"temperature" default:"0.2"`
|
||||
Model *string `yaml:"model" default:"gpt-4"`
|
||||
SystemPrompt string `yaml:"systemPrompt,omitempty"`
|
||||
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
||||
} `yaml:"defaults"`
|
||||
Conversations *struct {
|
||||
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||
} `yaml:"conversations"`
|
||||
Tools *struct {
|
||||
EnabledTools []string `yaml:"enabledTools"`
|
||||
} `yaml:"tools"`
|
||||
Providers []*struct {
|
||||
Name *string `yaml:"name,omitempty"`
|
||||
Kind *string `yaml:"kind"`
|
||||
BaseURL *string `yaml:"baseUrl,omitempty"`
|
||||
APIKey *string `yaml:"apiKey,omitempty"`
|
||||
Models *[]string `yaml:"models"`
|
||||
} `yaml:"providers"`
|
||||
Chroma *struct {
|
||||
Style *string `yaml:"style" default:"onedark"`
|
||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||
} `yaml:"chroma"`
|
||||
Tools *struct {
|
||||
EnabledTools []string `yaml:"enabledTools"`
|
||||
} `yaml:"tools"`
|
||||
Providers []*struct {
|
||||
Name string `yaml:"name,omitempty"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl,omitempty"`
|
||||
APIKey string `yaml:"apiKey,omitempty"`
|
||||
Models []string `yaml:"models"`
|
||||
} `yaml:"providers"`
|
||||
}
|
||||
|
||||
func NewConfig(configFile string) (*Config, error) {
|
||||
@ -60,8 +60,9 @@ func NewConfig(configFile string) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||
}
|
||||
bytes, _ := yaml.Marshal(c)
|
||||
_, err = file.Write(bytes)
|
||||
encoder := yaml.NewEncoder(file)
|
||||
encoder.SetIndent(2)
|
||||
err = encoder.Encode(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||
}
|
||||
@ -78,8 +79,5 @@ func (c *Config) GetSystemPrompt() string {
|
||||
}
|
||||
return content
|
||||
}
|
||||
if c.Defaults.SystemPrompt == nil {
|
||||
return ""
|
||||
}
|
||||
return *c.Defaults.SystemPrompt
|
||||
return c.Defaults.SystemPrompt
|
||||
}
|
||||
|
@ -18,7 +18,8 @@ import (
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
Config *Config // may be updated at runtime
|
||||
// high level app configuration, may be mutated at runtime
|
||||
Config Config
|
||||
Store ConversationStore
|
||||
|
||||
Chroma *tty.ChromaHighlighter
|
||||
@ -54,15 +55,20 @@ func NewContext() (*Context, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &Context{config, store, chroma, enabledTools}, nil
|
||||
return &Context{*config, store, chroma, enabledTools}, nil
|
||||
}
|
||||
|
||||
func (c *Context) GetModels() (models []string) {
|
||||
modelCounts := make(map[string]int)
|
||||
for _, p := range c.Config.Providers {
|
||||
for _, m := range *p.Models {
|
||||
name := p.Kind
|
||||
if p.Name != "" {
|
||||
name = p.Name
|
||||
}
|
||||
|
||||
for _, m := range p.Models {
|
||||
modelCounts[m]++
|
||||
models = append(models, fmt.Sprintf("%s@%s", m, *p.Name))
|
||||
models = append(models, fmt.Sprintf("%s@%s", m, name))
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,50 +91,55 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
|
||||
}
|
||||
|
||||
for _, p := range c.Config.Providers {
|
||||
if provider != "" && *p.Name != provider {
|
||||
name := p.Kind
|
||||
if p.Name != "" {
|
||||
name = p.Name
|
||||
}
|
||||
|
||||
if provider != "" && name != provider {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range *p.Models {
|
||||
for _, m := range p.Models {
|
||||
if m == model {
|
||||
switch *p.Kind {
|
||||
switch p.Kind {
|
||||
case "anthropic":
|
||||
url := "https://api.anthropic.com"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &anthropic.AnthropicClient{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
case "google":
|
||||
url := "https://generativelanguage.googleapis.com"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &google.Client{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
case "ollama":
|
||||
url := "http://localhost:11434/api"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &ollama.OllamaClient{
|
||||
BaseURL: url,
|
||||
}, nil
|
||||
case "openai":
|
||||
url := "https://api.openai.com"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &openai.OpenAIClient{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
||||
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user