diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 088b9f4..427a192 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -11,12 +11,14 @@ import ( var ( maxTokens int + model string ) func init() { inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} for _, cmd := range inputCmds { cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens") + cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use") } rootCmd.AddCommand( @@ -274,7 +276,7 @@ var replyCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) + err = CreateChatCompletionStream(model, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } @@ -354,7 +356,7 @@ var newCmd = &cobra.Command{ response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) + err = CreateChatCompletionStream(model, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) } @@ -403,7 +405,7 @@ var promptCmd = &cobra.Command{ receiver := make(chan string) go HandleDelayedResponse(receiver) - err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, maxTokens, receiver) + err := CreateChatCompletionStream(model, messages, maxTokens, receiver) if err != nil { Fatal("%v\n", err) }