Set config defaults using a "default" struct tag
Add new SetStructDefaults function to handle the "defaults" struct tag. Only works on struct fields which are pointers (in order to be able to distinguish between not set (nil) and zero values). So, the Config struct has been updated to use pointer fields and we now need to dereference those pointers to use them.
This commit is contained in:
parent
6426b04e2c
commit
dce62e7748
@ -17,8 +17,8 @@ var (
|
|||||||
func init() {
|
func init() {
|
||||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
|
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
|
||||||
for _, cmd := range inputCmds {
|
for _, cmd := range inputCmds {
|
||||||
cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens")
|
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens")
|
||||||
cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use")
|
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use")
|
||||||
}
|
}
|
||||||
|
|
||||||
rootCmd.AddCommand(
|
rootCmd.AddCommand(
|
||||||
|
@ -9,10 +9,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
OpenAI struct {
|
OpenAI *struct {
|
||||||
APIKey string `yaml:"apiKey"`
|
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||||
DefaultModel string `yaml:"defaultModel"`
|
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
||||||
DefaultMaxLength int `yaml:"defaultMaxLength"`
|
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
||||||
} `yaml:"openai"`
|
} `yaml:"openai"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,30 +33,30 @@ func getConfigDir() string {
|
|||||||
|
|
||||||
func NewConfig() (*Config, error) {
|
func NewConfig() (*Config, error) {
|
||||||
configFile := filepath.Join(getConfigDir(), "config.yaml")
|
configFile := filepath.Join(getConfigDir(), "config.yaml")
|
||||||
|
shouldWriteDefaults := false
|
||||||
|
c := &Config{}
|
||||||
|
|
||||||
configBytes, err := os.ReadFile(configFile)
|
configBytes, err := os.ReadFile(configFile)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
defaultConfig := &Config{}
|
shouldWriteDefaults = true
|
||||||
defaultConfig.OpenAI.APIKey = "your_key_here"
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not read config file: %v", err)
|
||||||
|
} else {
|
||||||
|
yaml.Unmarshal(configBytes, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldWriteDefaults = SetStructDefaults(c)
|
||||||
|
if shouldWriteDefaults {
|
||||||
file, err := os.Create(configFile)
|
file, err := os.Create(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||||
}
|
}
|
||||||
|
bytes, _ := yaml.Marshal(c)
|
||||||
fmt.Printf("Writing default configuration to: %s\n", configFile)
|
|
||||||
|
|
||||||
bytes, _ := yaml.Marshal(defaultConfig)
|
|
||||||
|
|
||||||
_, err = file.Write(bytes)
|
_, err = file.Write(bytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &Config{}
|
return c, nil
|
||||||
yaml.Unmarshal(configBytes, config)
|
|
||||||
return config, nil
|
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ func CreateChatCompletionRequest(model string, messages []Message, maxTokens int
|
|||||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||||
// response.
|
// response.
|
||||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
||||||
client := openai.NewClient(config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -40,7 +40,7 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
|||||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||||
// and streams the response to the provided output channel.
|
// and streams the response to the provided output channel.
|
||||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error {
|
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error {
|
||||||
client := openai.NewClient(config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
|
|
||||||
defer close(output)
|
defer close(output)
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -51,8 +53,8 @@ func InputFromEditor(placeholder string, pattern string) (string, error) {
|
|||||||
return strings.Trim(content, "\n \t"), nil
|
return strings.Trim(content, "\n \t"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// humanTimeElapsedSince returns a human-friendly representation of the given time
|
// humanTimeElapsedSince returns a human-friendly "in the past" representation
|
||||||
// duration.
|
// of the given duration.
|
||||||
func humanTimeElapsedSince(d time.Duration) string {
|
func humanTimeElapsedSince(d time.Duration) string {
|
||||||
seconds := d.Seconds()
|
seconds := d.Seconds()
|
||||||
minutes := seconds / 60
|
minutes := seconds / 60
|
||||||
@ -91,3 +93,68 @@ func humanTimeElapsedSince(d time.Duration) string {
|
|||||||
return fmt.Sprintf("%d years ago", int64(years))
|
return fmt.Sprintf("%d years ago", int64(years))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetStructDefaultValues checks for any nil ptr fields within the passed
|
||||||
|
// struct, and sets the values of those fields to the value that is contained
|
||||||
|
// within their "default" struct tag. Handles setting string, int, and bool
|
||||||
|
// values. Returns whether any changes were made to the struct.
|
||||||
|
func SetStructDefaults(data interface{}) bool {
|
||||||
|
v := reflect.ValueOf(data).Elem()
|
||||||
|
changed := false
|
||||||
|
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
field := v.Field(i)
|
||||||
|
|
||||||
|
// Check if we can set the field value
|
||||||
|
if !field.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We won't bother with non-pointer fields
|
||||||
|
if field.Kind() != reflect.Ptr {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
t := field.Type() // type of pointer
|
||||||
|
e := t.Elem() // type of value of pointer
|
||||||
|
|
||||||
|
// Handle nested structs recursively
|
||||||
|
if e.Kind() == reflect.Struct {
|
||||||
|
if field.IsNil() {
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
result := SetStructDefaults(field.Interface())
|
||||||
|
if result {
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !field.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the "default" struct tag
|
||||||
|
defaultTag := v.Type().Field(i).Tag.Get("default")
|
||||||
|
if defaultTag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set nil pointer fields to their defined defaults
|
||||||
|
switch e.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
defaultValue := defaultTag
|
||||||
|
field.Set(reflect.ValueOf(&defaultValue))
|
||||||
|
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
|
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
field.Elem().SetInt(intValue)
|
||||||
|
case reflect.Bool:
|
||||||
|
boolValue := defaultTag == "true"
|
||||||
|
field.Set(reflect.ValueOf(&boolValue))
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
return changed
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user