From 239ded18f357d01780a6db175fa8e2af8fe7eb2e Mon Sep 17 00:00:00 2001 From: Matt Low Date: Tue, 2 Jan 2024 04:31:21 +0000 Subject: [PATCH] Add edit command Various refactoring: - reduced repetition with conversation message handling - made some functions internal --- main.go | 2 +- pkg/cli/cmd.go | 419 ++++++++++++++++++++++------------------ pkg/cli/config.go | 4 +- pkg/cli/conversation.go | 19 +- pkg/cli/store.go | 8 +- pkg/cli/tty.go | 4 +- pkg/cli/util.go | 6 +- 7 files changed, 257 insertions(+), 205 deletions(-) diff --git a/main.go b/main.go index 52232c7..28f9102 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,7 @@ import ( func main() { if err := cli.Execute(); err != nil { - fmt.Fprint(os.Stderr, err) + fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) } } diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 5a42938..7409c11 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -18,7 +18,7 @@ var ( ) func init() { - inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd} + inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd} for _, cmd := range inputCmds { cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens") cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use") @@ -39,6 +39,7 @@ func init() { retryCmd, rmCmd, viewCmd, + editCmd, ) } @@ -57,43 +58,81 @@ func SystemPrompt() string { return systemPrompt } -// LLMRequest prompts the LLM with the given messages, writing the response -// to stdout. Returns all reply messages added by the LLM, including any -// function call messages. -func LLMRequest(messages []Message) ([]Message, error) { - // receiver receives the reponse from LLM - receiver := make(chan string) - defer close(receiver) +// fetchAndShowCompletion prompts the LLM with the given messages and streams +// the response to stdout. Returns all model reply messages. +func fetchAndShowCompletion(messages []Message) ([]Message, error) { + content := make(chan string) // receives the reponse from LLM + defer close(content) - // start HandleDelayedContent goroutine to print received data to stdout - go HandleDelayedContent(receiver) + // render all content received over the channel + go ShowDelayedContent(content) var replies []Message - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver, &replies) + response, err := CreateChatCompletionStream(model, messages, maxTokens, content, &replies) if response != "" { + // there was some content, so break to a new line after it + fmt.Println() + if err != nil { Warn("Received partial response. Error: %v\n", err) err = nil } - // there was some content, so break to a new line after it - fmt.Println() } return replies, err } -func (c *Conversation) GenerateAndSaveReplies(messages []Message) { - replies, err := LLMRequest(messages) +// lookupConversationByShortname either returns the conversation found by the +// short name or exits the program +func lookupConversationByShortname(shortName string) *Conversation { + c, err := store.ConversationByShortName(shortName) + if err != nil { + Fatal("Could not lookup conversation: %v\n", err) + } + if c.ID == 0 { + Fatal("Conversation not found with short name: %s\n", shortName) + } + return c +} + +// handleConversationReply handles sending messages to an existing +// conversation, optionally persisting them. It displays the entire +// conversation before +func handleConversationReply(c *Conversation, persist bool, toSend ...Message) { + existing, err := store.Messages(c) + if err != nil { + Fatal("Could not retrieve messages for conversation: %s\n", c.Title) + } + + if persist { + for _, message := range toSend { + err = store.SaveMessage(&message) + if err != nil { + Warn("Could not save %s message: %v\n", message.Role, err) + } + } + } + + allMessages := append(existing, toSend...) + + RenderConversation(allMessages, true) + + // render a message header with no contents + (&Message{Role: MessageRoleAssistant}).RenderTTY() + + replies, err := fetchAndShowCompletion(allMessages) if err != nil { Fatal("Error fetching LLM response: %v\n", err) } - for _, reply := range replies { - reply.ConversationID = c.ID + if persist { + for _, reply := range replies { + reply.ConversationID = c.ID - err = store.SaveMessage(&reply) - if err != nil { - Warn("Could not save reply: %v\n", err) + err = store.SaveMessage(&reply) + if err != nil { + Warn("Could not save reply: %v\n", err) + } } } } @@ -102,10 +141,10 @@ func (c *Conversation) GenerateAndSaveReplies(messages []Message) { // (joined with spaces), or if len(args) is 0, opens an editor and returns // whatever input was provided there. placeholder is a string which populates // the editor and gets stripped from the final output. -func InputFromArgsOrEditor(args []string, placeholder string) (message string) { +func InputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) { var err error if len(args) == 0 { - message, err = InputFromEditor(placeholder, "message.*.md") + message, err = InputFromEditor(placeholder, "message.*.md", existingMessage) if err != nil { Fatal("Failed to get input: %v\n", err) } @@ -116,9 +155,8 @@ func InputFromArgsOrEditor(args []string, placeholder string) (message string) { } var rootCmd = &cobra.Command{ - Use: "lmcli", - Short: "Interact with Large Language Models", - Long: `lmcli is a CLI tool to interact with Large Language Models.`, + Use: "lmcli", + Long: `lmcli - command-line interface with Large Language Models.`, Run: func(cmd *cobra.Command, args []string) { // execute `lm ls` by default }, @@ -242,13 +280,7 @@ var rmCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { var toRemove []*Conversation for _, shortName := range args { - conversation, err := store.ConversationByShortName(shortName) - if err != nil { - Fatal("Could not search for conversation: %v\n", err) - } - if conversation.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } + conversation := lookupConversationByShortname(shortName) toRemove = append(toRemove, conversation) } var errors []error @@ -268,7 +300,8 @@ var rmCmd = &cobra.Command{ ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp var completions []string - outer: for _, completion := range store.ConversationShortNameCompletions(toComplete) { + outer: + for _, completion := range store.ConversationShortNameCompletions(toComplete) { parts := strings.Split(completion, "\t") for _, arg := range args { if parts[0] == arg { @@ -294,10 +327,7 @@ var viewCmd = &cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { shortName := args[0] - conversation, err := store.ConversationByShortName(shortName) - if conversation.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } + conversation := lookupConversationByShortname(shortName) messages, err := store.Messages(conversation) if err != nil { @@ -315,115 +345,6 @@ var viewCmd = &cobra.Command{ }, } -var replyCmd = &cobra.Command{ - Use: "reply [message]", - Short: "Send a reply to a conversation", - Long: `Sends a reply to conversation and writes the response to stdout.`, - Args: func(cmd *cobra.Command, args []string) error { - argCount := 1 - if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { - return err - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - shortName := args[0] - conversation, err := store.ConversationByShortName(shortName) - if conversation.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } - - messages, err := store.Messages(conversation) - if err != nil { - Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) - } - - messageContents := InputFromArgsOrEditor(args[1:], "# How would you like to reply?\n") - if messageContents == "" { - Fatal("No reply was provided.\n") - } - - userReply := Message{ - ConversationID: conversation.ID, - Role: MessageRoleUser, - OriginalContent: messageContents, - } - - err = store.SaveMessage(&userReply) - if err != nil { - Warn("Could not save your reply: %v\n", err) - } - - messages = append(messages, userReply) - - RenderConversation(messages, true) - (&Message{Role: MessageRoleAssistant}).RenderTTY() - - conversation.GenerateAndSaveReplies(messages) - }, - ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - compMode := cobra.ShellCompDirectiveNoFileComp - if len(args) != 0 { - return nil, compMode - } - return store.ConversationShortNameCompletions(toComplete), compMode - }, -} - -var newCmd = &cobra.Command{ - Use: "new [message]", - Short: "Start a new conversation", - Long: `Start a new conversation with the Large Language Model.`, - Run: func(cmd *cobra.Command, args []string) { - messageContents := InputFromArgsOrEditor(args, "# What would you like to say?\n") - if messageContents == "" { - Fatal("No message was provided.\n") - } - - conversation := &Conversation{} - err := store.SaveConversation(conversation) - if err != nil { - Fatal("Could not save new conversation: %v\n", err) - } - - messages := []Message{ - { - ConversationID: conversation.ID, - Role: MessageRoleSystem, - OriginalContent: SystemPrompt(), - }, - { - ConversationID: conversation.ID, - Role: MessageRoleUser, - OriginalContent: messageContents, - }, - } - for _, message := range messages { - err = store.SaveMessage(&message) - if err != nil { - Warn("Could not save %s message: %v\n", message.Role, err) - } - } - - RenderConversation(messages, true) - (&Message{Role: MessageRoleAssistant}).RenderTTY() - - conversation.GenerateAndSaveReplies(messages) - - title, err := conversation.GenerateTitle() - if err != nil { - Warn("Could not generate title for conversation: %v\n", err) - } - - conversation.Title = title - - err = store.SaveConversation(conversation) - if err != nil { - Warn("Could not save conversation after generating title: %v\n", err) - } - }, -} - var renameCmd = &cobra.Command{ Use: "rename [title]", Short: "Rename a conversation", @@ -437,17 +358,15 @@ var renameCmd = &cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { shortName := args[0] - conversation, err := store.ConversationByShortName(shortName) - if conversation.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } + conversation := lookupConversationByShortname(shortName) + var err error generate, _ := cmd.Flags().GetBool("generate") var title string if generate { title, err = conversation.GenerateTitle() if err != nil { - Fatal("Could not generate title for conversation: %v\n", err) + Fatal("Could not generate conversation title: %v\n", err) } } else { if len(args) < 2 { @@ -471,12 +390,92 @@ var renameCmd = &cobra.Command{ }, } +var replyCmd = &cobra.Command{ + Use: "reply [message]", + Short: "Send a reply to a conversation", + Long: `Sends a reply to conversation and writes the response to stdout.`, + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + shortName := args[0] + conversation := lookupConversationByShortname(shortName) + + reply := InputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "") + if reply == "" { + Fatal("No reply was provided.\n") + } + + handleConversationReply(conversation, true, Message{ + ConversationID: conversation.ID, + Role: MessageRoleUser, + OriginalContent: reply, + }) + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return store.ConversationShortNameCompletions(toComplete), compMode + }, +} + +var newCmd = &cobra.Command{ + Use: "new [message]", + Short: "Start a new conversation", + Long: `Start a new conversation with the Large Language Model.`, + Run: func(cmd *cobra.Command, args []string) { + messageContents := InputFromArgsOrEditor(args, "# What would you like to say?\n", "") + if messageContents == "" { + Fatal("No message was provided.\n") + } + + conversation := &Conversation{} + err := store.SaveConversation(conversation) + if err != nil { + Fatal("Could not save new conversation: %v\n", err) + } + + messages := []Message{ + { + ConversationID: conversation.ID, + Role: MessageRoleSystem, + OriginalContent: SystemPrompt(), + }, + { + ConversationID: conversation.ID, + Role: MessageRoleUser, + OriginalContent: messageContents, + }, + } + + handleConversationReply(conversation, true, messages...) + + title, err := conversation.GenerateTitle() + if err != nil { + Warn("Could not generate title for conversation: %v\n", err) + } + + conversation.Title = title + + err = store.SaveConversation(conversation) + if err != nil { + Warn("Could not save conversation after generating title: %v\n", err) + } + }, +} + var promptCmd = &cobra.Command{ Use: "prompt [message]", Short: "Do a one-shot prompt", Long: `Prompt the Large Language Model and get a response.`, Run: func(cmd *cobra.Command, args []string) { - message := InputFromArgsOrEditor(args, "# What would you like to say?\n") + message := InputFromArgsOrEditor(args, "# What would you like to say?\n", "") if message == "" { Fatal("No message was provided.\n") } @@ -492,7 +491,7 @@ var promptCmd = &cobra.Command{ }, } - _, err := LLMRequest(messages) + _, err := fetchAndShowCompletion(messages) if err != nil { Fatal("Error fetching LLM response: %v\n", err) } @@ -512,39 +511,27 @@ var retryCmd = &cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { shortName := args[0] - conversation, err := store.ConversationByShortName(shortName) - if conversation.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } + conversation := lookupConversationByShortname(shortName) messages, err := store.Messages(conversation) if err != nil { Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) } - var lastUserMessageIndex int - // walk backwards through conversations to find last user message + // walk backwards through the conversation and delete messages, break + // when we find the latest user response for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == MessageRoleUser { - lastUserMessageIndex = i break } - if lastUserMessageIndex == 0 { - // haven't found the the last user message yet, delete this one - err = store.DeleteMessage(&messages[i]) - if err != nil { - Warn("Could not delete previous reply: %v\n", err) - } + err = store.DeleteMessage(&messages[i]) + if err != nil { + Warn("Could not delete previous reply: %v\n", err) } } - messages = messages[:lastUserMessageIndex+1] - - RenderConversation(messages, true) - (&Message{Role: MessageRoleAssistant}).RenderTTY() - - conversation.GenerateAndSaveReplies(messages) + handleConversationReply(conversation, true) }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp @@ -568,20 +555,80 @@ var continueCmd = &cobra.Command{ }, Run: func(cmd *cobra.Command, args []string) { shortName := args[0] - conversation, err := store.ConversationByShortName(shortName) - if conversation.ID == 0 { - Fatal("Conversation not found with short name: %s\n", shortName) - } - - messages, err := store.Messages(conversation) - if err != nil { - Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) - } - - RenderConversation(messages, true) - (&Message{Role: MessageRoleAssistant}).RenderTTY() - - conversation.GenerateAndSaveReplies(messages) + conversation := lookupConversationByShortname(shortName) + handleConversationReply(conversation, true) + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + compMode := cobra.ShellCompDirectiveNoFileComp + if len(args) != 0 { + return nil, compMode + } + return store.ConversationShortNameCompletions(toComplete), compMode + }, +} + +var editCmd = &cobra.Command{ + Use: "edit ", + Short: "Edit the last user message in a conversation.", + Args: func(cmd *cobra.Command, args []string) error { + argCount := 1 + if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { + return err + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + shortName := args[0] + conversation := lookupConversationByShortname(shortName) + + messages, err := store.Messages(conversation) + if err != nil { + Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) + } + + // walk backwards through the conversation deleting messages until and + // including the last user message + toRemove := []Message{} + var lastUserMessage *Message + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == MessageRoleUser { + lastUserMessage = &messages[i] + } + + toRemove = append(toRemove, messages[i]) + messages = messages[:i] + if lastUserMessage != nil { + break + } + } + + if lastUserMessage == nil { + Fatal("No messages left in the conversation, nothing to edit.\n") + } + + existingContents := lastUserMessage.OriginalContent + + newContents := InputFromArgsOrEditor(args[1:], "# Save when finished editing\n", existingContents) + if newContents == existingContents { + Fatal("No edits were made.\n") + } + + if newContents == "" { + Fatal("No message was provided.\n") + } + + for _, message := range(toRemove) { + err = store.DeleteMessage(&message) + if err != nil { + Warn("Could not delete message: %v\n", err) + } + } + + handleConversationReply(conversation, true, Message{ + ConversationID: conversation.ID, + Role: MessageRoleUser, + OriginalContent: newContents, + }) }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 53e25f6..16f0d75 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -24,7 +24,7 @@ type Config struct { } `yaml:"chroma"` } -func ConfigDir() string { +func configDir() string { var configDir string xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") @@ -40,7 +40,7 @@ func ConfigDir() string { } func NewConfig() (*Config, error) { - configFile := filepath.Join(ConfigDir(), "config.yaml") + configFile := filepath.Join(configDir(), "config.yaml") shouldWriteDefaults := false c := &Config{} diff --git a/pkg/cli/conversation.go b/pkg/cli/conversation.go index cac4b73..3ea3078 100644 --- a/pkg/cli/conversation.go +++ b/pkg/cli/conversation.go @@ -30,10 +30,15 @@ func (m *Message) FriendlyRole() string { } func (c *Conversation) GenerateTitle() (string, error) { - const header = "Generate a consise 4-5 word title for the conversation below." - prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting(false)) + messages, err := store.Messages(c) + if err != nil { + return "", err + } - messages := []Message{ + const header = "Generate a concise 4-5 word title for the conversation below." + prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, formatForExternalPrompting(messages, false)) + + generateRequest := []Message{ { Role: MessageRoleUser, OriginalContent: prompt, @@ -41,7 +46,7 @@ func (c *Conversation) GenerateTitle() (string, error) { } model := "gpt-3.5-turbo" // use cheap model to generate title - response, err := CreateChatCompletion(model, messages, 25, nil) + response, err := CreateChatCompletion(model, generateRequest, 25, nil) if err != nil { return "", err } @@ -49,12 +54,8 @@ func (c *Conversation) GenerateTitle() (string, error) { return response, nil } -func (c *Conversation) FormatForExternalPrompting(system bool) string { +func formatForExternalPrompting(messages []Message, system bool) string { sb := strings.Builder{} - messages, err := store.Messages(c) - if err != nil { - Fatal("Could not retrieve messages for conversation %v", c) - } for _, message := range messages { if message.Role == MessageRoleSystem && !system { continue diff --git a/pkg/cli/store.go b/pkg/cli/store.go index af9ed55..5ffbacd 100644 --- a/pkg/cli/store.go +++ b/pkg/cli/store.go @@ -1,6 +1,7 @@ package cli import ( + "errors" "database/sql" "fmt" "os" @@ -35,7 +36,7 @@ type Conversation struct { Title string } -func DataDir() string { +func dataDir() string { var dataDir string xdgDataHome := os.Getenv("XDG_DATA_HOME") @@ -51,7 +52,7 @@ func DataDir() string { } func NewStore() (*Store, error) { - databaseFile := filepath.Join(DataDir(), "conversations.db") + databaseFile := filepath.Join(dataDir(), "conversations.db") db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{}) if err != nil { return nil, fmt.Errorf("Error establishing connection to store: %v", err) @@ -119,6 +120,9 @@ func (s *Store) ConversationShortNameCompletions(shortName string) []string { } func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) { + if shortName == "" { + return nil, errors.New("shortName is empty") + } var conversation Conversation err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error return &conversation, err diff --git a/pkg/cli/tty.go b/pkg/cli/tty.go index 2a1e13e..ab275cc 100644 --- a/pkg/cli/tty.go +++ b/pkg/cli/tty.go @@ -36,12 +36,12 @@ func ShowWaitAnimation(signal chan any) { } } -// HandleDelayedContent displays a waiting animation to stdout while waiting +// ShowDelayedContent displays a waiting animation to stdout while waiting // for content to be received on the provided channel. As soon as any (possibly // chunked) content is received on the channel, the waiting animation is // replaced by the content. // Blocks until the channel is closed. -func HandleDelayedContent(content <-chan string) { +func ShowDelayedContent(content <-chan string) { waitSignal := make(chan any) go ShowWaitAnimation(waitSignal) diff --git a/pkg/cli/util.go b/pkg/cli/util.go index 5acc547..1bd9c95 100644 --- a/pkg/cli/util.go +++ b/pkg/cli/util.go @@ -17,11 +17,11 @@ import ( // contents of the file exactly match the value of placeholder (no edits to the // file were made), then an empty string is returned. Otherwise, the contents // are returned. Example patten: message.*.md -func InputFromEditor(placeholder string, pattern string) (string, error) { +func InputFromEditor(placeholder string, pattern string, content string) (string, error) { msgFile, _ := os.CreateTemp("/tmp", pattern) defer os.Remove(msgFile.Name()) - os.WriteFile(msgFile.Name(), []byte(placeholder), os.ModeAppend) + os.WriteFile(msgFile.Name(), []byte(placeholder + content), os.ModeAppend) editor := os.Getenv("EDITOR") if editor == "" { @@ -38,7 +38,7 @@ func InputFromEditor(placeholder string, pattern string) (string, error) { } bytes, _ := os.ReadFile(msgFile.Name()) - content := string(bytes) + content = string(bytes) if placeholder != "" { if content == placeholder {