Add --length flag to control model output "maxTokens"

This commit is contained in:
Matt Low 2023-11-18 15:07:17 +00:00
parent 681b52a55c
commit 8bc8312154
2 changed files with 12 additions and 5 deletions

View File

@ -9,10 +9,16 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// TODO: allow setting with flag var (
const MAX_TOKENS = 256 maxTokens int
)
func init() { func init() {
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
for _, cmd := range inputCmds {
cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens")
}
rootCmd.AddCommand( rootCmd.AddCommand(
lsCmd, lsCmd,
newCmd, newCmd,
@ -268,7 +274,7 @@ var replyCmd = &cobra.Command{
response <- HandleDelayedResponse(receiver) response <- HandleDelayedResponse(receiver)
}() }()
err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }
@ -348,7 +354,7 @@ var newCmd = &cobra.Command{
response <- HandleDelayedResponse(receiver) response <- HandleDelayedResponse(receiver)
}() }()
err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }
@ -397,7 +403,7 @@ var promptCmd = &cobra.Command{
receiver := make(chan string) receiver := make(chan string)
go HandleDelayedResponse(receiver) go HandleDelayedResponse(receiver)
err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }

View File

@ -12,6 +12,7 @@ type Config struct {
OpenAI struct { OpenAI struct {
APIKey string `yaml:"apiKey"` APIKey string `yaml:"apiKey"`
DefaultModel string `yaml:"defaultModel"` DefaultModel string `yaml:"defaultModel"`
DefaultMaxLength int `yaml:"defaultMaxLength"`
} `yaml:"openai"` } `yaml:"openai"`
} }