Implement PathToRoot and PathToLeaf with one query

After fetching all of a conversation's messages, we traverse the
message's Parent or SelectedReply fields to build the message "path"
in-memory
This commit is contained in:
Matt Low 2024-06-01 06:40:59 +00:00
parent ea576d24a6
commit 60a474d516
1 changed files with 88 additions and 28 deletions

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"time"
@ -299,48 +300,107 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
return &newMessage, replyCount, nil
}
// PathToRoot traverses message Parent until reaching the tree root
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
func fetchMessages(db *gorm.DB) ([]model.Message, error) {
var messages []model.Message
if err := db.Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err)
}
messageMap := make(map[uint]model.Message)
for i, message := range messages {
messageMap[messages[i].ID] = message
}
// Create a map to store replies by their parent ID
repliesMap := make(map[uint][]model.Message)
for i, message := range messages {
if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
}
}
// Assign replies, parent, and selected reply to each message
for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]model.Message, len(replies))
for j, m := range replies {
messages[i].Replies[j] = m
}
}
if messages[i].ParentID != nil {
if parent, exists := messageMap[*messages[i].ParentID]; exists {
messages[i].Parent = &parent
}
}
if messages[i].SelectedReplyID != nil {
if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists {
messages[i].SelectedReply = &selectedReply
}
}
}
return messages, nil
}
func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) {
var messages []model.Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil {
return nil, err
}
// Create a map to store messages by their ID
messageMap := make(map[uint]*model.Message)
for i := range messages {
messageMap[messages[i].ID] = &messages[i]
}
// Build the path
var path []model.Message
current := message
nextID := &message.ID
for {
path = append([]model.Message{*current}, path...)
if current.Parent == nil {
break
current, exists := messageMap[*nextID]
if !exists {
return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID)
}
var err error
current, err = s.MessageByID(*current.ParentID)
if err != nil {
return nil, fmt.Errorf("finding parent message: %w", err)
path = append(path, *current)
nextID = getNext(current)
if nextID == nil {
break
}
}
return path, nil
}
// PathToRoot traverses message Parent until reaching the tree root
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
path, err := s.buildPath(message, func(m *model.Message) *uint {
return m.ParentID
})
if err != nil {
return nil, err
}
slices.Reverse(path)
return path, nil
}
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID")
}
var path []model.Message
current := message
for {
path = append(path, *current)
if current.SelectedReplyID == nil {
break
}
var err error
current, err = s.MessageByID(*current.SelectedReplyID)
if err != nil {
return nil, fmt.Errorf("finding selected reply: %w", err)
}
}
return path, nil
return s.buildPath(message, func(m *model.Message) *uint {
return m.SelectedReplyID
})
}
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {