Add retry and continue commands

This commit is contained in:
Matt Low 2023-11-22 06:53:22 +00:00
parent cb9e27542e
commit a2bd911ac8

View File

@ -17,7 +17,7 @@ var (
)
func init() {
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd}
for _, cmd := range inputCmds {
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens")
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use")
@ -27,10 +27,12 @@ func init() {
}
rootCmd.AddCommand(
continueCmd,
lsCmd,
newCmd,
promptCmd,
replyCmd,
retryCmd,
rmCmd,
viewCmd,
)
@ -432,3 +434,134 @@ var promptCmd = &cobra.Command{
fmt.Println()
},
}
var retryCmd = &cobra.Command{
Use: "retry <conversation>",
Short: "Retries the last conversation prompt.",
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
messages, err := store.Messages(conversation)
if err != nil {
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
}
var lastUserMessageIndex int
for i := len(messages) - 1; i >=0; i-- {
if messages[i].Role == "user" {
lastUserMessageIndex = i
break
}
}
messages = messages[:lastUserMessageIndex+1]
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
receiver := make(chan string)
response := make(chan string)
go func() {
response <- HandleStreamedResponse(receiver)
}()
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil {
Fatal("%v\n", err)
}
assistantReply.OriginalContent = <-response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
fmt.Println()
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}
var continueCmd = &cobra.Command{
Use: "continue <conversation>",
Short: "Continues where the previous prompt left off.",
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
Args: func(cmd *cobra.Command, args []string) error {
argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
return err
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
shortName := args[0]
conversation, err := store.ConversationByShortName(shortName)
if conversation.ID == 0 {
Fatal("Conversation not found with short name: %s\n", shortName)
}
messages, err := store.Messages(conversation)
if err != nil {
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
}
RenderConversation(messages, true)
assistantReply := Message{
ConversationID: conversation.ID,
Role: "assistant",
}
assistantReply.RenderTTY()
receiver := make(chan string)
response := make(chan string)
go func() {
response <- HandleStreamedResponse(receiver)
}()
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil {
Fatal("%v\n", err)
}
assistantReply.OriginalContent = <-response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
fmt.Println()
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return store.ConversationShortNameCompletions(toComplete), compMode
},
}