From a2bd911ac856d900e7b90f7809a6241904b182e8 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Wed, 22 Nov 2023 06:53:22 +0000 Subject: [PATCH] Add `retry` and `continue` commands --- pkg/cli/cmd.go | 135 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 134 insertions(+), 1 deletion(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 8ead74e..41ea536 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -17,7 +17,7 @@ var ( ) func init() { - inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd} + inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd} for _, cmd := range inputCmds { cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens") cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use") @@ -27,10 +27,12 @@ func init() { } rootCmd.AddCommand( + continueCmd, lsCmd, newCmd, promptCmd, replyCmd, + retryCmd, rmCmd, viewCmd, ) @@ -432,3 +434,134 @@ var promptCmd = &cobra.Command{ fmt.Println() }, } + +var retryCmd = &cobra.Command{ + Use: "retry ", + Short: "Retries the last conversation prompt.", + Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + shortName := args[0] + conversation, err := store.ConversationByShortName(shortName) + if conversation.ID == 0 { + Fatal("Conversation not found with short name: %s\n", shortName) + } + + messages, err := store.Messages(conversation) + if err != nil { + Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) + } + + var lastUserMessageIndex int + for i := len(messages) - 1; i >=0; i-- { + if messages[i].Role == "user" { + lastUserMessageIndex = i + break + } + } + + messages = messages[:lastUserMessageIndex+1] + + RenderConversation(messages, true) + assistantReply := Message{ + ConversationID: conversation.ID, + Role: "assistant", + } + assistantReply.RenderTTY() + + receiver := make(chan string) + response := make(chan string) + go func() { + response <- HandleStreamedResponse(receiver) + }() + + err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + if err != nil { + Fatal("%v\n", err) + } + + assistantReply.OriginalContent = <-response + + err = store.SaveMessage(&assistantReply) + if err != nil { + Fatal("Could not save assistant reply: %v\n", err) + } + + fmt.Println() + + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return store.ConversationShortNameCompletions(toComplete), compMode + }, +} + + +var continueCmd = &cobra.Command{ + Use: "continue ", + Short: "Continues where the previous prompt left off.", + Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + shortName := args[0] + conversation, err := store.ConversationByShortName(shortName) + if conversation.ID == 0 { + Fatal("Conversation not found with short name: %s\n", shortName) + } + + messages, err := store.Messages(conversation) + if err != nil { + Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) + } + + RenderConversation(messages, true) + assistantReply := Message{ + ConversationID: conversation.ID, + Role: "assistant", + } + assistantReply.RenderTTY() + + receiver := make(chan string) + response := make(chan string) + go func() { + response <- HandleStreamedResponse(receiver) + }() + + err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + if err != nil { + Fatal("%v\n", err) + } + + assistantReply.OriginalContent = <-response + + err = store.SaveMessage(&assistantReply) + if err != nil { + Fatal("Could not save assistant reply: %v\n", err) + } + + fmt.Println() + + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return store.ConversationShortNameCompletions(toComplete), compMode + }, +}