From ec21a02ec07b1286f795e2a475e85ee558ba8010 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Fri, 25 Oct 2024 16:57:15 +0000 Subject: [PATCH] Tweaks/cleanups to conversation management in tui - Pass around message/conversation values instead of pointers where it makes more sense, and store values instead of pointers in the globally (within the TUI) shared `App` (pointers provide no utility here). - Split conversation persistence into separate conversation/message saving stages --- pkg/api/api.go | 8 +++ pkg/conversation/repo.go | 31 ++++++--- pkg/tui/model/model.go | 70 ++++++++++++-------- pkg/tui/views/chat/chat.go | 10 ++- pkg/tui/views/chat/cmds.go | 26 ++++---- pkg/tui/views/chat/input.go | 5 +- pkg/tui/views/chat/update.go | 32 ++++----- pkg/tui/views/chat/view.go | 4 +- pkg/tui/views/conversations/conversations.go | 16 ++--- 9 files changed, 116 insertions(+), 86 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index c26c5b1..6042af8 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -72,6 +72,14 @@ func (m MessageRole) IsAssistant() bool { return false } +func (m MessageRole) IsUser() bool { + switch m { + case MessageRoleUser, MessageRoleToolResult: + return true + } + return false +} + func (m MessageRole) IsSystem() bool { switch m { case MessageRoleSystem: diff --git a/pkg/conversation/repo.go b/pkg/conversation/repo.go index 4e033b6..f817cc3 100644 --- a/pkg/conversation/repo.go +++ b/pkg/conversation/repo.go @@ -25,6 +25,7 @@ type Repo interface { CreateConversation(title string) (*Conversation, error) UpdateConversation(*Conversation) error DeleteConversation(*Conversation) error + DeleteConversationById(id uint) error GetMessageByID(messageID uint) (*Message, error) @@ -71,7 +72,7 @@ func NewRepo(db *gorm.DB) (Repo, error) { return &repo{db, _sqids}, nil } -type conversationListItem struct { +type ConversationListItem struct { ID uint ShortName string Title string @@ -80,7 +81,7 @@ type conversationListItem struct { type ConversationList struct { Total int - Items []conversationListItem + Items []ConversationListItem } // LoadConversationList loads existing conversations, ordered by the date @@ -95,7 +96,7 @@ func (s *repo) LoadConversationList() (ConversationList, error) { } for _, c := range convos { - list.Items = append(list.Items, conversationListItem{ + list.Items = append(list.Items, ConversationListItem{ ID: c.ID, ShortName: c.ShortName.String, Title: c.Title, @@ -147,7 +148,7 @@ func (s *repo) GetConversationByID(id uint) (*Conversation, error) { func (s *repo) CreateConversation(title string) (*Conversation, error) { // Create the new conversation c := &Conversation{Title: title} - err := s.db.Save(c).Error + err := s.db.Create(c).Error if err != nil { return nil, err } @@ -172,12 +173,18 @@ func (s *repo) DeleteConversation(c *Conversation) error { if c == nil || c.ID == 0 { return fmt.Errorf("Conversation is nil or invalid (missing ID)") } - // Delete messages first - err := s.db.Where("conversation_id = ?", c.ID).Delete(&Message{}).Error + return s.DeleteConversationById(c.ID) +} + +func (s *repo) DeleteConversationById(id uint) error { + if id == 0 { + return fmt.Errorf("Invalid conversation ID: %d", id) + } + err := s.db.Where("conversation_id = ?", id).Delete(&Message{}).Error if err != nil { return err } - return s.db.Delete(c).Error + return s.db.Where("id = ?", id).Delete(&Conversation{}).Error } func (s *repo) SaveMessage(m Message) (*Message, error) { @@ -186,6 +193,7 @@ func (s *repo) SaveMessage(m Message) (*Message, error) { } newMessage := m newMessage.ID = 0 + newMessage.CreatedAt = time.Now() return &newMessage, s.db.Create(&newMessage).Error } @@ -234,12 +242,15 @@ func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) { savedMessages = append(savedMessages, message) currentParent = &message } - - to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt - s.UpdateConversation(to.Conversation) return nil }) + if err != nil { + return savedMessages, err + } + + to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt + err = s.UpdateConversation(to.Conversation) return savedMessages, err } diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 0d98531..5737a8d 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -16,7 +16,7 @@ import ( type AppModel struct { Ctx *lmcli.Context Conversations conversation.ConversationList - Conversation *conversation.Conversation + Conversation conversation.Conversation Messages []conversation.Message Model string ProviderName string @@ -27,12 +27,13 @@ type AppModel struct { func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel { app := &AppModel{ Ctx: ctx, - Conversation: initialConversation, Model: *ctx.Config.Defaults.Model, } if initialConversation == nil { app.NewConversation() + } else { + } model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") @@ -61,7 +62,7 @@ const ( ) func (m *AppModel) ClearConversation() { - m.Conversation = nil + m.Conversation = conversation.Conversation{} m.Messages = []conversation.Message{} } @@ -96,10 +97,6 @@ func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (s return cmdutil.GenerateTitle(a.Ctx, messages) } -func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error { - return a.Ctx.Conversations.UpdateConversation(conversation) -} - func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) { msg, _, err := a.Ctx.Conversations.CloneBranch(message) if err != nil { @@ -182,33 +179,54 @@ func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir Message return nextReply, nil } -func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) { - var err error - if conversation == nil || conversation.ID == 0 { - conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...) - if err != nil { - return nil, nil, fmt.Errorf("Could not start new conversation: %v", err) - } - return conversation, messages, nil - } - - for i := range messages { - if messages[i].ID > 0 { - err := a.Ctx.Conversations.UpdateMessage(&messages[i]) +func (a *AppModel) PersistMessages() ([]conversation.Message, error) { + messages := make([]conversation.Message, len(a.Messages)) + for i, m := range a.Messages { + if i == 0 && m.ID == 0 { + m.Conversation = &a.Conversation + m, err := a.Ctx.Conversations.SaveMessage(m) if err != nil { - return nil, nil, err + return nil, fmt.Errorf("Could not create first message %d: %v", a.Messages[i].ID, err) } + messages[i] = *m + // let's set the conversation root message(s), as this is the first message + m.Conversation.RootMessages = []conversation.Message{*m} + m.Conversation.SelectedRoot = &m.Conversation.RootMessages[0] + a.Ctx.Conversations.UpdateConversation(m.Conversation) + } else if m.ID > 0 { + // Existing message, update it + err := a.Ctx.Conversations.UpdateMessage(&m) + if err != nil { + return nil, fmt.Errorf("Could not update message %d: %v", a.Messages[i].ID, err) + } + messages[i] = m } else if i > 0 { - saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i]) + // New message, reply to previous + replies, err := a.Ctx.Conversations.Reply(&messages[i-1], m) if err != nil { - return nil, nil, err + return nil, fmt.Errorf("Could not reply with new message: %v", err) } - messages[i] = saved[0] + messages[i] = replies[0] } else { - return nil, nil, fmt.Errorf("Error: no messages to reply to") + return nil, fmt.Errorf("No messages to reply to (this is a bug)") } } - return conversation, messages, nil + return messages, nil +} + +func (a *AppModel) PersistConversation() (conversation.Conversation, error) { + conv := a.Conversation + var err error + if a.Conversation.ID > 0 { + err = a.Ctx.Conversations.UpdateConversation(&conv) + } else { + c, e := a.Ctx.Conversations.CreateConversation("") + err = e + if e == nil && c != nil { + conv = *c + } + } + return conv, err } func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) { diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index 848766e..1755418 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -4,8 +4,8 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/provider" "git.mlow.ca/mlow/lmcli/pkg/conversation" + "git.mlow.ca/mlow/lmcli/pkg/provider" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/spinner" @@ -20,10 +20,8 @@ type ( // sent when a new conversation title generated msgConversationTitleGenerated string // sent when the conversation has been persisted, triggers a reload of contents - msgConversationPersisted struct { - conversation *conversation.Conversation - messages []conversation.Message - } + msgConversationPersisted conversation.Conversation + msgMessagesPersisted []conversation.Message // sent when a conversation's messages are laoded msgConversationMessagesLoaded struct { messages []conversation.Message @@ -35,7 +33,7 @@ type ( // sent on each chunk received from LLM msgChatResponseChunk provider.Chunk // sent on each completed reply - msgChatResponse *conversation.Message + msgChatResponse conversation.Message // sent when the response is canceled msgChatResponseCanceled struct{} // sent when results from a tool call are returned diff --git a/pkg/tui/views/chat/cmds.go b/pkg/tui/views/chat/cmds.go index 9b2e69e..3a20acb 100644 --- a/pkg/tui/views/chat/cmds.go +++ b/pkg/tui/views/chat/cmds.go @@ -36,16 +36,6 @@ func (m *Model) generateConversationTitle() tea.Cmd { } } -func (m *Model) updateConversationTitle(conversation *conversation.Conversation) tea.Cmd { - return func() tea.Msg { - err := m.App.UpdateConversationTitle(conversation) - if err != nil { - return shared.WrapError(err) - } - return nil - } -} - func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd { return func() tea.Msg { msg, err := m.App.CloneMessage(message, selected) @@ -96,11 +86,21 @@ func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.Mess func (m *Model) persistConversation() tea.Cmd { return func() tea.Msg { - conversation, messages, err := m.App.PersistConversation(m.App.Conversation, m.App.Messages) + conversation, err := m.App.PersistConversation() if err != nil { return shared.AsMsgError(err) } - return msgConversationPersisted{conversation, messages} + return msgConversationPersisted(conversation) + } +} + +func (m *Model) persistMessages() tea.Cmd { + return func() tea.Msg { + messages, err := m.App.PersistMessages() + if err != nil { + return shared.AsMsgError(err) + } + return msgMessagesPersisted(messages) } } @@ -130,7 +130,7 @@ func (m *Model) promptLLM() tea.Cmd { if err != nil { return msgChatResponseError{Err: err} } - return msgChatResponse(resp) + return msgChatResponse(*resp) }, ) } diff --git a/pkg/tui/views/chat/input.go b/pkg/tui/views/chat/input.go index 4cca580..4afab6b 100644 --- a/pkg/tui/views/chat/input.go +++ b/pkg/tui/views/chat/input.go @@ -127,7 +127,7 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd { var cmd tea.Cmd if m.selectedMessage == 0 { - cmd = m.cycleSelectedRoot(m.App.Conversation, dir) + cmd = m.cycleSelectedRoot(&m.App.Conversation, dir) } else if m.selectedMessage > 0 { cmd = m.cycleSelectedReply(&m.App.Messages[m.selectedMessage-1], dir) } @@ -162,7 +162,6 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd { m.input.Blur() return shared.KeyHandled(msg) case "ctrl+s": - // TODO: call a "handleSend" function which returns a tea.Cmd if m.state != idle { return nil } @@ -172,7 +171,7 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd { return shared.KeyHandled(msg) } - if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role == api.MessageRoleUser { + if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role.IsUser() { return shared.WrapError(fmt.Errorf("Can't reply to a user message")) } diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index ad06a4c..e9b4b44 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -75,7 +75,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { m.rebuildMessageCache() m.updateContent() - if m.App.Conversation != nil && m.App.Conversation.ID > 0 { + if m.App.Conversation.ID > 0 { // (re)load conversation contents cmds = append(cmds, m.loadConversationMessages()) } @@ -133,7 +133,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case msgChatResponse: m.state = idle - reply := (*conversation.Message)(msg) + reply := conversation.Message(msg) reply.Content = strings.TrimSpace(reply.Content) last := len(m.App.Messages) - 1 @@ -142,16 +142,15 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { } if m.App.Messages[last].Role.IsAssistant() { - // TODO: handle continuations gracefully - some models support them well, others fail horribly. - m.setMessage(last, *reply) + // TODO: handle continuations gracefully - only some models support them + m.setMessage(last, reply) } else { - m.addMessage(*reply) + m.addMessage(reply) } - switch reply.Role { - case api.MessageRoleToolCall: + if reply.Role == api.MessageRoleToolCall { // TODO: user confirmation before execution - // m.state = waitingForConfirmation + // m.state = confirmToolUse cmds = append(cmds, m.executeToolCalls(reply.ToolCalls)) } @@ -159,11 +158,9 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { cmds = append(cmds, m.persistConversation()) } - if m.App.Conversation.Title == "" { + if m.App.Conversation.Title == "" && len(m.App.Messages) > 0 { cmds = append(cmds, m.generateConversationTitle()) } - - m.updateContent() case msgChatResponseCanceled: m.state = idle m.updateContent() @@ -194,8 +191,8 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case msgConversationTitleGenerated: title := string(msg) m.App.Conversation.Title = title - if m.persistence { - cmds = append(cmds, m.updateConversationTitle(m.App.Conversation)) + if m.persistence && m.App.Conversation.ID > 0 { + cmds = append(cmds, m.persistConversation()) } case cursor.BlinkMsg: if m.state == pendingResponse { @@ -205,14 +202,13 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { m.updateContent() } case msgConversationPersisted: - m.App.Conversation = msg.conversation - m.App.Messages = msg.messages + m.App.Conversation = conversation.Conversation(msg) + cmds = append(cmds, m.persistMessages()) + case msgMessagesPersisted: + m.App.Messages = msg m.rebuildMessageCache() m.updateContent() case msgMessageCloned: - if msg.Parent == nil { - m.App.Conversation = msg.Conversation - } cmds = append(cmds, m.loadConversationMessages()) case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated: cmds = append(cmds, m.loadConversationMessages()) diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index 7dd3c6a..337ad0f 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -71,7 +71,7 @@ func (m *Model) renderMessageHeading(i int, message *conversation.Message) strin prefix = " " } - if i == 0 && len(m.App.Conversation.RootMessages) > 1 && m.App.Conversation.SelectedRootID != nil { + if i == 0 && m.App.Conversation.SelectedRootID != nil && len(m.App.Conversation.RootMessages) > 1 { selectedRootIndex := 0 for j, reply := range m.App.Conversation.RootMessages { if reply.ID == *m.App.Conversation.SelectedRootID { @@ -261,7 +261,7 @@ func (m *Model) Content(width, height int) string { func (m *Model) Header(width int) string { titleStyle := lipgloss.NewStyle().Bold(true) var title string - if m.App.Conversation != nil && m.App.Conversation.Title != "" { + if m.App.Conversation.Title != "" { title = m.App.Conversation.Title } else { title = "Untitled" diff --git a/pkg/tui/views/conversations/conversations.go b/pkg/tui/views/conversations/conversations.go index 52f68eb..48af809 100644 --- a/pkg/tui/views/conversations/conversations.go +++ b/pkg/tui/views/conversations/conversations.go @@ -21,14 +21,14 @@ type ( // sent when conversation list is loaded msgConversationsLoaded conversation.ConversationList // sent when a single conversation is loaded - msgConversationLoaded *conversation.Conversation + msgConversationLoaded conversation.Conversation // sent when a conversation is deleted msgConversationDeleted struct{} ) type Model struct { - App *model.AppModel - width int + App *model.AppModel + width int height int cursor int @@ -151,14 +151,14 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { m.content.SetContent(m.renderConversationList()) case msgConversationLoaded: m.App.ClearConversation() - m.App.Conversation = msg + m.App.Conversation = conversation.Conversation(msg) cmds = append(cmds, func() tea.Msg { return shared.MsgViewChange(shared.ViewChat) }) case bubbles.MsgConfirmPromptAnswered: m.confirmPrompt.Blur() if msg.Value { - conv, ok := msg.Payload.(conversation.Conversation) + conv, ok := msg.Payload.(conversation.ConversationListItem) if ok { cmds = append(cmds, m.deleteConversation(conv)) } @@ -198,13 +198,13 @@ func (m *Model) loadConversation(conversationID uint) tea.Cmd { if err != nil { return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err)) } - return msgConversationLoaded(conversation) + return msgConversationLoaded(*conversation) } } -func (m *Model) deleteConversation(conv conversation.Conversation) tea.Cmd { +func (m *Model) deleteConversation(conv conversation.ConversationListItem) tea.Cmd { return func() tea.Msg { - err := m.App.Ctx.Conversations.DeleteConversation(&conv) + err := m.App.Ctx.Conversations.DeleteConversationById(conv.ID) if err != nil { return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err)) }