Fixes to message/conversation handling in tui chat view

This set of changes fixes root/child message cycling and ensures all
database operations happen within a `tea.Cmd`
This commit is contained in:
Matt Low 2024-06-08 21:28:29 +00:00
parent 136c463924
commit 45df957a06
6 changed files with 251 additions and 169 deletions

View File

@ -241,6 +241,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
}
// update parent selected reply
currentParent.Replies = append(currentParent.Replies, message)
currentParent.SelectedReply = &message
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
return err
@ -257,7 +258,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the message to clone.
// the messageToClone
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
@ -302,7 +303,7 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
func fetchMessages(db *gorm.DB) ([]model.Message, error) {
var messages []model.Message
if err := db.Find(&messages).Error; err != nil {
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err)
}
@ -375,7 +376,9 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message
return path, nil
}
// PathToRoot traverses message Parent until reaching the tree root
// PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
@ -392,7 +395,9 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
return path, nil
}
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
// PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")

View File

@ -36,13 +36,29 @@ type (
// a special case of common.MsgError that stops the response waiting animation
msgResponseError error
// sent on each completed reply
msgAssistantReply models.Message
msgResponse models.Message
// sent when a conversation is (re)loaded
msgConversationLoaded *models.Conversation
// sent when a new conversation title is set
msgConversationTitleChanged string
msgConversationLoaded struct {
conversation *models.Conversation
rootMessages []models.Message
}
// sent when a new conversation title generated
msgConversationTitleGenerated string
// sent when a conversation's messages are laoded
msgMessagesLoaded []models.Message
// sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct {
conversation *models.Conversation
messages []models.Message
}
// sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *models.Message
// sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *models.Message
// sent when a message's contents are updated and saved
msgMessageUpdated *models.Message
// sent when a message is cloned, with the cloned message
msgMessageCloned *models.Message
)
type Model struct {
@ -141,7 +157,7 @@ func Chat(state shared.State) Model {
func (m Model) Init() tea.Cmd {
return tea.Batch(
m.waitForChunk(),
m.waitForReply(),
m.waitForResponseChunk(),
m.waitForResponse(),
)
}

View File

@ -60,13 +60,17 @@ func (m *Model) loadConversation(shortname string) tea.Cmd {
if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
}
return msgConversationLoaded(c)
rootMessages, err := m.State.Ctx.Store.RootMessages(c.ID)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
}
return msgConversationLoaded{c, rootMessages}
}
}
func (m *Model) loadMessages(c *models.Conversation) tea.Cmd {
func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.State.Ctx.Store.PathToLeaf(c.SelectedRoot)
messages, err := m.State.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
@ -80,127 +84,170 @@ func (m *Model) generateConversationTitle() tea.Cmd {
if err != nil {
return shared.MsgError(err)
}
return msgConversationTitleChanged(title)
return msgConversationTitleGenerated(title)
}
}
func cycleMessages(curr *models.Message, msgs []models.Message, dir MessageCycleDirection) (*models.Message, error) {
func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd {
return func() tea.Msg {
err := m.State.Ctx.Store.UpdateConversation(conversation)
if err != nil {
return shared.WrapError(err)
}
return nil
}
}
// Clones the given message (and its descendents). If selected is true, updates
// either its parent's SelectedReply or its conversation's SelectedRoot to
// point to the new clone
func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
return func() tea.Msg {
msg, _, err := m.Ctx.Store.CloneBranch(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not clone message: %v", err))
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = m.State.Ctx.Store.UpdateConversation(&msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = m.State.Ctx.Store.UpdateMessage(msg.Parent)
}
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
}
}
return msgMessageCloned(msg)
}
}
func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
return func() tea.Msg {
err := m.State.Ctx.Store.UpdateMessage(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
}
return msgMessageUpdated(message)
}
}
func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) {
currentIndex := -1
for i, reply := range msgs {
if reply.ID == curr.ID {
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("message not found")
// this should probably be an assert
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
// Wrap around to the last reply if at the beginning
next = (currentIndex - 1 + len(msgs)) % len(msgs)
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
// Wrap around to the first reply if at the end
next = (currentIndex + 1) % len(msgs)
next = (currentIndex + 1) % len(choices)
}
return &msgs[next], nil
return &choices[next], nil
}
func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) (*models.Message, error) {
func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd {
if len(m.rootMessages) < 2 {
return nil, nil
return nil
}
nextRoot, err := cycleMessages(conv.SelectedRoot, m.rootMessages, dir)
return func() tea.Msg {
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, m.rootMessages, dir)
if err != nil {
return nil, err
return shared.WrapError(err)
}
conv.SelectedRoot = nextRoot
err = m.State.Ctx.Store.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
}
return msgSelectedRootCycled(nextRoot)
}
return nextRoot, nil
}
func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) (*models.Message, error) {
func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 {
return nil, nil
return nil
}
nextReply, err := cycleMessages(message.SelectedReply, message.Replies, dir)
return func() tea.Msg {
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
return shared.WrapError(err)
}
message.SelectedReply = nextReply
err = m.State.Ctx.Store.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
}
return msgSelectedReplyCycled(nextReply)
}
return nextReply, nil
}
func (m *Model) persistConversation() error {
func (m *Model) persistConversation() tea.Cmd {
conversation := m.conversation
messages := m.messages
var err error
if m.conversation.ID == 0 {
return func() tea.Msg {
// Start a new conversation with all messages so far
c, messages, err := m.State.Ctx.Store.StartConversation(m.messages...)
conversation, messages, err = m.State.Ctx.Store.StartConversation(messages...)
if err != nil {
return err
return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
}
return msgConversationPersisted{conversation, messages}
}
m.conversation = c
m.messages = messages
return nil
}
return func() tea.Msg {
// else, we'll handle updating an existing conversation's messages
for i := 0; i < len(m.messages); i++ {
if m.messages[i].ID > 0 {
for i := range messages {
if messages[i].ID > 0 {
// message has an ID, update its contents
// TODO: check for content/tool equality before updating?
err := m.State.Ctx.Store.UpdateMessage(&m.messages[i])
err := m.State.Ctx.Store.UpdateMessage(&messages[i])
if err != nil {
return err
return shared.MsgError(err)
}
} else if i > 0 {
// messages is new, so add it as a reply to previous message
saved, err := m.State.Ctx.Store.Reply(&m.messages[i-1], m.messages[i])
if err != nil {
return err
if messages[i].Content == "" {
continue
}
// add this message as a reply to the previous
m.messages[i-1].Replies = append(m.messages[i-1].Replies, saved[0])
m.messages[i] = saved[0]
// messages is new, so add it as a reply to previous message
saved, err := m.State.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil {
return shared.MsgError(err)
}
messages[i] = saved[0]
} else {
// message has no id and no previous messages to add it to
// this shouldn't happen?
return fmt.Errorf("Error: no messages to reply to")
}
}
return nil
return msgConversationPersisted{conversation, messages}
}
}
func (m *Model) promptLLM() tea.Cmd {
m.waitingForReply = true
m.replyCursor.Blink = false
m.status = "Press ctrl+c to cancel"
toPrompt := m.messages
// Add response placeholder message
if m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant {
m.addMessage(models.Message{
Role: models.MessageRoleAssistant,
Content: "",
})
}
m.tokenCount = 0
m.startTime = time.Now()
m.elapsed = 0
@ -234,7 +281,7 @@ func (m *Model) promptLLM() tea.Cmd {
}()
resp, err := provider.CreateChatCompletionStream(
ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan,
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
)
if err != nil && !canceled {

View File

@ -71,10 +71,15 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Focus()
return true, nil
case "e":
message := m.messages[m.selectedMessage]
cmd := tuiutil.OpenTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
if m.selectedMessage < len(m.messages) {
m.editorTarget = selectedMessage
return true, cmd
return true, tuiutil.OpenTempfileEditor(
"message.*.md",
m.messages[m.selectedMessage].Content,
"# Edit the message below\n",
)
}
return false, nil
case "ctrl+k":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage--
@ -97,34 +102,14 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
dir = CycleNext
}
var err error
var selected *models.Message
var cmd tea.Cmd
if m.selectedMessage == 0 {
selected, err = m.cycleSelectedRoot(m.conversation, dir)
if err != nil {
return true, shared.WrapError(fmt.Errorf("Could not cycle conversation root: %v", err))
}
cmd = m.cycleSelectedRoot(m.conversation, dir)
} else if m.selectedMessage > 0 {
selected, err = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir)
if err != nil {
return true, shared.WrapError(fmt.Errorf("Could not cycle reply: %v", err))
}
cmd = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir)
}
if selected == nil {
return false, nil
}
// Retrieve updated view at this point
newPath, err := m.State.Ctx.Store.PathToLeaf(selected)
if err != nil {
m.State.Err = fmt.Errorf("Could not fetch messages: %v", err)
}
m.messages = append(m.messages[:m.selectedMessage], newPath...)
m.rebuildMessageCache()
m.updateContent()
return true, nil
return cmd != nil, cmd
case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message
if m.waitingForReply || len(m.messages) == 0 {
@ -177,17 +162,16 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.SetValue("")
var cmds []tea.Cmd
if m.persistence {
err := m.persistConversation()
if err != nil {
return true, shared.WrapError(err)
}
cmds = append(cmds, m.persistConversation())
}
cmd := m.promptLLM()
cmds = append(cmds, m.promptLLM())
m.updateContent()
m.content.GotoBottom()
return true, cmd
return true, tea.Batch(cmds...)
case "ctrl+e":
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input

View File

@ -1,7 +1,6 @@
package chat
import (
"fmt"
"strings"
"time"
@ -22,13 +21,13 @@ func (m *Model) HandleResize(width, height int) {
}
}
func (m *Model) waitForReply() tea.Cmd {
func (m *Model) waitForResponse() tea.Cmd {
return func() tea.Msg {
return msgAssistantReply(<-m.replyChan)
return msgResponse(<-m.replyChan)
}
}
func (m *Model) waitForChunk() tea.Cmd {
func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg {
return msgResponseChunk(<-m.replyChunkChan)
}
@ -37,46 +36,59 @@ func (m *Model) waitForChunk() tea.Cmd {
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
case shared.MsgViewEnter:
// wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick)
if m.State.Values.ConvShortname != "" && m.conversation.ShortName.String != m.State.Values.ConvShortname {
if m.State.Values.ConvShortname != "" {
// (re)load conversation contents
cmds = append(cmds, m.loadConversation(m.State.Values.ConvShortname))
if m.conversation.ShortName.String != m.State.Values.ConvShortname {
// clear existing messages if we're loading a new conversation
m.messages = []models.Message{}
m.selectedMessage = 0
}
}
m.rebuildMessageCache()
m.updateContent()
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
case tuiutil.MsgTempfileEditorClosed:
contents := string(msg)
switch m.editorTarget {
case input:
m.input.SetValue(contents)
case selectedMessage:
m.setMessageContents(m.selectedMessage, contents)
if m.persistence && m.messages[m.selectedMessage].ID > 0 {
// update persisted message
err := m.State.Ctx.Store.UpdateMessage(&m.messages[m.selectedMessage])
if err != nil {
cmds = append(cmds, shared.WrapError(fmt.Errorf("Could not save edited message: %v", err)))
toEdit := m.messages[m.selectedMessage]
if toEdit.Content != contents {
toEdit.Content = contents
m.setMessage(m.selectedMessage, toEdit)
if m.persistence && toEdit.ID > 0 {
// create clone of message with its new contents
cmds = append(cmds, m.cloneMessage(toEdit, true))
}
}
m.updateContent()
}
case msgConversationLoaded:
m.conversation = (*models.Conversation)(msg)
m.rootMessages, _ = m.State.Ctx.Store.RootMessages(m.conversation.ID)
cmds = append(cmds, m.loadMessages(m.conversation))
m.conversation = msg.conversation
m.rootMessages = msg.rootMessages
m.selectedMessage = -1
if len(m.rootMessages) > 0 {
cmds = append(cmds, m.loadConversationMessages())
}
case msgMessagesLoaded:
m.selectedMessage = len(msg) - 1
m.messages = msg
if m.selectedMessage == -1 {
m.selectedMessage = len(msg) - 1
} else {
m.selectedMessage = min(m.selectedMessage, len(m.messages))
}
m.rebuildMessageCache()
m.updateContent()
m.content.GotoBottom()
case msgResponseChunk:
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
chunk := string(msg)
if chunk == "" {
@ -102,8 +114,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.tokenCount++
m.elapsed = time.Now().Sub(m.startTime)
case msgAssistantReply:
cmds = append(cmds, m.waitForReply()) // wait for the next reply
case msgResponse:
cmds = append(cmds, m.waitForResponse()) // wait for the next response
reply := models.Message(msg)
reply.Content = strings.TrimSpace(reply.Content)
@ -121,10 +133,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
}
if m.persistence {
err := m.persistConversation()
if err != nil {
cmds = append(cmds, shared.WrapError(err))
}
cmds = append(cmds, m.persistConversation())
}
if m.conversation.Title == "" {
@ -146,20 +155,30 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.status = "Press ctrl+s to send"
m.State.Err = error(msg)
m.updateContent()
case msgConversationTitleChanged:
case msgConversationTitleGenerated:
title := string(msg)
m.conversation.Title = title
if m.persistence {
err := m.State.Ctx.Store.UpdateConversation(m.conversation)
if err != nil {
cmds = append(cmds, shared.WrapError(err))
}
cmds = append(cmds, m.updateConversationTitle(m.conversation))
}
case cursor.BlinkMsg:
if m.waitingForReply {
// ensure we show updated "wait for response" cursor blink state
// ensure we show the updated "wait for response" cursor blink state
m.updateContent()
}
case msgConversationPersisted:
m.conversation = msg.conversation
m.messages = msg.messages
m.rebuildMessageCache()
m.updateContent()
case msgMessageCloned:
if msg.Parent == nil {
m.conversation = &msg.Conversation
m.rootMessages = append(m.rootMessages, *msg)
}
cmds = append(cmds, m.loadConversationMessages())
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
cmds = append(cmds, m.loadConversationMessages())
}
var cmd tea.Cmd
@ -218,12 +237,12 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
// move cursor up until content reaches the bottom of the viewport
m.input.CursorUp()
}
m.input, cmd = m.input.Update(nil)
m.input, _ = m.input.Update(nil)
for i := 0; i < dist; i++ {
// move cursor back down to its previous position
m.input.CursorDown()
}
m.input, cmd = m.input.Update(nil)
m.input, _ = m.input.Update(nil)
}
}

View File

@ -89,7 +89,7 @@ func (m *Model) renderMessageHeading(i int, message *models.Message) string {
faint := lipgloss.NewStyle().Faint(true)
if i == 0 && len(m.rootMessages) > 0 {
if i == 0 && len(m.rootMessages) > 1 && m.conversation.SelectedRootID != nil {
selectedRootIndex := 0
for j, reply := range m.rootMessages {
if reply.ID == *m.conversation.SelectedRootID {
@ -139,7 +139,7 @@ func (m *Model) renderMessage(i int) string {
}
// Show the assistant's cursor
if m.waitingForReply && i == len(m.messages)-1 {
if m.waitingForReply && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant {
sb.WriteString(m.replyCursor.View())
}
@ -250,6 +250,17 @@ func (m *Model) conversationMessagesView() string {
lineCnt += lipgloss.Height(rendered)
}
// Render a placeholder for the incoming assistant reply
if m.waitingForReply && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) {
heading := m.renderMessageHeading(-1, &models.Message{
Role: models.MessageRoleAssistant,
})
sb.WriteString(heading)
sb.WriteString("\n")
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
sb.WriteString("\n")
}
return sb.String()
}