From 60a474d516f0b23d03d6b59007712f221cba5639 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 1 Jun 2024 06:40:59 +0000 Subject: [PATCH] Implement PathToRoot and PathToLeaf with one query After fetching all of a conversation's messages, we traverse the message's Parent or SelectedReply fields to build the message "path" in-memory --- pkg/lmcli/store.go | 116 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 88 insertions(+), 28 deletions(-) diff --git a/pkg/lmcli/store.go b/pkg/lmcli/store.go index 97a80da..bdac762 100644 --- a/pkg/lmcli/store.go +++ b/pkg/lmcli/store.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "slices" "strings" "time" @@ -299,48 +300,107 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui return &newMessage, replyCount, nil } -// PathToRoot traverses message Parent until reaching the tree root -func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { - if message == nil { - return nil, fmt.Errorf("Message is nil") +func fetchMessages(db *gorm.DB) ([]model.Message, error) { + var messages []model.Message + if err := db.Find(&messages).Error; err != nil { + return nil, fmt.Errorf("Could not fetch messages: %v", err) } + + messageMap := make(map[uint]model.Message) + for i, message := range messages { + messageMap[messages[i].ID] = message + } + + // Create a map to store replies by their parent ID + repliesMap := make(map[uint][]model.Message) + for i, message := range messages { + if messages[i].ParentID != nil { + repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) + } + } + + // Assign replies, parent, and selected reply to each message + for i := range messages { + if replies, exists := repliesMap[messages[i].ID]; exists { + messages[i].Replies = make([]model.Message, len(replies)) + for j, m := range replies { + messages[i].Replies[j] = m + } + } + if messages[i].ParentID != nil { + if parent, exists := messageMap[*messages[i].ParentID]; exists { + messages[i].Parent = &parent + } + } + if messages[i].SelectedReplyID != nil { + if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists { + messages[i].SelectedReply = &selectedReply + } + } + } + return messages, nil +} + +func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) { + var messages []model.Message + messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) + if err != nil { + return nil, err + } + + // Create a map to store messages by their ID + messageMap := make(map[uint]*model.Message) + for i := range messages { + messageMap[messages[i].ID] = &messages[i] + } + + // Build the path var path []model.Message - current := message + nextID := &message.ID + for { - path = append([]model.Message{*current}, path...) - if current.Parent == nil { - break + current, exists := messageMap[*nextID] + if !exists { + return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID) } - var err error - current, err = s.MessageByID(*current.ParentID) - if err != nil { - return nil, fmt.Errorf("finding parent message: %w", err) + path = append(path, *current) + + nextID = getNext(current) + if nextID == nil { + break } } + + return path, nil +} + +// PathToRoot traverses message Parent until reaching the tree root +func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { + if message == nil || message.ID <= 0 { + return nil, fmt.Errorf("Message is nil or has invalid ID") + } + + path, err := s.buildPath(message, func(m *model.Message) *uint { + return m.ParentID + }) + if err != nil { + return nil, err + } + slices.Reverse(path) + return path, nil } // PathToLeaf traverses message SelectedReply until reaching a tree leaf func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) { - if message == nil { - return nil, fmt.Errorf("Message is nil") + if message == nil || message.ID <= 0 { + return nil, fmt.Errorf("Message is nil or has invalid ID") } - var path []model.Message - current := message - for { - path = append(path, *current) - if current.SelectedReplyID == nil { - break - } - var err error - current, err = s.MessageByID(*current.SelectedReplyID) - if err != nil { - return nil, fmt.Errorf("finding selected reply: %w", err) - } - } - return path, nil + return s.buildPath(message, func(m *model.Message) *uint { + return m.SelectedReplyID + }) } func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {