diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 9116010..a00d1f5 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -17,8 +17,8 @@ var ( func init() { inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} for _, cmd := range inputCmds { - cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens") - cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use") + cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens") + cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use") } rootCmd.AddCommand( diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 5e94dbb..b9c7137 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -9,10 +9,10 @@ import ( ) type Config struct { - OpenAI struct { - APIKey string `yaml:"apiKey"` - DefaultModel string `yaml:"defaultModel"` - DefaultMaxLength int `yaml:"defaultMaxLength"` + OpenAI *struct { + APIKey *string `yaml:"apiKey" default:"your_key_here"` + DefaultModel *string `yaml:"defaultModel" default:"gpt-4"` + DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"` } `yaml:"openai"` } @@ -33,30 +33,30 @@ func getConfigDir() string { func NewConfig() (*Config, error) { configFile := filepath.Join(getConfigDir(), "config.yaml") + shouldWriteDefaults := false + c := &Config{} configBytes, err := os.ReadFile(configFile) if os.IsNotExist(err) { - defaultConfig := &Config{} - defaultConfig.OpenAI.APIKey = "your_key_here" + shouldWriteDefaults = true + } else if err != nil { + return nil, fmt.Errorf("Could not read config file: %v", err) + } else { + yaml.Unmarshal(configBytes, c) + } + shouldWriteDefaults = SetStructDefaults(c) + if shouldWriteDefaults { file, err := os.Create(configFile) if err != nil { return nil, fmt.Errorf("Could not open config file for writing: %v", err) } - - fmt.Printf("Writing default configuration to: %s\n", configFile) - - bytes, _ := yaml.Marshal(defaultConfig) - + bytes, _ := yaml.Marshal(c) _, err = file.Write(bytes) if err != nil { return nil, fmt.Errorf("Could not save default configuration: %v", err) } - } else if err != nil { - return nil, fmt.Errorf("Could not read config file: %v", err) } - config := &Config{} - yaml.Unmarshal(configBytes, config) - return config, nil + return c, nil } diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 09435e6..57f1407 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -27,7 +27,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int // CreateChatCompletion submits a Chat Completion API request and returns the // response. func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { - client := openai.NewClient(config.OpenAI.APIKey) + client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) resp, err := client.CreateChatCompletion(context.Background(), req) if err != nil { @@ -40,7 +40,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri // CreateChatCompletionStream submits a streaming Chat Completion API request // and streams the response to the provided output channel. func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error { - client := openai.NewClient(config.OpenAI.APIKey) + client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) defer close(output) diff --git a/pkg/cli/util.go b/pkg/cli/util.go index 6157c48..11af566 100644 --- a/pkg/cli/util.go +++ b/pkg/cli/util.go @@ -4,6 +4,8 @@ import ( "fmt" "os" "os/exec" + "reflect" + "strconv" "strings" "time" ) @@ -51,8 +53,8 @@ func InputFromEditor(placeholder string, pattern string) (string, error) { return strings.Trim(content, "\n \t"), nil } -// humanTimeElapsedSince returns a human-friendly representation of the given time -// duration. +// humanTimeElapsedSince returns a human-friendly "in the past" representation +// of the given duration. func humanTimeElapsedSince(d time.Duration) string { seconds := d.Seconds() minutes := seconds / 60 @@ -91,3 +93,68 @@ func humanTimeElapsedSince(d time.Duration) string { return fmt.Sprintf("%d years ago", int64(years)) } } + +// SetStructDefaultValues checks for any nil ptr fields within the passed +// struct, and sets the values of those fields to the value that is defined by +// their "default" struct tag. Handles setting string, int, and bool values. +// Returns whether any changes were made to the struct. +func SetStructDefaults(data interface{}) bool { + v := reflect.ValueOf(data).Elem() + changed := false + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + + // Check if we can set the field's value + if !field.CanSet() { + continue + } + + // We won't bother with non-pointer fields + if field.Kind() != reflect.Ptr { + continue + } + + t := field.Type() // type of pointer + e := t.Elem() // type of value of pointer + + // Handle nested structs recursively + if e.Kind() == reflect.Struct { + if field.IsNil() { + field.Set(reflect.New(e)) + changed = true + } + result := SetStructDefaults(field.Interface()) + if result { + changed = true + } + continue + } + + if !field.IsNil() { + continue + } + + // Get the "default" struct tag + defaultTag := v.Type().Field(i).Tag.Get("default") + if defaultTag == "" { + continue + } + + // Set nil pointer fields to their defined defaults + switch e.Kind() { + case reflect.String: + defaultValue := defaultTag + field.Set(reflect.ValueOf(&defaultValue)) + case reflect.Int, reflect.Int32, reflect.Int64: + intValue, _ := strconv.ParseInt(defaultTag, 10, 64) + field.Set(reflect.New(e)) + field.Elem().SetInt(intValue) + case reflect.Bool: + boolValue := defaultTag == "true" + field.Set(reflect.ValueOf(&boolValue)) + } + changed = true + } + return changed +}