Fixes to message/conversation handling in tui chat view

This set of changes fixes root/child message cycling and ensures all
database operations happen within a `tea.Cmd`
This commit is contained in:
Matt Low 2024-06-08 21:28:29 +00:00
parent 136c463924
commit 45df957a06
6 changed files with 251 additions and 169 deletions

View File

@ -241,6 +241,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
} }
// update parent selected reply // update parent selected reply
currentParent.Replies = append(currentParent.Replies, message)
currentParent.SelectedReply = &message currentParent.SelectedReply = &message
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil { if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
return err return err
@ -257,7 +258,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
// CloneBranch returns a deep clone of the given message and its replies, returning // CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as // a new message object. The new message will be attached to the same parent as
// the message to clone. // the messageToClone
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) { func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
newMessage := messageToClone newMessage := messageToClone
newMessage.ID = 0 newMessage.ID = 0
@ -302,7 +303,7 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
func fetchMessages(db *gorm.DB) ([]model.Message, error) { func fetchMessages(db *gorm.DB) ([]model.Message, error) {
var messages []model.Message var messages []model.Message
if err := db.Find(&messages).Error; err != nil { if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err) return nil, fmt.Errorf("Could not fetch messages: %v", err)
} }
@ -375,7 +376,9 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message
return path, nil return path, nil
} }
// PathToRoot traverses message Parent until reaching the tree root // PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
@ -392,7 +395,9 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
return path, nil return path, nil
} }
// PathToLeaf traverses message SelectedReply until reaching a tree leaf // PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) { func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")

View File

@ -36,13 +36,29 @@ type (
// a special case of common.MsgError that stops the response waiting animation // a special case of common.MsgError that stops the response waiting animation
msgResponseError error msgResponseError error
// sent on each completed reply // sent on each completed reply
msgAssistantReply models.Message msgResponse models.Message
// sent when a conversation is (re)loaded // sent when a conversation is (re)loaded
msgConversationLoaded *models.Conversation msgConversationLoaded struct {
// sent when a new conversation title is set conversation *models.Conversation
msgConversationTitleChanged string rootMessages []models.Message
}
// sent when a new conversation title generated
msgConversationTitleGenerated string
// sent when a conversation's messages are laoded // sent when a conversation's messages are laoded
msgMessagesLoaded []models.Message msgMessagesLoaded []models.Message
// sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct {
conversation *models.Conversation
messages []models.Message
}
// sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *models.Message
// sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *models.Message
// sent when a message's contents are updated and saved
msgMessageUpdated *models.Message
// sent when a message is cloned, with the cloned message
msgMessageCloned *models.Message
) )
type Model struct { type Model struct {
@ -141,7 +157,7 @@ func Chat(state shared.State) Model {
func (m Model) Init() tea.Cmd { func (m Model) Init() tea.Cmd {
return tea.Batch( return tea.Batch(
m.waitForChunk(), m.waitForResponseChunk(),
m.waitForReply(), m.waitForResponse(),
) )
} }

View File

@ -60,13 +60,17 @@ func (m *Model) loadConversation(shortname string) tea.Cmd {
if c.ID == 0 { if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname)) return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
} }
return msgConversationLoaded(c) rootMessages, err := m.State.Ctx.Store.RootMessages(c.ID)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
}
return msgConversationLoaded{c, rootMessages}
} }
} }
func (m *Model) loadMessages(c *models.Conversation) tea.Cmd { func (m *Model) loadConversationMessages() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.State.Ctx.Store.PathToLeaf(c.SelectedRoot) messages, err := m.State.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
if err != nil { if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
} }
@ -80,127 +84,170 @@ func (m *Model) generateConversationTitle() tea.Cmd {
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
return msgConversationTitleChanged(title) return msgConversationTitleGenerated(title)
} }
} }
func cycleMessages(curr *models.Message, msgs []models.Message, dir MessageCycleDirection) (*models.Message, error) { func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd {
return func() tea.Msg {
err := m.State.Ctx.Store.UpdateConversation(conversation)
if err != nil {
return shared.WrapError(err)
}
return nil
}
}
// Clones the given message (and its descendents). If selected is true, updates
// either its parent's SelectedReply or its conversation's SelectedRoot to
// point to the new clone
func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
return func() tea.Msg {
msg, _, err := m.Ctx.Store.CloneBranch(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not clone message: %v", err))
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = m.State.Ctx.Store.UpdateConversation(&msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = m.State.Ctx.Store.UpdateMessage(msg.Parent)
}
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
}
}
return msgMessageCloned(msg)
}
}
func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
return func() tea.Msg {
err := m.State.Ctx.Store.UpdateMessage(message)
if err != nil {
return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
}
return msgMessageUpdated(message)
}
}
func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) {
currentIndex := -1 currentIndex := -1
for i, reply := range msgs { for i, reply := range choices {
if reply.ID == curr.ID { if reply.ID == selected.ID {
currentIndex = i currentIndex = i
break break
} }
} }
if currentIndex < 0 { if currentIndex < 0 {
return nil, fmt.Errorf("message not found") // this should probably be an assert
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
} }
var next int var next int
if dir == CyclePrev { if dir == CyclePrev {
// Wrap around to the last reply if at the beginning // Wrap around to the last reply if at the beginning
next = (currentIndex - 1 + len(msgs)) % len(msgs) next = (currentIndex - 1 + len(choices)) % len(choices)
} else { } else {
// Wrap around to the first reply if at the end // Wrap around to the first reply if at the end
next = (currentIndex + 1) % len(msgs) next = (currentIndex + 1) % len(choices)
} }
return &msgs[next], nil return &choices[next], nil
} }
func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) (*models.Message, error) { func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd {
if len(m.rootMessages) < 2 { if len(m.rootMessages) < 2 {
return nil, nil return nil
} }
nextRoot, err := cycleMessages(conv.SelectedRoot, m.rootMessages, dir) return func() tea.Msg {
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, m.rootMessages, dir)
if err != nil { if err != nil {
return nil, err return shared.WrapError(err)
} }
conv.SelectedRoot = nextRoot conv.SelectedRoot = nextRoot
err = m.State.Ctx.Store.UpdateConversation(conv) err = m.State.Ctx.Store.UpdateConversation(conv)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err) return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
}
return msgSelectedRootCycled(nextRoot)
} }
return nextRoot, nil
} }
func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) (*models.Message, error) { func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 { if len(message.Replies) < 2 {
return nil, nil return nil
} }
nextReply, err := cycleMessages(message.SelectedReply, message.Replies, dir) return func() tea.Msg {
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil { if err != nil {
return nil, err return shared.WrapError(err)
} }
message.SelectedReply = nextReply message.SelectedReply = nextReply
err = m.State.Ctx.Store.UpdateMessage(message) err = m.State.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err) return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
}
return msgSelectedReplyCycled(nextReply)
} }
return nextReply, nil
} }
func (m *Model) persistConversation() error { func (m *Model) persistConversation() tea.Cmd {
conversation := m.conversation
messages := m.messages
var err error
if m.conversation.ID == 0 { if m.conversation.ID == 0 {
return func() tea.Msg {
// Start a new conversation with all messages so far // Start a new conversation with all messages so far
c, messages, err := m.State.Ctx.Store.StartConversation(m.messages...) conversation, messages, err = m.State.Ctx.Store.StartConversation(messages...)
if err != nil { if err != nil {
return err return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
}
return msgConversationPersisted{conversation, messages}
} }
m.conversation = c
m.messages = messages
return nil
} }
return func() tea.Msg {
// else, we'll handle updating an existing conversation's messages // else, we'll handle updating an existing conversation's messages
for i := 0; i < len(m.messages); i++ { for i := range messages {
if m.messages[i].ID > 0 { if messages[i].ID > 0 {
// message has an ID, update its contents // message has an ID, update its contents
// TODO: check for content/tool equality before updating? err := m.State.Ctx.Store.UpdateMessage(&messages[i])
err := m.State.Ctx.Store.UpdateMessage(&m.messages[i])
if err != nil { if err != nil {
return err return shared.MsgError(err)
} }
} else if i > 0 { } else if i > 0 {
// messages is new, so add it as a reply to previous message if messages[i].Content == "" {
saved, err := m.State.Ctx.Store.Reply(&m.messages[i-1], m.messages[i]) continue
if err != nil {
return err
} }
// add this message as a reply to the previous // messages is new, so add it as a reply to previous message
m.messages[i-1].Replies = append(m.messages[i-1].Replies, saved[0]) saved, err := m.State.Ctx.Store.Reply(&messages[i-1], messages[i])
m.messages[i] = saved[0] if err != nil {
return shared.MsgError(err)
}
messages[i] = saved[0]
} else { } else {
// message has no id and no previous messages to add it to // message has no id and no previous messages to add it to
// this shouldn't happen? // this shouldn't happen?
return fmt.Errorf("Error: no messages to reply to") return fmt.Errorf("Error: no messages to reply to")
} }
} }
return msgConversationPersisted{conversation, messages}
return nil }
} }
func (m *Model) promptLLM() tea.Cmd { func (m *Model) promptLLM() tea.Cmd {
m.waitingForReply = true m.waitingForReply = true
m.replyCursor.Blink = false m.replyCursor.Blink = false
m.status = "Press ctrl+c to cancel" m.status = "Press ctrl+c to cancel"
toPrompt := m.messages
// Add response placeholder message
if m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant {
m.addMessage(models.Message{
Role: models.MessageRoleAssistant,
Content: "",
})
}
m.tokenCount = 0 m.tokenCount = 0
m.startTime = time.Now() m.startTime = time.Now()
m.elapsed = 0 m.elapsed = 0
@ -234,7 +281,7 @@ func (m *Model) promptLLM() tea.Cmd {
}() }()
resp, err := provider.CreateChatCompletionStream( resp, err := provider.CreateChatCompletionStream(
ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan, ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
) )
if err != nil && !canceled { if err != nil && !canceled {

View File

@ -71,10 +71,15 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Focus() m.input.Focus()
return true, nil return true, nil
case "e": case "e":
message := m.messages[m.selectedMessage] if m.selectedMessage < len(m.messages) {
cmd := tuiutil.OpenTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
m.editorTarget = selectedMessage m.editorTarget = selectedMessage
return true, cmd return true, tuiutil.OpenTempfileEditor(
"message.*.md",
m.messages[m.selectedMessage].Content,
"# Edit the message below\n",
)
}
return false, nil
case "ctrl+k": case "ctrl+k":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) { if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage-- m.selectedMessage--
@ -97,34 +102,14 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
dir = CycleNext dir = CycleNext
} }
var err error var cmd tea.Cmd
var selected *models.Message
if m.selectedMessage == 0 { if m.selectedMessage == 0 {
selected, err = m.cycleSelectedRoot(m.conversation, dir) cmd = m.cycleSelectedRoot(m.conversation, dir)
if err != nil {
return true, shared.WrapError(fmt.Errorf("Could not cycle conversation root: %v", err))
}
} else if m.selectedMessage > 0 { } else if m.selectedMessage > 0 {
selected, err = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir) cmd = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir)
if err != nil {
return true, shared.WrapError(fmt.Errorf("Could not cycle reply: %v", err))
}
} }
if selected == nil { return cmd != nil, cmd
return false, nil
}
// Retrieve updated view at this point
newPath, err := m.State.Ctx.Store.PathToLeaf(selected)
if err != nil {
m.State.Err = fmt.Errorf("Could not fetch messages: %v", err)
}
m.messages = append(m.messages[:m.selectedMessage], newPath...)
m.rebuildMessageCache()
m.updateContent()
return true, nil
case "ctrl+r": case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message // resubmit the conversation with all messages up until and including the selected message
if m.waitingForReply || len(m.messages) == 0 { if m.waitingForReply || len(m.messages) == 0 {
@ -177,17 +162,16 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.SetValue("") m.input.SetValue("")
var cmds []tea.Cmd
if m.persistence { if m.persistence {
err := m.persistConversation() cmds = append(cmds, m.persistConversation())
if err != nil {
return true, shared.WrapError(err)
}
} }
cmd := m.promptLLM() cmds = append(cmds, m.promptLLM())
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
return true, cmd return true, tea.Batch(cmds...)
case "ctrl+e": case "ctrl+e":
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n") cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input m.editorTarget = input

View File

@ -1,7 +1,6 @@
package chat package chat
import ( import (
"fmt"
"strings" "strings"
"time" "time"
@ -22,13 +21,13 @@ func (m *Model) HandleResize(width, height int) {
} }
} }
func (m *Model) waitForReply() tea.Cmd { func (m *Model) waitForResponse() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
return msgAssistantReply(<-m.replyChan) return msgResponse(<-m.replyChan)
} }
} }
func (m *Model) waitForChunk() tea.Cmd { func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
return msgResponseChunk(<-m.replyChunkChan) return msgResponseChunk(<-m.replyChunkChan)
} }
@ -37,46 +36,59 @@ func (m *Model) waitForChunk() tea.Cmd {
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
case shared.MsgViewEnter: case shared.MsgViewEnter:
// wake up spinners and cursors // wake up spinners and cursors
cmds = append(cmds, cursor.Blink, m.spinner.Tick) cmds = append(cmds, cursor.Blink, m.spinner.Tick)
if m.State.Values.ConvShortname != "" && m.conversation.ShortName.String != m.State.Values.ConvShortname { if m.State.Values.ConvShortname != "" {
// (re)load conversation contents
cmds = append(cmds, m.loadConversation(m.State.Values.ConvShortname)) cmds = append(cmds, m.loadConversation(m.State.Values.ConvShortname))
if m.conversation.ShortName.String != m.State.Values.ConvShortname {
// clear existing messages if we're loading a new conversation
m.messages = []models.Message{}
m.selectedMessage = 0
}
} }
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case tea.WindowSizeMsg:
m.HandleResize(msg.Width, msg.Height)
case tuiutil.MsgTempfileEditorClosed: case tuiutil.MsgTempfileEditorClosed:
contents := string(msg) contents := string(msg)
switch m.editorTarget { switch m.editorTarget {
case input: case input:
m.input.SetValue(contents) m.input.SetValue(contents)
case selectedMessage: case selectedMessage:
m.setMessageContents(m.selectedMessage, contents) toEdit := m.messages[m.selectedMessage]
if m.persistence && m.messages[m.selectedMessage].ID > 0 { if toEdit.Content != contents {
// update persisted message toEdit.Content = contents
err := m.State.Ctx.Store.UpdateMessage(&m.messages[m.selectedMessage]) m.setMessage(m.selectedMessage, toEdit)
if err != nil { if m.persistence && toEdit.ID > 0 {
cmds = append(cmds, shared.WrapError(fmt.Errorf("Could not save edited message: %v", err))) // create clone of message with its new contents
cmds = append(cmds, m.cloneMessage(toEdit, true))
} }
} }
m.updateContent()
} }
case msgConversationLoaded: case msgConversationLoaded:
m.conversation = (*models.Conversation)(msg) m.conversation = msg.conversation
m.rootMessages, _ = m.State.Ctx.Store.RootMessages(m.conversation.ID) m.rootMessages = msg.rootMessages
cmds = append(cmds, m.loadMessages(m.conversation)) m.selectedMessage = -1
if len(m.rootMessages) > 0 {
cmds = append(cmds, m.loadConversationMessages())
}
case msgMessagesLoaded: case msgMessagesLoaded:
m.selectedMessage = len(msg) - 1
m.messages = msg m.messages = msg
if m.selectedMessage == -1 {
m.selectedMessage = len(msg) - 1
} else {
m.selectedMessage = min(m.selectedMessage, len(m.messages))
}
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
m.content.GotoBottom()
case msgResponseChunk: case msgResponseChunk:
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
chunk := string(msg) chunk := string(msg)
if chunk == "" { if chunk == "" {
@ -102,8 +114,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.tokenCount++ m.tokenCount++
m.elapsed = time.Now().Sub(m.startTime) m.elapsed = time.Now().Sub(m.startTime)
case msgAssistantReply: case msgResponse:
cmds = append(cmds, m.waitForReply()) // wait for the next reply cmds = append(cmds, m.waitForResponse()) // wait for the next response
reply := models.Message(msg) reply := models.Message(msg)
reply.Content = strings.TrimSpace(reply.Content) reply.Content = strings.TrimSpace(reply.Content)
@ -121,10 +133,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
if m.persistence { if m.persistence {
err := m.persistConversation() cmds = append(cmds, m.persistConversation())
if err != nil {
cmds = append(cmds, shared.WrapError(err))
}
} }
if m.conversation.Title == "" { if m.conversation.Title == "" {
@ -146,20 +155,30 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
m.State.Err = error(msg) m.State.Err = error(msg)
m.updateContent() m.updateContent()
case msgConversationTitleChanged: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
m.conversation.Title = title m.conversation.Title = title
if m.persistence { if m.persistence {
err := m.State.Ctx.Store.UpdateConversation(m.conversation) cmds = append(cmds, m.updateConversationTitle(m.conversation))
if err != nil {
cmds = append(cmds, shared.WrapError(err))
}
} }
case cursor.BlinkMsg: case cursor.BlinkMsg:
if m.waitingForReply { if m.waitingForReply {
// ensure we show updated "wait for response" cursor blink state // ensure we show the updated "wait for response" cursor blink state
m.updateContent() m.updateContent()
} }
case msgConversationPersisted:
m.conversation = msg.conversation
m.messages = msg.messages
m.rebuildMessageCache()
m.updateContent()
case msgMessageCloned:
if msg.Parent == nil {
m.conversation = &msg.Conversation
m.rootMessages = append(m.rootMessages, *msg)
}
cmds = append(cmds, m.loadConversationMessages())
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
cmds = append(cmds, m.loadConversationMessages())
} }
var cmd tea.Cmd var cmd tea.Cmd
@ -218,12 +237,12 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
// move cursor up until content reaches the bottom of the viewport // move cursor up until content reaches the bottom of the viewport
m.input.CursorUp() m.input.CursorUp()
} }
m.input, cmd = m.input.Update(nil) m.input, _ = m.input.Update(nil)
for i := 0; i < dist; i++ { for i := 0; i < dist; i++ {
// move cursor back down to its previous position // move cursor back down to its previous position
m.input.CursorDown() m.input.CursorDown()
} }
m.input, cmd = m.input.Update(nil) m.input, _ = m.input.Update(nil)
} }
} }

View File

@ -89,7 +89,7 @@ func (m *Model) renderMessageHeading(i int, message *models.Message) string {
faint := lipgloss.NewStyle().Faint(true) faint := lipgloss.NewStyle().Faint(true)
if i == 0 && len(m.rootMessages) > 0 { if i == 0 && len(m.rootMessages) > 1 && m.conversation.SelectedRootID != nil {
selectedRootIndex := 0 selectedRootIndex := 0
for j, reply := range m.rootMessages { for j, reply := range m.rootMessages {
if reply.ID == *m.conversation.SelectedRootID { if reply.ID == *m.conversation.SelectedRootID {
@ -139,7 +139,7 @@ func (m *Model) renderMessage(i int) string {
} }
// Show the assistant's cursor // Show the assistant's cursor
if m.waitingForReply && i == len(m.messages)-1 { if m.waitingForReply && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant {
sb.WriteString(m.replyCursor.View()) sb.WriteString(m.replyCursor.View())
} }
@ -250,6 +250,17 @@ func (m *Model) conversationMessagesView() string {
lineCnt += lipgloss.Height(rendered) lineCnt += lipgloss.Height(rendered)
} }
// Render a placeholder for the incoming assistant reply
if m.waitingForReply && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) {
heading := m.renderMessageHeading(-1, &models.Message{
Role: models.MessageRoleAssistant,
})
sb.WriteString(heading)
sb.WriteString("\n")
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
sb.WriteString("\n")
}
return sb.String() return sb.String()
} }