From 8bc831215428c8e8988c012fd766a91e20ea4998 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 18 Nov 2023 15:07:17 +0000 Subject: [PATCH] Add --length flag to control model output "maxTokens" --- pkg/cli/cmd.go | 16 +++++++++++----- pkg/cli/config.go | 1 + 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 0ab667e..088b9f4 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -9,10 +9,16 @@ import ( "github.com/spf13/cobra" ) -// TODO: allow setting with flag -const MAX_TOKENS = 256 +var ( + maxTokens int +) func init() { + inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} + for _, cmd := range inputCmds { + cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens") + } + rootCmd.AddCommand( lsCmd, newCmd, @@ -268,7 +274,7 @@ var replyCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) + err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } @@ -348,7 +354,7 @@ var newCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) + err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } @@ -397,7 +403,7 @@ var promptCmd = &cobra.Command{ receiver := make(chan string) go HandleDelayedResponse(receiver) - err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) + err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 41aa2a7..5e94dbb 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -12,6 +12,7 @@ type Config struct { OpenAI struct { APIKey string `yaml:"apiKey"` DefaultModel string `yaml:"defaultModel"` + DefaultMaxLength int `yaml:"defaultMaxLength"` } `yaml:"openai"` }