From 07c96082e79113bbe792afe19fd41171aebd145b Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 21 Oct 2024 15:33:20 +0000 Subject: [PATCH] Add LastMessageAt field to conversation Replaced `LatestConversationMessages` with `LoadConversationList`, which utilizes `LastMessageAt` for much faster conversation loading in the conversation listing TUI and `lmcli list` command. --- pkg/cmd/list.go | 20 +++--- pkg/conversation/conversation.go | 1 + pkg/conversation/repo.go | 58 +++++++++++------- pkg/tui/model/model.go | 23 +------ pkg/tui/views/conversations/conversations.go | 64 ++++++++++++-------- 5 files changed, 88 insertions(+), 78 deletions(-) diff --git a/pkg/cmd/list.go b/pkg/cmd/list.go index 2acf770..6651aa9 100644 --- a/pkg/cmd/list.go +++ b/pkg/cmd/list.go @@ -20,9 +20,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { Short: "List conversations", Long: `List conversations in order of recent activity`, RunE: func(cmd *cobra.Command, args []string) error { - messages, err := ctx.Conversations.LatestConversationMessages() + list, err := ctx.Conversations.LoadConversationList() if err != nil { - return fmt.Errorf("Could not fetch conversations: %v", err) + return fmt.Errorf("Could not load conversations: %v", err) } type Category struct { @@ -57,12 +57,12 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { all, _ := cmd.Flags().GetBool("all") - for _, message := range messages { - messageAge := now.Sub(message.CreatedAt) + for _, item := range list.Items { + age := now.Sub(item.LastMessageAt) var category string for _, c := range categories { - if messageAge < c.cutoff { + if age < c.cutoff { category = c.name break } @@ -70,14 +70,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { formatted := fmt.Sprintf( "%s - %s - %s", - message.Conversation.ShortName.String, - util.HumanTimeElapsedSince(messageAge), - message.Conversation.Title, + item.ShortName, + util.HumanTimeElapsedSince(age), + item.Title, ) categorized[category] = append( categorized[category], - ConversationLine{messageAge, formatted}, + ConversationLine{age, formatted}, ) } @@ -93,7 +93,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { fmt.Printf("%s:\n", category.name) for _, conv := range conversationLines { if conversationsPrinted >= count && !all { - fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted) + fmt.Printf("%d remaining conversation(s), use --all to view.\n", list.Total-conversationsPrinted) break outer } diff --git a/pkg/conversation/conversation.go b/pkg/conversation/conversation.go index 356f2a6..7c7db06 100644 --- a/pkg/conversation/conversation.go +++ b/pkg/conversation/conversation.go @@ -17,6 +17,7 @@ type Conversation struct { SelectedRootID *uint SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` RootMessages []Message `gorm:"-:all"` + LastMessageAt time.Time } type MessageMeta struct { diff --git a/pkg/conversation/repo.go b/pkg/conversation/repo.go index 47db744..4e033b6 100644 --- a/pkg/conversation/repo.go +++ b/pkg/conversation/repo.go @@ -15,8 +15,7 @@ import ( // Repo exposes low-level message and conversation management. See // Service for high-level helpers type Repo interface { - // LatestConversationMessages returns a slice of all conversations ordered by when they were last updated (newest to oldest) - LatestConversationMessages() ([]Message, error) + LoadConversationList() (ConversationList, error) FindConversationByShortName(shortName string) (*Conversation, error) ConversationShortNameCompletions(search string) []string @@ -72,25 +71,40 @@ func NewRepo(db *gorm.DB) (Repo, error) { return &repo{db, _sqids}, nil } -func (s *repo) LatestConversationMessages() ([]Message, error) { - var latestMessages []Message +type conversationListItem struct { + ID uint + ShortName string + Title string + LastMessageAt time.Time +} - subQuery := s.db.Model(&Message{}). - Select("MAX(created_at) as max_created_at, conversation_id"). - Group("conversation_id") +type ConversationList struct { + Total int + Items []conversationListItem +} - err := s.db.Model(&Message{}). - Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). - Group("messages.conversation_id"). - Order("created_at DESC"). - Preload("Conversation.SelectedRoot"). - Find(&latestMessages).Error +// LoadConversationList loads existing conversations, ordered by the date +// of their latest message, from most recent to oldest. +func (s *repo) LoadConversationList() (ConversationList, error) { + list := ConversationList{} + var convos []Conversation + err := s.db.Order("last_message_at DESC").Find(&convos).Error if err != nil { - return nil, err + return list, err } - return latestMessages, nil + for _, c := range convos { + list.Items = append(list.Items, conversationListItem{ + ID: c.ID, + ShortName: c.ShortName.String, + Title: c.Title, + LastMessageAt: c.LastMessageAt, + }) + } + + list.Total = len(list.Items) + return list, nil } func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) { @@ -220,6 +234,9 @@ func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) { savedMessages = append(savedMessages, message) currentParent = &message } + + to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt + s.UpdateConversation(to.Conversation) return nil }) @@ -427,10 +444,7 @@ func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, // Update conversation's selected root message conversation.RootMessages = []Message{messages[0]} conversation.SelectedRoot = &messages[0] - err = s.UpdateConversation(conversation) - if err != nil { - return nil, nil, err - } + conversation.LastMessageAt = messages[0].CreatedAt // Add additional replies to conversation if len(messages) > 1 { @@ -439,10 +453,12 @@ func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, return nil, nil, err } messages = append([]Message{messages[0]}, newMessages...) + conversation.LastMessageAt = messages[len(messages)-1].CreatedAt } - return conversation, messages, nil -} + err = s.UpdateConversation(conversation) + return conversation, messages, err +} // CloneConversation clones the given conversation and all of its meesages func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) { diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 529e869..0d98531 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -13,14 +13,9 @@ import ( "github.com/charmbracelet/lipgloss" ) -type LoadedConversation struct { - Conv conversation.Conversation - LastReply conversation.Message -} - type AppModel struct { Ctx *lmcli.Context - Conversations []LoadedConversation + Conversations conversation.ConversationList Conversation *conversation.Conversation Messages []conversation.Message Model string @@ -89,22 +84,6 @@ func (m *AppModel) NewConversation() { m.ApplySystemPrompt() } -func (m *AppModel) LoadConversations() (error, []LoadedConversation) { - messages, err := m.Ctx.Conversations.LatestConversationMessages() - if err != nil { - return fmt.Errorf("Could not load conversations: %v", err), nil - } - - conversations := make([]LoadedConversation, len(messages)) - for i, msg := range messages { - conversations[i] = LoadedConversation{ - Conv: *msg.Conversation, - LastReply: msg, - } - } - return nil, conversations -} - func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) { messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot) if err != nil { diff --git a/pkg/tui/views/conversations/conversations.go b/pkg/tui/views/conversations/conversations.go index 1a4318d..52f68eb 100644 --- a/pkg/tui/views/conversations/conversations.go +++ b/pkg/tui/views/conversations/conversations.go @@ -19,9 +19,9 @@ import ( type ( // sent when conversation list is loaded - msgConversationsLoaded ([]model.LoadedConversation) - // sent when a conversation is selected - msgConversationSelected conversation.Conversation + msgConversationsLoaded conversation.ConversationList + // sent when a single conversation is loaded + msgConversationLoaded *conversation.Conversation // sent when a conversation is deleted msgConversationDeleted struct{} ) @@ -56,19 +56,17 @@ func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd { } } + conversations := m.App.Conversations.Items + switch msg.String() { case "enter": - if len(m.App.Conversations) > 0 && m.cursor < len(m.App.Conversations) { - m.App.ClearConversation() - m.App.Conversation = &m.App.Conversations[m.cursor].Conv - return func() tea.Msg { - return shared.MsgViewChange(shared.ViewChat) - } + if len(conversations) > 0 && m.cursor < len(conversations) { + return m.loadConversation(conversations[m.cursor].ID) } case "j", "down": - if m.cursor < len(m.App.Conversations)-1 { + if m.cursor < len(conversations)-1 { m.cursor++ - if m.cursor == len(m.App.Conversations)-1 { + if m.cursor == len(conversations)-1 { m.content.GotoBottom() } else { // this hack positions the *next* conversatoin slightly @@ -78,7 +76,7 @@ func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd { } m.content.SetContent(m.renderConversationList()) } else { - m.cursor = len(m.App.Conversations) - 1 + m.cursor = len(conversations) - 1 m.content.GotoBottom() } return shared.KeyHandled(msg) @@ -100,14 +98,14 @@ func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd { m.App.NewConversation() return shared.ChangeView(shared.ViewChat) case "d": - if !m.confirmPrompt.Focused() && len(m.App.Conversations) > 0 && m.cursor < len(m.App.Conversations) { - title := m.App.Conversations[m.cursor].Conv.Title + if !m.confirmPrompt.Focused() && len(conversations) > 0 && m.cursor < len(conversations) { + title := conversations[m.cursor].Title if title == "" { title = "(untitled)" } m.confirmPrompt = bubbles.NewConfirmPrompt( fmt.Sprintf("Delete '%s'?", title), - m.App.Conversations[m.cursor].Conv, + conversations[m.cursor], ) m.confirmPrompt.Style = lipgloss.NewStyle(). Bold(true). @@ -148,9 +146,15 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { m.width, m.height = msg.Width, msg.Height m.content.SetContent(m.renderConversationList()) case msgConversationsLoaded: - m.App.Conversations = msg - m.cursor = max(0, min(len(m.App.Conversations), m.cursor)) + m.App.Conversations = conversation.ConversationList(msg) + m.cursor = max(0, min(len(m.App.Conversations.Items), m.cursor)) m.content.SetContent(m.renderConversationList()) + case msgConversationLoaded: + m.App.ClearConversation() + m.App.Conversation = msg + cmds = append(cmds, func() tea.Msg { + return shared.MsgViewChange(shared.ViewChat) + }) case bubbles.MsgConfirmPromptAnswered: m.confirmPrompt.Blur() if msg.Value { @@ -180,11 +184,21 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { func (m *Model) loadConversations() tea.Cmd { return func() tea.Msg { - err, conversations := m.App.LoadConversations() + list, err := m.App.Ctx.Conversations.LoadConversationList() if err != nil { return shared.AsMsgError(fmt.Errorf("Could not load conversations: %v", err)) } - return msgConversationsLoaded(conversations) + return msgConversationsLoaded(list) + } +} + +func (m *Model) loadConversation(conversationID uint) tea.Cmd { + return func() tea.Msg { + conversation, err := m.App.Ctx.Conversations.GetConversationByID(conversationID) + if err != nil { + return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err)) + } + return msgConversationLoaded(conversation) } } @@ -259,12 +273,12 @@ func (m *Model) renderConversationList() string { sb strings.Builder ) - m.itemOffsets = make([]int, len(m.App.Conversations)) + m.itemOffsets = make([]int, len(m.App.Conversations.Items)) sb.WriteRune('\n') currentOffset += 1 - for i, c := range m.App.Conversations { - lastReplyAge := now.Sub(c.LastReply.CreatedAt) + for i, c := range m.App.Conversations.Items { + lastReplyAge := now.Sub(c.LastMessageAt) var category string for _, g := range categories { @@ -284,14 +298,14 @@ func (m *Model) renderConversationList() string { } tStyle := titleStyle - if c.Conv.Title == "" { + if c.Title == "" { tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)") } if i == m.cursor { tStyle = tStyle.Inherit(selectedStyle) } - title := tStyle.Width(m.width - 3).PaddingLeft(2).Render(c.Conv.Title) + title := tStyle.Width(m.width - 3).PaddingLeft(2).Render(c.Title) if i == m.cursor { title = ">" + title[1:] } @@ -304,7 +318,7 @@ func (m *Model) renderConversationList() string { )) sb.WriteString(item) currentOffset += tuiutil.Height(item) - if i < len(m.App.Conversations)-1 { + if i < len(m.App.Conversations.Items)-1 { sb.WriteRune('\n') } }