package conversation import ( "database/sql" "errors" "fmt" "slices" "strings" "time" sqids "github.com/sqids/sqids-go" "gorm.io/gorm" ) // Repo exposes low-level message and conversation management. See // Service for high-level helpers type Repo interface { LoadConversationList() (ConversationList, error) FindConversationByShortName(shortName string) (*Conversation, error) ConversationShortNameCompletions(search string) []string GetConversationByID(int uint) (*Conversation, error) GetRootMessages(conversationID uint) ([]Message, error) CreateConversation(title string) (*Conversation, error) UpdateConversation(*Conversation) error DeleteConversation(*Conversation) error DeleteConversationById(id uint) error GetMessageByID(messageID uint) (*Message, error) SaveMessage(message Message) (*Message, error) UpdateMessage(message *Message) error DeleteMessage(message *Message, prune bool) error CloneBranch(toClone Message) (*Message, uint, error) Reply(to *Message, messages ...Message) ([]Message, error) PathToRoot(message *Message) ([]Message, error) PathToLeaf(message *Message) ([]Message, error) // Retrieves and return the "selected thread" of the conversation. // The "selected thread" of the conversation is a chain of messages // starting from the Conversation's SelectedRoot Message, following each // Message's SelectedReply until the tail Message is reached. GetSelectedThread(*Conversation) ([]Message, error) // Start a new conversation with the given messages StartConversation(messages ...Message) (*Conversation, []Message, error) CloneConversation(toClone Conversation) (*Conversation, uint, error) } type repo struct { db *gorm.DB sqids *sqids.Sqids } func NewRepo(db *gorm.DB) (Repo, error) { models := []any{ &Conversation{}, &Message{}, } for _, x := range models { err := db.AutoMigrate(x) if err != nil { return nil, fmt.Errorf("Could not perform database migrations: %v", err) } } _sqids, _ := sqids.New(sqids.Options{MinLength: 4}) return &repo{db, _sqids}, nil } type ConversationListItem struct { ID uint ShortName string Title string LastMessageAt time.Time } type ConversationList struct { Total int Items []ConversationListItem } // LoadConversationList loads existing conversations, ordered by the date // of their latest message, from most recent to oldest. func (s *repo) LoadConversationList() (ConversationList, error) { list := ConversationList{} var convos []Conversation err := s.db.Order("last_message_at DESC").Find(&convos).Error if err != nil { return list, err } for _, c := range convos { list.Items = append(list.Items, ConversationListItem{ ID: c.ID, ShortName: c.ShortName.String, Title: c.Title, LastMessageAt: c.LastMessageAt, }) } list.Total = len(list.Items) return list, nil } func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) { if shortName == "" { return nil, errors.New("shortName is empty") } var conversation Conversation err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error return &conversation, err } func (s *repo) ConversationShortNameCompletions(shortName string) []string { var conversations []Conversation // ignore error for completions s.db.Find(&conversations) completions := make([]string, 0, len(conversations)) for _, conversation := range conversations { if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) } } return completions } func (s *repo) GetConversationByID(id uint) (*Conversation, error) { var conversation Conversation err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error if err != nil { return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err) } rootMessages, err := s.GetRootMessages(id) if err != nil { return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err) } conversation.RootMessages = rootMessages return &conversation, nil } func (s *repo) CreateConversation(title string) (*Conversation, error) { // Create the new conversation c := &Conversation{Title: title} err := s.db.Create(c).Error if err != nil { return nil, err } // Generate and save its "short name" shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)}) c.ShortName = sql.NullString{String: shortName, Valid: true} err = s.db.Updates(c).Error if err != nil { return nil, err } return c, nil } func (s *repo) UpdateConversation(c *Conversation) error { if c == nil || c.ID == 0 { return fmt.Errorf("Conversation is nil or invalid (missing ID)") } return s.db.Updates(c).Error } func (s *repo) DeleteConversation(c *Conversation) error { if c == nil || c.ID == 0 { return fmt.Errorf("Conversation is nil or invalid (missing ID)") } return s.DeleteConversationById(c.ID) } func (s *repo) DeleteConversationById(id uint) error { if id == 0 { return fmt.Errorf("Invalid conversation ID: %d", id) } err := s.db.Where("conversation_id = ?", id).Delete(&Message{}).Error if err != nil { return err } return s.db.Where("id = ?", id).Delete(&Conversation{}).Error } func (s *repo) SaveMessage(m Message) (*Message, error) { if m.Conversation == nil { return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)") } newMessage := m newMessage.ID = 0 newMessage.CreatedAt = time.Now() return &newMessage, s.db.Create(&newMessage).Error } func (s *repo) UpdateMessage(m *Message) error { if m == nil || m.ID == 0 { return fmt.Errorf("Message is nil or invalid (missing ID)") } return s.db.Updates(m).Error } func (s *repo) DeleteMessage(message *Message, prune bool) error { return s.db.Delete(&message).Error } func (s *repo) GetMessageByID(messageID uint) (*Message, error) { var message Message err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error return &message, err } // Reply to a message with a series of messages (each followed by the next) func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) { var savedMessages []Message err := s.db.Transaction(func(tx *gorm.DB) error { currentParent := to for i := range messages { parent := currentParent message := messages[i] message.Parent = parent message.Conversation = parent.Conversation message.ID = 0 message.CreatedAt = time.Time{} if err := tx.Create(&message).Error; err != nil { return err } // update parent selected reply parent.Replies = append(parent.Replies, message) parent.SelectedReply = &message if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil { return err } savedMessages = append(savedMessages, message) currentParent = &message } return nil }) if err != nil { return savedMessages, err } to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt err = s.UpdateConversation(to.Conversation) return savedMessages, err } // CloneBranch returns a deep clone of the given message and its replies, returning // a new message object. The new message will be attached to the same parent as // the messageToClone func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) { newMessage := messageToClone newMessage.ID = 0 newMessage.Replies = nil newMessage.SelectedReplyID = nil newMessage.SelectedReply = nil originalReplies := messageToClone.Replies if err := s.db.Create(&newMessage).Error; err != nil { return nil, 0, fmt.Errorf("Could not clone message: %s", err) } var replyCount uint = 0 for _, reply := range originalReplies { replyCount++ newReply := reply newReply.ConversationID = messageToClone.ConversationID newReply.ParentID = &newMessage.ID newReply.Parent = &newMessage res, c, err := s.CloneBranch(newReply) if err != nil { return nil, 0, err } newMessage.Replies = append(newMessage.Replies, *res) replyCount += c if reply.ID == *messageToClone.SelectedReplyID { newMessage.SelectedReplyID = &res.ID if err := s.UpdateMessage(&newMessage); err != nil { return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err) } } } return &newMessage, replyCount, nil } func fetchMessages(db *gorm.DB) ([]Message, error) { var messages []Message if err := db.Preload("Conversation").Find(&messages).Error; err != nil { return nil, fmt.Errorf("Could not fetch messages: %v", err) } messageMap := make(map[uint]Message) for i, message := range messages { messageMap[messages[i].ID] = message } // Create a map to store replies by their parent ID repliesMap := make(map[uint][]Message) for i, message := range messages { if messages[i].ParentID != nil { repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) } } // Assign replies, parent, and selected reply to each message for i := range messages { if replies, exists := repliesMap[messages[i].ID]; exists { messages[i].Replies = make([]Message, len(replies)) for j, m := range replies { messages[i].Replies[j] = m } } if messages[i].ParentID != nil { if parent, exists := messageMap[*messages[i].ParentID]; exists { messages[i].Parent = &parent } } if messages[i].SelectedReplyID != nil { if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists { messages[i].SelectedReply = &selectedReply } } } return messages, nil } func (r repo) GetRootMessages(conversationID uint) ([]Message, error) { var rootMessages []Message err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error if err != nil { return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err) } return rootMessages, nil } func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) { var messages []Message messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) if err != nil { return nil, err } // Create a map to store messages by their ID messageMap := make(map[uint]*Message, len(messages)) for i := range messages { messageMap[messages[i].ID] = &messages[i] } // Construct Replies repliesMap := make(map[uint][]*Message, len(messages)) for _, m := range messageMap { if m.ParentID == nil { continue } if p, ok := messageMap[*m.ParentID]; ok { repliesMap[p.ID] = append(repliesMap[p.ID], m) } } // Add replies to messages for _, m := range messageMap { if replies, ok := repliesMap[m.ID]; ok { m.Replies = make([]Message, len(replies)) for idx, reply := range replies { m.Replies[idx] = *reply } } } // Build the path var path []Message nextID := &message.ID for { current, exists := messageMap[*nextID] if !exists { return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID) } path = append(path, *current) nextID = getNext(current) if nextID == nil { break } } return path, nil } // PathToRoot traverses the provided message's Parent until reaching the tree // root and returns a slice of all messages traversed in chronological order // (starting with the root and ending with the message provided) func (s *repo) PathToRoot(message *Message) ([]Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") } path, err := s.buildPath(message, func(m *Message) *uint { return m.ParentID }) if err != nil { return nil, err } slices.Reverse(path) return path, nil } // PathToLeaf traverses the provided message's SelectedReply until reaching a // tree leaf and returns a slice of all messages traversed in chronological // order (starting with the message provided and ending with the leaf) func (s *repo) PathToLeaf(message *Message) ([]Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") } return s.buildPath(message, func(m *Message) *uint { return m.SelectedReplyID }) } func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) { if len(messages) == 0 { return nil, nil, fmt.Errorf("Must provide at least 1 message") } // Create new conversation conversation, err := s.CreateConversation("") if err != nil { return nil, nil, err } messages[0].Conversation = conversation // Create first message firstMessage, err := s.SaveMessage(messages[0]) if err != nil { return nil, nil, err } messages[0] = *firstMessage // Update conversation's selected root message conversation.RootMessages = []Message{messages[0]} conversation.SelectedRoot = &messages[0] conversation.LastMessageAt = messages[0].CreatedAt // Add additional replies to conversation if len(messages) > 1 { newMessages, err := s.Reply(&messages[0], messages[1:]...) if err != nil { return nil, nil, err } messages = append([]Message{messages[0]}, newMessages...) conversation.LastMessageAt = messages[len(messages)-1].CreatedAt } err = s.UpdateConversation(conversation) return conversation, messages, err } // CloneConversation clones the given conversation and all of its meesages func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) { rootMessages, err := s.GetRootMessages(toClone.ID) if err != nil { return nil, 0, fmt.Errorf("Could not create clone: %v", err) } clone, err := s.CreateConversation(toClone.Title + " - Clone") if err != nil { return nil, 0, fmt.Errorf("Could not create clone: %v", err) } var errors []error var messageCnt uint = 0 for _, root := range rootMessages { messageCnt++ newRoot := root newRoot.ConversationID = &clone.ID cloned, count, err := s.CloneBranch(newRoot) if err != nil { errors = append(errors, err) continue } messageCnt += count if root.ID == *toClone.SelectedRootID { clone.SelectedRootID = &cloned.ID if err := s.UpdateConversation(clone); err != nil { errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err)) } } } if len(errors) > 0 { return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors) } return clone, messageCnt, nil } func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) { if c.SelectedRoot == nil { return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug") } return s.PathToLeaf(c.SelectedRoot) }