From 8c53752146314b19547e2e75947eb66fc8915343 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 20 May 2024 18:12:44 +0000 Subject: [PATCH] Add message branching Updated the behaviour of commands: - `lmcli edit` - by default create a new branch/message branch with the edited contents - add --in-place to avoid creating a branch - no longer delete messages after the edited message - only do the edit, don't fetch a new response - `lmcli retry` - create a new branch rather than replacing old messages - add --offset to change where to retry from --- pkg/cmd/clone.go | 31 +-- pkg/cmd/continue.go | 2 +- pkg/cmd/edit.go | 65 +++---- pkg/cmd/list.go | 22 +-- pkg/cmd/new.go | 47 ++--- pkg/cmd/prompt.go | 25 ++- pkg/cmd/rename.go | 13 +- pkg/cmd/reply.go | 6 +- pkg/cmd/retry.go | 41 ++-- pkg/cmd/util/util.go | 57 +++--- pkg/cmd/view.go | 4 +- pkg/lmcli/lmcli.go | 4 +- pkg/lmcli/model/conversation.go | 21 +- pkg/lmcli/store.go | 331 +++++++++++++++++++++++++++----- pkg/tui/chat.go | 125 +++++------- pkg/tui/conversations.go | 19 +- 16 files changed, 505 insertions(+), 308 deletions(-) diff --git a/pkg/cmd/clone.go b/pkg/cmd/clone.go index a32024f..5055257 100644 --- a/pkg/cmd/clone.go +++ b/pkg/cmd/clone.go @@ -5,7 +5,6 @@ import ( cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -28,36 +27,12 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command { return err } - messagesToCopy, err := ctx.Store.Messages(toClone) + clone, messageCnt, err := ctx.Store.CloneConversation(*toClone) if err != nil { - return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String) + return fmt.Errorf("Failed to clone conversation: %v", err) } - clone := &model.Conversation{ - Title: toClone.Title + " - Clone", - } - if err := ctx.Store.SaveConversation(clone); err != nil { - return fmt.Errorf("Cloud not create clone: %s", err) - } - - var errors []error - messageCnt := 0 - for _, message := range messagesToCopy { - newMessage := message - newMessage.ConversationID = clone.ID - newMessage.ID = 0 - if err := ctx.Store.SaveMessage(&newMessage); err != nil { - errors = append(errors, err) - } else { - messageCnt++ - } - } - - if len(errors) > 0 { - return fmt.Errorf("Messages failed to be cloned: %v", errors) - } - - fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title) + fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title) return nil }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index 0869769..1ee4317 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.Messages(conversation) + messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) if err != nil { return fmt.Errorf("could not retrieve conversation messages: %v", err) } diff --git a/pkg/cmd/edit.go b/pkg/cmd/edit.go index b8bcbaf..c710a95 100644 --- a/pkg/cmd/edit.go +++ b/pkg/cmd/edit.go @@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.Messages(conversation) + messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) if err != nil { return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) } @@ -39,21 +39,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { } desiredIdx := len(messages) - 1 - offset - - // walk backwards through the conversation deleting messages until and - // including the last user message - toRemove := []model.Message{} - var toEdit *model.Message - for i := len(messages) - 1; i >= 0; i-- { - if i == desiredIdx { - toEdit = &messages[i] - } - toRemove = append(toRemove, messages[i]) - messages = messages[:i] - if toEdit != nil { - break - } - } + toEdit := messages[desiredIdx] newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content) switch newContents { @@ -63,26 +49,38 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } + toEdit.Content = newContents + role, _ := cmd.Flags().GetString("role") - if role == "" { - role = string(toEdit.Role) - } else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { - return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") - } - - for _, message := range toRemove { - err = ctx.Store.DeleteMessage(&message) - if err != nil { - lmcli.Warn("Could not delete message: %v\n", err) + if role != "" { + if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { + return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") } + toEdit.Role = model.MessageRole(role) } - cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ - ConversationID: conversation.ID, - Role: model.MessageRole(role), - Content: newContents, - }) - return nil + // Update the message in-place + inplace, _ := cmd.Flags().GetBool("in-place") + if inplace { + return ctx.Store.UpdateMessage(&toEdit) + } + + // Otherwise, create a branch for the edited message + message, _, err := ctx.Store.CloneBranch(toEdit) + if err != nil { + return err + } + + if desiredIdx > 0 { + // update selected reply + messages[desiredIdx-1].SelectedReply = message + err = ctx.Store.UpdateMessage(&messages[desiredIdx-1]) + } else { + // update selected root + conversation.SelectedRoot = message + err = ctx.Store.UpdateConversation(conversation) + } + return err }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp @@ -93,8 +91,9 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { }, } + cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch") cmd.Flags().Int("offset", 1, "Offset from the last message to edit") - cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)") + cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)") return cmd } diff --git a/pkg/cmd/list.go b/pkg/cmd/list.go index 317deac..6367ec0 100644 --- a/pkg/cmd/list.go +++ b/pkg/cmd/list.go @@ -2,7 +2,6 @@ package cmd import ( "fmt" - "slices" "time" "git.mlow.ca/mlow/lmcli/pkg/lmcli" @@ -21,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { Short: "List conversations", Long: `List conversations in order of recent activity`, RunE: func(cmd *cobra.Command, args []string) error { - conversations, err := ctx.Store.Conversations() + messages, err := ctx.Store.LatestConversationMessages() if err != nil { return fmt.Errorf("Could not fetch conversations: %v", err) } @@ -58,13 +57,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { all, _ := cmd.Flags().GetBool("all") - for _, conversation := range conversations { - lastMessage, err := ctx.Store.LastMessage(&conversation) - if lastMessage == nil || err != nil { - continue - } - - messageAge := now.Sub(lastMessage.CreatedAt) + for _, message := range messages { + messageAge := now.Sub(message.CreatedAt) var category string for _, c := range categories { @@ -76,9 +70,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { formatted := fmt.Sprintf( "%s - %s - %s", - conversation.ShortName.String, + message.Conversation.ShortName.String, util.HumanTimeElapsedSince(messageAge), - conversation.Title, + message.Conversation.Title, ) categorized[category] = append( @@ -96,14 +90,10 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { continue } - slices.SortFunc(conversationLines, func(a, b ConversationLine) int { - return int(a.timeSinceReply - b.timeSinceReply) - }) - fmt.Printf("%s:\n", category.name) for _, conv := range conversationLines { if conversationsPrinted >= count && !all { - fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted) + fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted) break outer } diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go index 0ef6a5d..e5a74da 100644 --- a/pkg/cmd/new.go +++ b/pkg/cmd/new.go @@ -15,42 +15,43 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { Short: "Start a new conversation", Long: `Start a new conversation with the Large Language Model.`, RunE: func(cmd *cobra.Command, args []string) error { - messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") - if messageContents == "" { + input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "") + if input == "" { return fmt.Errorf("No message was provided.") } - conversation := &model.Conversation{} - err := ctx.Store.SaveConversation(conversation) - if err != nil { - return fmt.Errorf("Could not save new conversation: %v", err) + var messages []model.Message + + // TODO: probably just make this part of the conversation + system := ctx.GetSystemPrompt() + if system != "" { + messages = append(messages, model.Message{ + Role: model.MessageRoleSystem, + Content: system, + }) } - messages := []model.Message{ - { - ConversationID: conversation.ID, - Role: model.MessageRoleSystem, - Content: ctx.GetSystemPrompt(), - }, - { - ConversationID: conversation.ID, - Role: model.MessageRoleUser, - Content: messageContents, - }, + messages = append(messages, model.Message{ + Role: model.MessageRoleUser, + Content: input, + }) + + conversation, messages, err := ctx.Store.StartConversation(messages...) + if err != nil { + return fmt.Errorf("Could not start a new conversation: %v", err) } - cmdutil.HandleConversationReply(ctx, conversation, true, messages...) + cmdutil.HandleReply(ctx, &messages[len(messages)-1], true) - title, err := cmdutil.GenerateTitle(ctx, conversation) + title, err := cmdutil.GenerateTitle(ctx, messages) if err != nil { - lmcli.Warn("Could not generate title for conversation: %v\n", err) + lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err) } conversation.Title = title - - err = ctx.Store.SaveConversation(conversation) + err = ctx.Store.UpdateConversation(conversation) if err != nil { - lmcli.Warn("Could not save conversation after generating title: %v\n", err) + lmcli.Warn("Could not save conversation title: %v\n", err) } return nil }, diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 8e9411b..13e5998 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -15,22 +15,27 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { Short: "Do a one-shot prompt", Long: `Prompt the Large Language Model and get a response.`, RunE: func(cmd *cobra.Command, args []string) error { - message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") - if message == "" { + input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "") + if input == "" { return fmt.Errorf("No message was provided.") } - messages := []model.Message{ - { + var messages []model.Message + + // TODO: stop supplying system prompt as a message + system := ctx.GetSystemPrompt() + if system != "" { + messages = append(messages, model.Message{ Role: model.MessageRoleSystem, - Content: ctx.GetSystemPrompt(), - }, - { - Role: model.MessageRoleUser, - Content: message, - }, + Content: system, + }) } + messages = append(messages, model.Message{ + Role: model.MessageRoleUser, + Content: input, + }) + _, err := cmdutil.Prompt(ctx, messages, nil) if err != nil { return fmt.Errorf("Error fetching LLM response: %v", err) diff --git a/pkg/cmd/rename.go b/pkg/cmd/rename.go index 6a75725..c45bbfd 100644 --- a/pkg/cmd/rename.go +++ b/pkg/cmd/rename.go @@ -24,12 +24,17 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) + var err error + var title string generate, _ := cmd.Flags().GetBool("generate") - var title string if generate { - title, err = cmdutil.GenerateTitle(ctx, conversation) + messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) + if err != nil { + return fmt.Errorf("Could not retrieve conversation messages: %v", err) + } + title, err = cmdutil.GenerateTitle(ctx, messages) if err != nil { return fmt.Errorf("Could not generate conversation title: %v", err) } @@ -41,9 +46,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command { } conversation.Title = title - err = ctx.Store.SaveConversation(conversation) + err = ctx.Store.UpdateConversation(conversation) if err != nil { - lmcli.Warn("Could not save conversation with new title: %v\n", err) + lmcli.Warn("Could not update conversation title: %v\n", err) } return nil }, diff --git a/pkg/cmd/reply.go b/pkg/cmd/reply.go index 25292cb..3063427 100644 --- a/pkg/cmd/reply.go +++ b/pkg/cmd/reply.go @@ -31,9 +31,9 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { } cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ - ConversationID: conversation.ID, - Role: model.MessageRoleUser, - Content: reply, + ConversationID: conversation.ID, + Role: model.MessageRoleUser, + Content: reply, }) return nil }, diff --git a/pkg/cmd/retry.go b/pkg/cmd/retry.go index f0e2ba0..3e6b1fe 100644 --- a/pkg/cmd/retry.go +++ b/pkg/cmd/retry.go @@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { cmd := &cobra.Command{ Use: "retry ", Short: "Retry the last user reply in a conversation", - Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`, + Long: `Prompt the conversation from the last user response.`, Args: func(cmd *cobra.Command, args []string) error { argCount := 1 if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { @@ -25,25 +25,36 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.Messages(conversation) + // Load the complete thread from the root message + messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) if err != nil { return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) } - // walk backwards through the conversation and delete messages, break - // when we find the latest user response - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == model.MessageRoleUser { - break - } - - err = ctx.Store.DeleteMessage(&messages[i]) - if err != nil { - lmcli.Warn("Could not delete previous reply: %v\n", err) - } + offset, _ := cmd.Flags().GetInt("offset") + if offset < 0 { + offset = -offset } - cmdutil.HandleConversationReply(ctx, conversation, true) + if offset > len(messages)-1 { + return fmt.Errorf("Offset %d is before the start of the conversation.", offset) + } + + retryFromIdx := len(messages) - 1 - offset + + // decrease retryFromIdx until we hit a user message + for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser { + retryFromIdx-- + } + + if messages[retryFromIdx].Role != model.MessageRoleUser { + return fmt.Errorf("No user messages to retry") + } + + fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx]) + + // Start a new branch at the last user message + cmdutil.HandleReply(ctx, &messages[retryFromIdx], true) return nil }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { @@ -55,6 +66,8 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { }, } + cmd.Flags().Int("offset", 1, "Offset from the last message retry from.") + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 5f325ba..f249a5f 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -73,43 +73,57 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat return c, nil } +func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { + messages, err := ctx.Store.PathToLeaf(c.SelectedRoot) + if err != nil { + lmcli.Fatal("Could not load messages: %v\n", err) + } + HandleReply(ctx, &messages[len(messages)-1], persist, toSend...) +} + // handleConversationReply handles sending messages to an existing // conversation, optionally persisting both the sent replies and responses. -func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { - existing, err := ctx.Store.Messages(c) - if err != nil { - lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title) +func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) { + if to == nil { + lmcli.Fatal("Can't prompt from an empty message.") } - if persist { - for _, message := range toSend { - err = ctx.Store.SaveMessage(&message) - if err != nil { - lmcli.Warn("Could not save %s message: %v\n", message.Role, err) - } + existing, err := ctx.Store.PathToRoot(to) + if err != nil { + lmcli.Fatal("Could not load messages: %v\n", err) + } + + RenderConversation(ctx, append(existing, messages...), true) + + var savedReplies []model.Message + if persist && len(messages) > 0 { + savedReplies, err = ctx.Store.Reply(to, messages...) + if err != nil { + lmcli.Warn("Could not save messages: %v\n", err) } } - allMessages := append(existing, toSend...) - - RenderConversation(ctx, allMessages, true) - // render a message header with no contents RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) + var lastSavedMessage *model.Message + lastSavedMessage = to + if len(savedReplies) > 0 { + lastSavedMessage = &savedReplies[len(savedReplies)-1] + } + replyCallback := func(reply model.Message) { if !persist { return } - - reply.ConversationID = c.ID - err = ctx.Store.SaveMessage(&reply) + savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply) if err != nil { lmcli.Warn("Could not save reply: %v\n", err) } + lastSavedMessage = &savedReplies[0] } - _, err = Prompt(ctx, allMessages, replyCallback) + _, err = Prompt(ctx, append(existing, messages...), replyCallback) if err != nil { lmcli.Fatal("Error fetching LLM response: %v\n", err) } @@ -134,12 +148,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string { return sb.String() } -func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) { - messages, err := ctx.Store.Messages(c) - if err != nil { - return "", err - } - +func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) { const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective. Example conversation: diff --git a/pkg/cmd/view.go b/pkg/cmd/view.go index 5cffc55..772e869 100644 --- a/pkg/cmd/view.go +++ b/pkg/cmd/view.go @@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.Messages(conversation) + messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) if err != nil { - return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) + return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err) } cmdutil.RenderConversation(ctx, messages, false) diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 5efba8d..814be7a 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -36,7 +36,9 @@ func NewContext() (*Context, error) { } databaseFile := filepath.Join(dataDir(), "conversations.db") - db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{}) + db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{ + //Logger: logger.Default.LogMode(logger.Info), + }) if err != nil { return nil, fmt.Errorf("Error establishing connection to store: %v", err) } diff --git a/pkg/lmcli/model/conversation.go b/pkg/lmcli/model/conversation.go index 3aa1516..fb1d60d 100644 --- a/pkg/lmcli/model/conversation.go +++ b/pkg/lmcli/model/conversation.go @@ -16,19 +16,28 @@ const ( ) type Message struct { - ID uint `gorm:"primaryKey"` - ConversationID uint `gorm:"foreignKey:ConversationID"` + ID uint `gorm:"primaryKey"` + ConversationID uint `gorm:"index"` + Conversation Conversation `gorm:"foreignKey:ConversationID"` Content string Role MessageRole CreatedAt time.Time - ToolCalls ToolCalls // a json array of tool calls (from the modl) + ToolCalls ToolCalls // a json array of tool calls (from the model) ToolResults ToolResults // a json array of tool results + ParentID *uint + Parent *Message `gorm:"foreignKey:ParentID"` + Replies []Message `gorm:"foreignKey:ParentID"` + + SelectedReplyID *uint + SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` } type Conversation struct { - ID uint `gorm:"primaryKey"` - ShortName sql.NullString - Title string + ID uint `gorm:"primaryKey"` + ShortName sql.NullString + Title string + SelectedRootID *uint + SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` } type RequestParameters struct { diff --git a/pkg/lmcli/store.go b/pkg/lmcli/store.go index 3cecb86..97a80da 100644 --- a/pkg/lmcli/store.go +++ b/pkg/lmcli/store.go @@ -13,21 +13,26 @@ import ( ) type ConversationStore interface { - Conversations() ([]model.Conversation, error) - ConversationByShortName(shortName string) (*model.Conversation, error) ConversationShortNameCompletions(search string) []string + RootMessages(conversationID uint) ([]model.Message, error) + LatestConversationMessages() ([]model.Message, error) - SaveConversation(conversation *model.Conversation) error + StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) + UpdateConversation(conversation *model.Conversation) error DeleteConversation(conversation *model.Conversation) error + CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) - Messages(conversation *model.Conversation) ([]model.Message, error) - LastMessage(conversation *model.Conversation) (*model.Message, error) + MessageByID(messageID uint) (*model.Message, error) + MessageReplies(messageID uint) ([]model.Message, error) - SaveMessage(message *model.Message) error - DeleteMessage(message *model.Message) error UpdateMessage(message *model.Message) error - AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error) + DeleteMessage(message *model.Message, prune bool) error + CloneBranch(toClone model.Message) (*model.Message, uint, error) + Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) + + PathToRoot(message *model.Message) ([]model.Message, error) + PathToLeaf(message *model.Message) ([]model.Message, error) } type SQLStore struct { @@ -52,47 +57,52 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) { return &SQLStore{db, _sqids}, nil } -func (s *SQLStore) SaveConversation(conversation *model.Conversation) error { - err := s.db.Save(&conversation).Error +func (s *SQLStore) saveNewConversation(c *model.Conversation) error { + // Save the new conversation + err := s.db.Save(&c).Error if err != nil { return err } - if !conversation.ShortName.Valid { - shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)}) - conversation.ShortName = sql.NullString{String: shortName, Valid: true} - err = s.db.Save(&conversation).Error + // Generate and save its "short name" + shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)}) + c.ShortName = sql.NullString{String: shortName, Valid: true} + return s.UpdateConversation(c) +} + +func (s *SQLStore) UpdateConversation(c *model.Conversation) error { + if c == nil || c.ID == 0 { + return fmt.Errorf("Conversation is nil or invalid (missing ID)") } - - return err + return s.db.Updates(&c).Error } -func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error { - s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{}) - return s.db.Delete(&conversation).Error +func (s *SQLStore) DeleteConversation(c *model.Conversation) error { + // Delete messages first + err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error + if err != nil { + return err + } + return s.db.Delete(&c).Error } -func (s *SQLStore) SaveMessage(message *model.Message) error { - return s.db.Create(message).Error +func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error { + panic("Not yet implemented") + //return s.db.Delete(&message).Error } -func (s *SQLStore) DeleteMessage(message *model.Message) error { - return s.db.Delete(&message).Error -} - -func (s *SQLStore) UpdateMessage(message *model.Message) error { - return s.db.Updates(&message).Error -} - -func (s *SQLStore) Conversations() ([]model.Conversation, error) { - var conversations []model.Conversation - err := s.db.Find(&conversations).Error - return conversations, err +func (s *SQLStore) UpdateMessage(m *model.Message) error { + if m == nil || m.ID == 0 { + return fmt.Errorf("Message is nil or invalid (missing ID)") + } + return s.db.Updates(&m).Error } func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { - var completions []string - conversations, _ := s.Conversations() // ignore error for completions + var conversations []model.Conversation + // ignore error for completions + s.db.Find(&conversations) + completions := make([]string, 0, len(conversations)) for _, conversation := range conversations { if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) @@ -106,27 +116,250 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio return nil, errors.New("shortName is empty") } var conversation model.Conversation - err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error + err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error return &conversation, err } -func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) { - var messages []model.Message - err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error - return messages, err +func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) { + var rootMessages []model.Message + err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error + if err != nil { + return nil, err + } + return rootMessages, nil } -func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) { +func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) { var message model.Message - err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error + err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error return &message, err } -// AddReply adds the given messages as a reply to the given conversation, can be -// used to easily copy a message associated with one conversation, to another -func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) { - m.ConversationID = c.ID - m.ID = 0 - m.CreatedAt = time.Time{} - return &m, s.SaveMessage(&m) +func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) { + var replies []model.Message + err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error + return replies, err +} + +// StartConversation starts a new conversation with the provided messages +func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) { + if len(messages) == 0 { + return nil, nil, fmt.Errorf("Must provide at least 1 message") + } + + // Create new conversation + conversation := &model.Conversation{} + err := s.saveNewConversation(conversation) + if err != nil { + return nil, nil, err + } + + // Create first message + messages[0].ConversationID = conversation.ID + err = s.db.Create(&messages[0]).Error + if err != nil { + return nil, nil, err + } + + // Update conversation's selected root message + conversation.SelectedRoot = &messages[0] + err = s.UpdateConversation(conversation) + if err != nil { + return nil, nil, err + } + + // Add additional replies to conversation + if len(messages) > 1 { + newMessages, err := s.Reply(&messages[0], messages[1:]...) + if err != nil { + return nil, nil, err + } + messages = append([]model.Message{messages[0]}, newMessages...) + } + return conversation, messages, nil +} + +// CloneConversation clones the given conversation and all of its root meesages +func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) { + rootMessages, err := s.RootMessages(toClone.ID) + if err != nil { + return nil, 0, err + } + + clone := &model.Conversation{ + Title: toClone.Title + " - Clone", + } + if err := s.saveNewConversation(clone); err != nil { + return nil, 0, fmt.Errorf("Could not create clone: %s", err) + } + + var errors []error + var messageCnt uint = 0 + for _, root := range rootMessages { + messageCnt++ + newRoot := root + newRoot.ConversationID = clone.ID + + cloned, count, err := s.CloneBranch(newRoot) + if err != nil { + errors = append(errors, err) + continue + } + messageCnt += count + + if root.ID == *toClone.SelectedRootID { + clone.SelectedRootID = &cloned.ID + if err := s.UpdateConversation(clone); err != nil { + errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err)) + } + } + } + + if len(errors) > 0 { + return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors) + } + + return clone, messageCnt, nil +} + +// Reply to a message with a series of messages (each following the next) +func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) { + var savedMessages []model.Message + + err := s.db.Transaction(func(tx *gorm.DB) error { + currentParent := to + for i := range messages { + message := messages[i] + message.ConversationID = currentParent.ConversationID + message.ParentID = ¤tParent.ID + message.ID = 0 + message.CreatedAt = time.Time{} + + if err := tx.Create(&message).Error; err != nil { + return err + } + + // update parent selected reply + currentParent.SelectedReply = &message + if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil { + return err + } + + savedMessages = append(savedMessages, message) + currentParent = &message + } + return nil + }) + + return savedMessages, err +} + +// CloneBranch returns a deep clone of the given message and its replies, returning +// a new message object. The new message will be attached to the same parent as +// the message to clone. +func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) { + newMessage := messageToClone + newMessage.ID = 0 + newMessage.Replies = nil + newMessage.SelectedReplyID = nil + newMessage.SelectedReply = nil + + originalReplies, err := s.MessageReplies(messageToClone.ID) + if err != nil { + return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err) + } + + if err := s.db.Create(&newMessage).Error; err != nil { + return nil, 0, fmt.Errorf("Could not clone message: %s", err) + } + + var replyCount uint = 0 + for _, reply := range originalReplies { + replyCount++ + + newReply := reply + newReply.ConversationID = messageToClone.ConversationID + newReply.ParentID = &newMessage.ID + newReply.Parent = &newMessage + + res, c, err := s.CloneBranch(newReply) + if err != nil { + return nil, 0, err + } + newMessage.Replies = append(newMessage.Replies, *res) + replyCount += c + + if reply.ID == *messageToClone.SelectedReplyID { + newMessage.SelectedReplyID = &res.ID + if err := s.UpdateMessage(&newMessage); err != nil { + return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err) + } + } + } + return &newMessage, replyCount, nil +} + +// PathToRoot traverses message Parent until reaching the tree root +func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { + if message == nil { + return nil, fmt.Errorf("Message is nil") + } + var path []model.Message + current := message + for { + path = append([]model.Message{*current}, path...) + if current.Parent == nil { + break + } + + var err error + current, err = s.MessageByID(*current.ParentID) + if err != nil { + return nil, fmt.Errorf("finding parent message: %w", err) + } + } + return path, nil +} + +// PathToLeaf traverses message SelectedReply until reaching a tree leaf +func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) { + if message == nil { + return nil, fmt.Errorf("Message is nil") + } + var path []model.Message + current := message + for { + path = append(path, *current) + if current.SelectedReplyID == nil { + break + } + + var err error + current, err = s.MessageByID(*current.SelectedReplyID) + if err != nil { + return nil, fmt.Errorf("finding selected reply: %w", err) + } + } + return path, nil +} + +func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) { + var latestMessages []model.Message + + subQuery := s.db.Model(&model.Message{}). + Select("MAX(created_at) as max_created_at, conversation_id"). + Group("conversation_id") + + err := s.db.Model(&model.Message{}). + Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). + Group("messages.conversation_id"). + Order("created_at DESC"). + Preload("Conversation"). + Find(&latestMessages).Error + + if err != nil { + return nil, err + } + + return latestMessages, nil } diff --git a/pkg/tui/chat.go b/pkg/tui/chat.go index 2ff20a0..21360df 100644 --- a/pkg/tui/chat.go +++ b/pkg/tui/chat.go @@ -307,14 +307,9 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) { } if m.persistence { - var err error - if m.conversation.ID == 0 { - err = m.ctx.Store.SaveConversation(m.conversation) - } + err := m.persistConversation() if err != nil { cmds = append(cmds, wrapError(err)) - } else { - cmds = append(cmds, m.persistConversation()) } } @@ -341,7 +336,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) { title := string(msg) m.conversation.Title = title if m.persistence { - err := m.ctx.Store.SaveConversation(m.conversation) + err := m.ctx.Store.UpdateConversation(m.conversation) if err != nil { cmds = append(cmds, wrapError(err)) } @@ -469,8 +464,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) { m.input.Blur() return true, nil case "ctrl+s": - userInput := strings.TrimSpace(m.input.Value()) - if strings.TrimSpace(userInput) == "" { + input := strings.TrimSpace(m.input.Value()) + if input == "" { return true, nil } @@ -478,35 +473,19 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) { return true, wrapError(fmt.Errorf("Can't reply to a user message")) } - reply := models.Message{ + m.addMessage(models.Message{ Role: models.MessageRoleUser, - Content: userInput, - } - - if m.persistence { - var err error - if m.conversation.ID == 0 { - err = m.ctx.Store.SaveConversation(m.conversation) - } - if err != nil { - return true, wrapError(err) - } - - // ensure all messages up to the one we're about to add are persisted - cmd := m.persistConversation() - if cmd != nil { - return true, cmd - } - - savedReply, err := m.ctx.Store.AddReply(m.conversation, reply) - if err != nil { - return true, wrapError(err) - } - reply = *savedReply - } + Content: input, + }) m.input.SetValue("") - m.addMessage(reply) + + if m.persistence { + err := m.persistConversation() + if err != nil { + return true, wrapError(err) + } + } m.updateContent() m.content.GotoBottom() @@ -783,7 +762,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd { func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd { return func() tea.Msg { - messages, err := m.ctx.Store.Messages(c) + messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot) if err != nil { return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) } @@ -791,62 +770,48 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd { } } -func (m *chatModel) persistConversation() tea.Cmd { - existingMessages, err := m.ctx.Store.Messages(m.conversation) - if err != nil { - return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err)) - } - - existingById := make(map[uint]*models.Message, len(existingMessages)) - for _, msg := range existingMessages { - existingById[msg.ID] = &msg - } - - currentById := make(map[uint]*models.Message, len(m.messages)) - for _, msg := range m.messages { - currentById[msg.ID] = &msg - } - - for _, msg := range existingMessages { - _, ok := currentById[msg.ID] - if !ok { - err := m.ctx.Store.DeleteMessage(&msg) - if err != nil { - return wrapError(fmt.Errorf("Failed to remove messages: %v", err)) - } +func (m *chatModel) persistConversation() error { + if m.conversation.ID == 0 { + // Start a new conversation with all messages so far + c, messages, err := m.ctx.Store.StartConversation(m.messages...) + if err != nil { + return err } + m.conversation = c + m.messages = messages + + return nil } - for i, msg := range m.messages { - if msg.ID > 0 { - exist, ok := existingById[msg.ID] - if ok { - if msg.Content == exist.Content { - continue - } - // update message when contents don't match that of store - err := m.ctx.Store.UpdateMessage(&msg) - if err != nil { - return wrapError(err) - } - } else { - // this would be quite odd... and I'm not sure how to handle - // it at the time of writing this + // else, we'll handle updating an existing conversation's messages + for i := 0; i < len(m.messages); i++ { + if m.messages[i].ID > 0 { + // message has an ID, update its contents + // TODO: check for content/tool equality before updating? + err := m.ctx.Store.UpdateMessage(&m.messages[i]) + if err != nil { + return err } + } else if i > 0 { + // messages is new, so add it as a reply to previous message + saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i]) + if err != nil { + return err + } + m.messages[i] = saved[0] } else { - newMessage, err := m.ctx.Store.AddReply(m.conversation, msg) - if err != nil { - return wrapError(err) - } - m.setMessage(i, *newMessage) + // message has no id and no previous messages to add it to + // this shouldn't happen? + return fmt.Errorf("Error: no messages to reply to") } } + return nil } func (m *chatModel) generateConversationTitle() tea.Cmd { return func() tea.Msg { - title, err := cmdutil.GenerateTitle(m.ctx, m.conversation) + title, err := cmdutil.GenerateTitle(m.ctx, m.messages) if err != nil { return msgError(err) } diff --git a/pkg/tui/conversations.go b/pkg/tui/conversations.go index 8dc61e6..f3a2f4f 100644 --- a/pkg/tui/conversations.go +++ b/pkg/tui/conversations.go @@ -2,7 +2,6 @@ package tui import ( "fmt" - "slices" "strings" "time" @@ -145,25 +144,17 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) { func (m *conversationsModel) loadConversations() tea.Cmd { return func() tea.Msg { - conversations, err := m.ctx.Store.Conversations() + messages, err := m.ctx.Store.LatestConversationMessages() if err != nil { return msgError(fmt.Errorf("Could not load conversations: %v", err)) } - loaded := make([]loadedConversation, len(conversations)) - for i, c := range conversations { - lastMessage, err := m.ctx.Store.LastMessage(&c) - if err != nil { - return msgError(err) - } - loaded[i].conv = c - loaded[i].lastReply = *lastMessage + loaded := make([]loadedConversation, len(messages)) + for i, m := range messages { + loaded[i].lastReply = m + loaded[i].conv = m.Conversation } - slices.SortFunc(loaded, func(a, b loadedConversation) int { - return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt) - }) - return msgConversationsLoaded(loaded) } }