diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 0ab667e..088b9f4 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -9,10 +9,16 @@ import ( "github.com/spf13/cobra" ) -// TODO: allow setting with flag -const MAX_TOKENS = 256 +var ( + maxTokens int +) func init() { + inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} + for _, cmd := range inputCmds { + cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens") + } + rootCmd.AddCommand( lsCmd, newCmd, @@ -268,7 +274,7 @@ var replyCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) + err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } @@ -348,7 +354,7 @@ var newCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) + err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } @@ -397,7 +403,7 @@ var promptCmd = &cobra.Command{ receiver := make(chan string) go HandleDelayedResponse(receiver) - err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) + err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 41aa2a7..5e94dbb 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -12,6 +12,7 @@ type Config struct { OpenAI struct { APIKey string `yaml:"apiKey"` DefaultModel string `yaml:"defaultModel"` + DefaultMaxLength int `yaml:"defaultMaxLength"` } `yaml:"openai"` }