From 8780856854d453f4f570c786e575c7e4882e7b0f Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 19 Nov 2023 01:14:00 +0000 Subject: [PATCH] Set config defaults using a "default" struct tag Add new SetStructDefaults function to handle the "defaults" struct tag. Only works on struct fields which are pointers (in order to be able to distinguish between not set (nil) and zero values). So, the Config struct has been updated to use pointer fields and we now need to dereference those pointers to use them. --- pkg/cli/cmd.go | 4 +-- pkg/cli/config.go | 32 ++++++++++----------- pkg/cli/openai.go | 4 +-- pkg/cli/util.go | 71 +++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 89 insertions(+), 22 deletions(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 9116010..a00d1f5 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -17,8 +17,8 @@ var ( func init() { inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} for _, cmd := range inputCmds { - cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens") - cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use") + cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens") + cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use") } rootCmd.AddCommand( diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 5e94dbb..b9c7137 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -9,10 +9,10 @@ import ( ) type Config struct { - OpenAI struct { - APIKey string `yaml:"apiKey"` - DefaultModel string `yaml:"defaultModel"` - DefaultMaxLength int `yaml:"defaultMaxLength"` + OpenAI *struct { + APIKey *string `yaml:"apiKey" default:"your_key_here"` + DefaultModel *string `yaml:"defaultModel" default:"gpt-4"` + DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"` } `yaml:"openai"` } @@ -33,30 +33,30 @@ func getConfigDir() string { func NewConfig() (*Config, error) { configFile := filepath.Join(getConfigDir(), "config.yaml") + shouldWriteDefaults := false + c := &Config{} configBytes, err := os.ReadFile(configFile) if os.IsNotExist(err) { - defaultConfig := &Config{} - defaultConfig.OpenAI.APIKey = "your_key_here" + shouldWriteDefaults = true + } else if err != nil { + return nil, fmt.Errorf("Could not read config file: %v", err) + } else { + yaml.Unmarshal(configBytes, c) + } + shouldWriteDefaults = SetStructDefaults(c) + if shouldWriteDefaults { file, err := os.Create(configFile) if err != nil { return nil, fmt.Errorf("Could not open config file for writing: %v", err) } - - fmt.Printf("Writing default configuration to: %s\n", configFile) - - bytes, _ := yaml.Marshal(defaultConfig) - + bytes, _ := yaml.Marshal(c) _, err = file.Write(bytes) if err != nil { return nil, fmt.Errorf("Could not save default configuration: %v", err) } - } else if err != nil { - return nil, fmt.Errorf("Could not read config file: %v", err) } - config := &Config{} - yaml.Unmarshal(configBytes, config) - return config, nil + return c, nil } diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 09435e6..57f1407 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -27,7 +27,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int // CreateChatCompletion submits a Chat Completion API request and returns the // response. func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { - client := openai.NewClient(config.OpenAI.APIKey) + client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) resp, err := client.CreateChatCompletion(context.Background(), req) if err != nil { @@ -40,7 +40,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri // CreateChatCompletionStream submits a streaming Chat Completion API request // and streams the response to the provided output channel. func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error { - client := openai.NewClient(config.OpenAI.APIKey) + client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) defer close(output) diff --git a/pkg/cli/util.go b/pkg/cli/util.go index 6157c48..11af566 100644 --- a/pkg/cli/util.go +++ b/pkg/cli/util.go @@ -4,6 +4,8 @@ import ( "fmt" "os" "os/exec" + "reflect" + "strconv" "strings" "time" ) @@ -51,8 +53,8 @@ func InputFromEditor(placeholder string, pattern string) (string, error) { return strings.Trim(content, "\n \t"), nil } -// humanTimeElapsedSince returns a human-friendly representation of the given time -// duration. +// humanTimeElapsedSince returns a human-friendly "in the past" representation +// of the given duration. func humanTimeElapsedSince(d time.Duration) string { seconds := d.Seconds() minutes := seconds / 60 @@ -91,3 +93,68 @@ func humanTimeElapsedSince(d time.Duration) string { return fmt.Sprintf("%d years ago", int64(years)) } } + +// SetStructDefaultValues checks for any nil ptr fields within the passed +// struct, and sets the values of those fields to the value that is defined by +// their "default" struct tag. Handles setting string, int, and bool values. +// Returns whether any changes were made to the struct. +func SetStructDefaults(data interface{}) bool { + v := reflect.ValueOf(data).Elem() + changed := false + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + + // Check if we can set the field's value + if !field.CanSet() { + continue + } + + // We won't bother with non-pointer fields + if field.Kind() != reflect.Ptr { + continue + } + + t := field.Type() // type of pointer + e := t.Elem() // type of value of pointer + + // Handle nested structs recursively + if e.Kind() == reflect.Struct { + if field.IsNil() { + field.Set(reflect.New(e)) + changed = true + } + result := SetStructDefaults(field.Interface()) + if result { + changed = true + } + continue + } + + if !field.IsNil() { + continue + } + + // Get the "default" struct tag + defaultTag := v.Type().Field(i).Tag.Get("default") + if defaultTag == "" { + continue + } + + // Set nil pointer fields to their defined defaults + switch e.Kind() { + case reflect.String: + defaultValue := defaultTag + field.Set(reflect.ValueOf(&defaultValue)) + case reflect.Int, reflect.Int32, reflect.Int64: + intValue, _ := strconv.ParseInt(defaultTag, 10, 64) + field.Set(reflect.New(e)) + field.Elem().SetInt(intValue) + case reflect.Bool: + boolValue := defaultTag == "true" + field.Set(reflect.ValueOf(&boolValue)) + } + changed = true + } + return changed +}