Set config defaults using a "default" struct tag

Add new SetStructDefaults function to handle the "defaults" struct tag.

Only works on struct fields which are pointers (in order to be able to
distinguish between not set (nil) and zero values). So, the Config
struct has been updated to use pointer fields and we now need to
dereference those pointers to use them.
This commit is contained in:
Matt Low 2023-11-19 01:14:00 +00:00
parent 6426b04e2c
commit 8780856854
4 changed files with 89 additions and 22 deletions

View File

@ -17,8 +17,8 @@ var (
func init() { func init() {
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
for _, cmd := range inputCmds { for _, cmd := range inputCmds {
cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens") cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens")
cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use") cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use")
} }
rootCmd.AddCommand( rootCmd.AddCommand(

View File

@ -9,10 +9,10 @@ import (
) )
type Config struct { type Config struct {
OpenAI struct { OpenAI *struct {
APIKey string `yaml:"apiKey"` APIKey *string `yaml:"apiKey" default:"your_key_here"`
DefaultModel string `yaml:"defaultModel"` DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
DefaultMaxLength int `yaml:"defaultMaxLength"` DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
} `yaml:"openai"` } `yaml:"openai"`
} }
@ -33,30 +33,30 @@ func getConfigDir() string {
func NewConfig() (*Config, error) { func NewConfig() (*Config, error) {
configFile := filepath.Join(getConfigDir(), "config.yaml") configFile := filepath.Join(getConfigDir(), "config.yaml")
shouldWriteDefaults := false
c := &Config{}
configBytes, err := os.ReadFile(configFile) configBytes, err := os.ReadFile(configFile)
if os.IsNotExist(err) { if os.IsNotExist(err) {
defaultConfig := &Config{} shouldWriteDefaults = true
defaultConfig.OpenAI.APIKey = "your_key_here" } else if err != nil {
return nil, fmt.Errorf("Could not read config file: %v", err)
} else {
yaml.Unmarshal(configBytes, c)
}
shouldWriteDefaults = SetStructDefaults(c)
if shouldWriteDefaults {
file, err := os.Create(configFile) file, err := os.Create(configFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not open config file for writing: %v", err) return nil, fmt.Errorf("Could not open config file for writing: %v", err)
} }
bytes, _ := yaml.Marshal(c)
fmt.Printf("Writing default configuration to: %s\n", configFile)
bytes, _ := yaml.Marshal(defaultConfig)
_, err = file.Write(bytes) _, err = file.Write(bytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not save default configuration: %v", err) return nil, fmt.Errorf("Could not save default configuration: %v", err)
} }
} else if err != nil {
return nil, fmt.Errorf("Could not read config file: %v", err)
} }
config := &Config{} return c, nil
yaml.Unmarshal(configBytes, config)
return config, nil
} }

View File

@ -27,7 +27,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
// CreateChatCompletion submits a Chat Completion API request and returns the // CreateChatCompletion submits a Chat Completion API request and returns the
// response. // response.
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) { func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
client := openai.NewClient(config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
resp, err := client.CreateChatCompletion(context.Background(), req) resp, err := client.CreateChatCompletion(context.Background(), req)
if err != nil { if err != nil {
@ -40,7 +40,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
// CreateChatCompletionStream submits a streaming Chat Completion API request // CreateChatCompletionStream submits a streaming Chat Completion API request
// and streams the response to the provided output channel. // and streams the response to the provided output channel.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error { func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error {
client := openai.NewClient(config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
defer close(output) defer close(output)

View File

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"reflect"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -51,8 +53,8 @@ func InputFromEditor(placeholder string, pattern string) (string, error) {
return strings.Trim(content, "\n \t"), nil return strings.Trim(content, "\n \t"), nil
} }
// humanTimeElapsedSince returns a human-friendly representation of the given time // humanTimeElapsedSince returns a human-friendly "in the past" representation
// duration. // of the given duration.
func humanTimeElapsedSince(d time.Duration) string { func humanTimeElapsedSince(d time.Duration) string {
seconds := d.Seconds() seconds := d.Seconds()
minutes := seconds / 60 minutes := seconds / 60
@ -91,3 +93,68 @@ func humanTimeElapsedSince(d time.Duration) string {
return fmt.Sprintf("%d years ago", int64(years)) return fmt.Sprintf("%d years ago", int64(years))
} }
} }
// SetStructDefaultValues checks for any nil ptr fields within the passed
// struct, and sets the values of those fields to the value that is defined by
// their "default" struct tag. Handles setting string, int, and bool values.
// Returns whether any changes were made to the struct.
func SetStructDefaults(data interface{}) bool {
v := reflect.ValueOf(data).Elem()
changed := false
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
// Check if we can set the field's value
if !field.CanSet() {
continue
}
// We won't bother with non-pointer fields
if field.Kind() != reflect.Ptr {
continue
}
t := field.Type() // type of pointer
e := t.Elem() // type of value of pointer
// Handle nested structs recursively
if e.Kind() == reflect.Struct {
if field.IsNil() {
field.Set(reflect.New(e))
changed = true
}
result := SetStructDefaults(field.Interface())
if result {
changed = true
}
continue
}
if !field.IsNil() {
continue
}
// Get the "default" struct tag
defaultTag := v.Type().Field(i).Tag.Get("default")
if defaultTag == "" {
continue
}
// Set nil pointer fields to their defined defaults
switch e.Kind() {
case reflect.String:
defaultValue := defaultTag
field.Set(reflect.ValueOf(&defaultValue))
case reflect.Int, reflect.Int32, reflect.Int64:
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
field.Set(reflect.New(e))
field.Elem().SetInt(intValue)
case reflect.Bool:
boolValue := defaultTag == "true"
field.Set(reflect.ValueOf(&boolValue))
}
changed = true
}
return changed
}