diff --git a/pkg/lmcli/store.go b/pkg/lmcli/store.go index bdac762..61432b9 100644 --- a/pkg/lmcli/store.go +++ b/pkg/lmcli/store.go @@ -241,6 +241,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model. } // update parent selected reply + currentParent.Replies = append(currentParent.Replies, message) currentParent.SelectedReply = &message if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil { return err @@ -257,7 +258,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model. // CloneBranch returns a deep clone of the given message and its replies, returning // a new message object. The new message will be attached to the same parent as -// the message to clone. +// the messageToClone func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) { newMessage := messageToClone newMessage.ID = 0 @@ -302,7 +303,7 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui func fetchMessages(db *gorm.DB) ([]model.Message, error) { var messages []model.Message - if err := db.Find(&messages).Error; err != nil { + if err := db.Preload("Conversation").Find(&messages).Error; err != nil { return nil, fmt.Errorf("Could not fetch messages: %v", err) } @@ -375,7 +376,9 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message return path, nil } -// PathToRoot traverses message Parent until reaching the tree root +// PathToRoot traverses the provided message's Parent until reaching the tree +// root and returns a slice of all messages traversed in chronological order +// (starting with the root and ending with the message provided) func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") @@ -392,7 +395,9 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { return path, nil } -// PathToLeaf traverses message SelectedReply until reaching a tree leaf +// PathToLeaf traverses the provided message's SelectedReply until reaching a +// tree leaf and returns a slice of all messages traversed in chronological +// order (starting with the message provided and ending with the leaf) func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index 1e0526b..d3113f7 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -36,13 +36,29 @@ type ( // a special case of common.MsgError that stops the response waiting animation msgResponseError error // sent on each completed reply - msgAssistantReply models.Message + msgResponse models.Message // sent when a conversation is (re)loaded - msgConversationLoaded *models.Conversation - // sent when a new conversation title is set - msgConversationTitleChanged string + msgConversationLoaded struct { + conversation *models.Conversation + rootMessages []models.Message + } + // sent when a new conversation title generated + msgConversationTitleGenerated string // sent when a conversation's messages are laoded msgMessagesLoaded []models.Message + // sent when the conversation has been persisted, triggers a reload of contents + msgConversationPersisted struct { + conversation *models.Conversation + messages []models.Message + } + // sent when the given message is made the new selected reply of its parent + msgSelectedReplyCycled *models.Message + // sent when the given message is made the new selected root of the current conversation + msgSelectedRootCycled *models.Message + // sent when a message's contents are updated and saved + msgMessageUpdated *models.Message + // sent when a message is cloned, with the cloned message + msgMessageCloned *models.Message ) type Model struct { @@ -141,7 +157,7 @@ func Chat(state shared.State) Model { func (m Model) Init() tea.Cmd { return tea.Batch( - m.waitForChunk(), - m.waitForReply(), + m.waitForResponseChunk(), + m.waitForResponse(), ) } diff --git a/pkg/tui/views/chat/conversation.go b/pkg/tui/views/chat/conversation.go index 72b4356..0c098bf 100644 --- a/pkg/tui/views/chat/conversation.go +++ b/pkg/tui/views/chat/conversation.go @@ -60,13 +60,17 @@ func (m *Model) loadConversation(shortname string) tea.Cmd { if c.ID == 0 { return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname)) } - return msgConversationLoaded(c) + rootMessages, err := m.State.Ctx.Store.RootMessages(c.ID) + if err != nil { + return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err)) + } + return msgConversationLoaded{c, rootMessages} } } -func (m *Model) loadMessages(c *models.Conversation) tea.Cmd { +func (m *Model) loadConversationMessages() tea.Cmd { return func() tea.Msg { - messages, err := m.State.Ctx.Store.PathToLeaf(c.SelectedRoot) + messages, err := m.State.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot) if err != nil { return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) } @@ -80,127 +84,170 @@ func (m *Model) generateConversationTitle() tea.Cmd { if err != nil { return shared.MsgError(err) } - return msgConversationTitleChanged(title) + return msgConversationTitleGenerated(title) } } -func cycleMessages(curr *models.Message, msgs []models.Message, dir MessageCycleDirection) (*models.Message, error) { +func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd { + return func() tea.Msg { + err := m.State.Ctx.Store.UpdateConversation(conversation) + if err != nil { + return shared.WrapError(err) + } + return nil + } +} + +// Clones the given message (and its descendents). If selected is true, updates +// either its parent's SelectedReply or its conversation's SelectedRoot to +// point to the new clone +func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd { + return func() tea.Msg { + msg, _, err := m.Ctx.Store.CloneBranch(message) + if err != nil { + return shared.WrapError(fmt.Errorf("Could not clone message: %v", err)) + } + if selected { + if msg.Parent == nil { + msg.Conversation.SelectedRoot = msg + err = m.State.Ctx.Store.UpdateConversation(&msg.Conversation) + } else { + msg.Parent.SelectedReply = msg + err = m.State.Ctx.Store.UpdateMessage(msg.Parent) + } + if err != nil { + return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err)) + } + } + return msgMessageCloned(msg) + } +} + +func (m *Model) updateMessageContent(message *models.Message) tea.Cmd { + return func() tea.Msg { + err := m.State.Ctx.Store.UpdateMessage(message) + if err != nil { + return shared.WrapError(fmt.Errorf("Could not update message: %v", err)) + } + return msgMessageUpdated(message) + } +} + +func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) { currentIndex := -1 - for i, reply := range msgs { - if reply.ID == curr.ID { + for i, reply := range choices { + if reply.ID == selected.ID { currentIndex = i break } } if currentIndex < 0 { - return nil, fmt.Errorf("message not found") + // this should probably be an assert + return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID) } var next int if dir == CyclePrev { // Wrap around to the last reply if at the beginning - next = (currentIndex - 1 + len(msgs)) % len(msgs) + next = (currentIndex - 1 + len(choices)) % len(choices) } else { // Wrap around to the first reply if at the end - next = (currentIndex + 1) % len(msgs) + next = (currentIndex + 1) % len(choices) } - return &msgs[next], nil + return &choices[next], nil } -func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) (*models.Message, error) { +func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd { if len(m.rootMessages) < 2 { - return nil, nil - } - - nextRoot, err := cycleMessages(conv.SelectedRoot, m.rootMessages, dir) - if err != nil { - return nil, err - } - - conv.SelectedRoot = nextRoot - err = m.State.Ctx.Store.UpdateConversation(conv) - if err != nil { - return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err) - } - return nextRoot, nil -} - -func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) (*models.Message, error) { - if len(message.Replies) < 2 { - return nil, nil - } - - nextReply, err := cycleMessages(message.SelectedReply, message.Replies, dir) - if err != nil { - return nil, err - } - - message.SelectedReply = nextReply - err = m.State.Ctx.Store.UpdateMessage(message) - if err != nil { - return nil, fmt.Errorf("Could not update message SelectedReply: %v", err) - } - return nextReply, nil -} - -func (m *Model) persistConversation() error { - if m.conversation.ID == 0 { - // Start a new conversation with all messages so far - c, messages, err := m.State.Ctx.Store.StartConversation(m.messages...) - if err != nil { - return err - } - m.conversation = c - m.messages = messages - return nil } - // else, we'll handle updating an existing conversation's messages - for i := 0; i < len(m.messages); i++ { - if m.messages[i].ID > 0 { - // message has an ID, update its contents - // TODO: check for content/tool equality before updating? - err := m.State.Ctx.Store.UpdateMessage(&m.messages[i]) + return func() tea.Msg { + nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, m.rootMessages, dir) + if err != nil { + return shared.WrapError(err) + } + + conv.SelectedRoot = nextRoot + err = m.State.Ctx.Store.UpdateConversation(conv) + if err != nil { + return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err)) + } + return msgSelectedRootCycled(nextRoot) + } +} + +func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd { + if len(message.Replies) < 2 { + return nil + } + + return func() tea.Msg { + nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir) + if err != nil { + return shared.WrapError(err) + } + + message.SelectedReply = nextReply + err = m.State.Ctx.Store.UpdateMessage(message) + if err != nil { + return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err)) + } + return msgSelectedReplyCycled(nextReply) + } +} + +func (m *Model) persistConversation() tea.Cmd { + conversation := m.conversation + messages := m.messages + + var err error + if m.conversation.ID == 0 { + return func() tea.Msg { + // Start a new conversation with all messages so far + conversation, messages, err = m.State.Ctx.Store.StartConversation(messages...) if err != nil { - return err + return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err)) } - } else if i > 0 { - // messages is new, so add it as a reply to previous message - saved, err := m.State.Ctx.Store.Reply(&m.messages[i-1], m.messages[i]) - if err != nil { - return err - } - // add this message as a reply to the previous - m.messages[i-1].Replies = append(m.messages[i-1].Replies, saved[0]) - m.messages[i] = saved[0] - } else { - // message has no id and no previous messages to add it to - // this shouldn't happen? - return fmt.Errorf("Error: no messages to reply to") + return msgConversationPersisted{conversation, messages} } } - return nil + return func() tea.Msg { + // else, we'll handle updating an existing conversation's messages + for i := range messages { + if messages[i].ID > 0 { + // message has an ID, update its contents + err := m.State.Ctx.Store.UpdateMessage(&messages[i]) + if err != nil { + return shared.MsgError(err) + } + } else if i > 0 { + if messages[i].Content == "" { + continue + } + // messages is new, so add it as a reply to previous message + saved, err := m.State.Ctx.Store.Reply(&messages[i-1], messages[i]) + if err != nil { + return shared.MsgError(err) + } + messages[i] = saved[0] + } else { + // message has no id and no previous messages to add it to + // this shouldn't happen? + return fmt.Errorf("Error: no messages to reply to") + } + } + return msgConversationPersisted{conversation, messages} + } } - func (m *Model) promptLLM() tea.Cmd { m.waitingForReply = true m.replyCursor.Blink = false m.status = "Press ctrl+c to cancel" - toPrompt := m.messages - - // Add response placeholder message - if m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant { - m.addMessage(models.Message{ - Role: models.MessageRoleAssistant, - Content: "", - }) - } - m.tokenCount = 0 m.startTime = time.Now() m.elapsed = 0 @@ -234,7 +281,7 @@ func (m *Model) promptLLM() tea.Cmd { }() resp, err := provider.CreateChatCompletionStream( - ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan, + ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, ) if err != nil && !canceled { diff --git a/pkg/tui/views/chat/input.go b/pkg/tui/views/chat/input.go index c1e3107..6bc44f2 100644 --- a/pkg/tui/views/chat/input.go +++ b/pkg/tui/views/chat/input.go @@ -71,10 +71,15 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) { m.input.Focus() return true, nil case "e": - message := m.messages[m.selectedMessage] - cmd := tuiutil.OpenTempfileEditor("message.*.md", message.Content, "# Edit the message below\n") - m.editorTarget = selectedMessage - return true, cmd + if m.selectedMessage < len(m.messages) { + m.editorTarget = selectedMessage + return true, tuiutil.OpenTempfileEditor( + "message.*.md", + m.messages[m.selectedMessage].Content, + "# Edit the message below\n", + ) + } + return false, nil case "ctrl+k": if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) { m.selectedMessage-- @@ -97,34 +102,14 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) { dir = CycleNext } - var err error - var selected *models.Message + var cmd tea.Cmd if m.selectedMessage == 0 { - selected, err = m.cycleSelectedRoot(m.conversation, dir) - if err != nil { - return true, shared.WrapError(fmt.Errorf("Could not cycle conversation root: %v", err)) - } + cmd = m.cycleSelectedRoot(m.conversation, dir) } else if m.selectedMessage > 0 { - selected, err = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir) - if err != nil { - return true, shared.WrapError(fmt.Errorf("Could not cycle reply: %v", err)) - } + cmd = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir) } - if selected == nil { - return false, nil - } - - // Retrieve updated view at this point - newPath, err := m.State.Ctx.Store.PathToLeaf(selected) - if err != nil { - m.State.Err = fmt.Errorf("Could not fetch messages: %v", err) - } - - m.messages = append(m.messages[:m.selectedMessage], newPath...) - m.rebuildMessageCache() - m.updateContent() - return true, nil + return cmd != nil, cmd case "ctrl+r": // resubmit the conversation with all messages up until and including the selected message if m.waitingForReply || len(m.messages) == 0 { @@ -177,17 +162,16 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) { m.input.SetValue("") + var cmds []tea.Cmd if m.persistence { - err := m.persistConversation() - if err != nil { - return true, shared.WrapError(err) - } + cmds = append(cmds, m.persistConversation()) } - cmd := m.promptLLM() + cmds = append(cmds, m.promptLLM()) + m.updateContent() m.content.GotoBottom() - return true, cmd + return true, tea.Batch(cmds...) case "ctrl+e": cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n") m.editorTarget = input diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index c10d0b0..72eba8d 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -1,7 +1,6 @@ package chat import ( - "fmt" "strings" "time" @@ -22,13 +21,13 @@ func (m *Model) HandleResize(width, height int) { } } -func (m *Model) waitForReply() tea.Cmd { +func (m *Model) waitForResponse() tea.Cmd { return func() tea.Msg { - return msgAssistantReply(<-m.replyChan) + return msgResponse(<-m.replyChan) } } -func (m *Model) waitForChunk() tea.Cmd { +func (m *Model) waitForResponseChunk() tea.Cmd { return func() tea.Msg { return msgResponseChunk(<-m.replyChunkChan) } @@ -37,46 +36,59 @@ func (m *Model) waitForChunk() tea.Cmd { func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { var cmds []tea.Cmd switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.HandleResize(msg.Width, msg.Height) case shared.MsgViewEnter: // wake up spinners and cursors cmds = append(cmds, cursor.Blink, m.spinner.Tick) - if m.State.Values.ConvShortname != "" && m.conversation.ShortName.String != m.State.Values.ConvShortname { + if m.State.Values.ConvShortname != "" { + // (re)load conversation contents cmds = append(cmds, m.loadConversation(m.State.Values.ConvShortname)) + + if m.conversation.ShortName.String != m.State.Values.ConvShortname { + // clear existing messages if we're loading a new conversation + m.messages = []models.Message{} + m.selectedMessage = 0 + } } m.rebuildMessageCache() m.updateContent() - case tea.WindowSizeMsg: - m.HandleResize(msg.Width, msg.Height) case tuiutil.MsgTempfileEditorClosed: contents := string(msg) switch m.editorTarget { case input: m.input.SetValue(contents) case selectedMessage: - m.setMessageContents(m.selectedMessage, contents) - if m.persistence && m.messages[m.selectedMessage].ID > 0 { - // update persisted message - err := m.State.Ctx.Store.UpdateMessage(&m.messages[m.selectedMessage]) - if err != nil { - cmds = append(cmds, shared.WrapError(fmt.Errorf("Could not save edited message: %v", err))) + toEdit := m.messages[m.selectedMessage] + if toEdit.Content != contents { + toEdit.Content = contents + m.setMessage(m.selectedMessage, toEdit) + if m.persistence && toEdit.ID > 0 { + // create clone of message with its new contents + cmds = append(cmds, m.cloneMessage(toEdit, true)) } } - m.updateContent() } case msgConversationLoaded: - m.conversation = (*models.Conversation)(msg) - m.rootMessages, _ = m.State.Ctx.Store.RootMessages(m.conversation.ID) - cmds = append(cmds, m.loadMessages(m.conversation)) + m.conversation = msg.conversation + m.rootMessages = msg.rootMessages + m.selectedMessage = -1 + if len(m.rootMessages) > 0 { + cmds = append(cmds, m.loadConversationMessages()) + } case msgMessagesLoaded: - m.selectedMessage = len(msg) - 1 m.messages = msg + if m.selectedMessage == -1 { + m.selectedMessage = len(msg) - 1 + } else { + m.selectedMessage = min(m.selectedMessage, len(m.messages)) + } m.rebuildMessageCache() m.updateContent() - m.content.GotoBottom() case msgResponseChunk: - cmds = append(cmds, m.waitForChunk()) // wait for the next chunk + cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk chunk := string(msg) if chunk == "" { @@ -102,8 +114,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.tokenCount++ m.elapsed = time.Now().Sub(m.startTime) - case msgAssistantReply: - cmds = append(cmds, m.waitForReply()) // wait for the next reply + case msgResponse: + cmds = append(cmds, m.waitForResponse()) // wait for the next response reply := models.Message(msg) reply.Content = strings.TrimSpace(reply.Content) @@ -121,10 +133,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { } if m.persistence { - err := m.persistConversation() - if err != nil { - cmds = append(cmds, shared.WrapError(err)) - } + cmds = append(cmds, m.persistConversation()) } if m.conversation.Title == "" { @@ -146,20 +155,30 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.status = "Press ctrl+s to send" m.State.Err = error(msg) m.updateContent() - case msgConversationTitleChanged: + case msgConversationTitleGenerated: title := string(msg) m.conversation.Title = title if m.persistence { - err := m.State.Ctx.Store.UpdateConversation(m.conversation) - if err != nil { - cmds = append(cmds, shared.WrapError(err)) - } + cmds = append(cmds, m.updateConversationTitle(m.conversation)) } case cursor.BlinkMsg: if m.waitingForReply { - // ensure we show updated "wait for response" cursor blink state + // ensure we show the updated "wait for response" cursor blink state m.updateContent() } + case msgConversationPersisted: + m.conversation = msg.conversation + m.messages = msg.messages + m.rebuildMessageCache() + m.updateContent() + case msgMessageCloned: + if msg.Parent == nil { + m.conversation = &msg.Conversation + m.rootMessages = append(m.rootMessages, *msg) + } + cmds = append(cmds, m.loadConversationMessages()) + case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated: + cmds = append(cmds, m.loadConversationMessages()) } var cmd tea.Cmd @@ -218,12 +237,12 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { // move cursor up until content reaches the bottom of the viewport m.input.CursorUp() } - m.input, cmd = m.input.Update(nil) + m.input, _ = m.input.Update(nil) for i := 0; i < dist; i++ { // move cursor back down to its previous position m.input.CursorDown() } - m.input, cmd = m.input.Update(nil) + m.input, _ = m.input.Update(nil) } } diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index 6b1b470..b454757 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -89,7 +89,7 @@ func (m *Model) renderMessageHeading(i int, message *models.Message) string { faint := lipgloss.NewStyle().Faint(true) - if i == 0 && len(m.rootMessages) > 0 { + if i == 0 && len(m.rootMessages) > 1 && m.conversation.SelectedRootID != nil { selectedRootIndex := 0 for j, reply := range m.rootMessages { if reply.ID == *m.conversation.SelectedRootID { @@ -139,7 +139,7 @@ func (m *Model) renderMessage(i int) string { } // Show the assistant's cursor - if m.waitingForReply && i == len(m.messages)-1 { + if m.waitingForReply && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant { sb.WriteString(m.replyCursor.View()) } @@ -250,6 +250,17 @@ func (m *Model) conversationMessagesView() string { lineCnt += lipgloss.Height(rendered) } + // Render a placeholder for the incoming assistant reply + if m.waitingForReply && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) { + heading := m.renderMessageHeading(-1, &models.Message{ + Role: models.MessageRoleAssistant, + }) + sb.WriteString(heading) + sb.WriteString("\n") + sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View())) + sb.WriteString("\n") + } + return sb.String() }