Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
494 lines
14 KiB
Go
494 lines
14 KiB
Go
package conversation
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
sqids "github.com/sqids/sqids-go"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Repo exposes low-level message and conversation management. See
|
|
// Service for high-level helpers
|
|
type Repo interface {
|
|
// LatestConversationMessages returns a slice of all conversations ordered by when they were last updated (newest to oldest)
|
|
LatestConversationMessages() ([]Message, error)
|
|
|
|
FindConversationByShortName(shortName string) (*Conversation, error)
|
|
ConversationShortNameCompletions(search string) []string
|
|
GetConversationByID(int uint) (*Conversation, error)
|
|
GetRootMessages(conversationID uint) ([]Message, error)
|
|
|
|
CreateConversation(title string) (*Conversation, error)
|
|
UpdateConversation(*Conversation) error
|
|
DeleteConversation(*Conversation) error
|
|
|
|
GetMessageByID(messageID uint) (*Message, error)
|
|
|
|
SaveMessage(message Message) (*Message, error)
|
|
UpdateMessage(message *Message) error
|
|
DeleteMessage(message *Message, prune bool) error
|
|
CloneBranch(toClone Message) (*Message, uint, error)
|
|
Reply(to *Message, messages ...Message) ([]Message, error)
|
|
|
|
PathToRoot(message *Message) ([]Message, error)
|
|
PathToLeaf(message *Message) ([]Message, error)
|
|
|
|
// Retrieves and return the "selected thread" of the conversation.
|
|
// The "selected thread" of the conversation is a chain of messages
|
|
// starting from the Conversation's SelectedRoot Message, following each
|
|
// Message's SelectedReply until the tail Message is reached.
|
|
GetSelectedThread(*Conversation) ([]Message, error)
|
|
|
|
// Start a new conversation with the given messages
|
|
StartConversation(messages ...Message) (*Conversation, []Message, error)
|
|
|
|
CloneConversation(toClone Conversation) (*Conversation, uint, error)
|
|
}
|
|
|
|
type repo struct {
|
|
db *gorm.DB
|
|
sqids *sqids.Sqids
|
|
}
|
|
|
|
func NewRepo(db *gorm.DB) (Repo, error) {
|
|
models := []any{
|
|
&Conversation{},
|
|
&Message{},
|
|
}
|
|
|
|
for _, x := range models {
|
|
err := db.AutoMigrate(x)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
|
}
|
|
}
|
|
|
|
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
|
return &repo{db, _sqids}, nil
|
|
}
|
|
|
|
func (s *repo) LatestConversationMessages() ([]Message, error) {
|
|
var latestMessages []Message
|
|
|
|
subQuery := s.db.Model(&Message{}).
|
|
Select("MAX(created_at) as max_created_at, conversation_id").
|
|
Group("conversation_id")
|
|
|
|
err := s.db.Model(&Message{}).
|
|
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
|
Group("messages.conversation_id").
|
|
Order("created_at DESC").
|
|
Preload("Conversation.SelectedRoot").
|
|
Find(&latestMessages).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return latestMessages, nil
|
|
}
|
|
|
|
func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) {
|
|
if shortName == "" {
|
|
return nil, errors.New("shortName is empty")
|
|
}
|
|
var conversation Conversation
|
|
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
|
return &conversation, err
|
|
}
|
|
|
|
func (s *repo) ConversationShortNameCompletions(shortName string) []string {
|
|
var conversations []Conversation
|
|
// ignore error for completions
|
|
s.db.Find(&conversations)
|
|
completions := make([]string, 0, len(conversations))
|
|
for _, conversation := range conversations {
|
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
|
}
|
|
}
|
|
return completions
|
|
}
|
|
|
|
func (s *repo) GetConversationByID(id uint) (*Conversation, error) {
|
|
var conversation Conversation
|
|
err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err)
|
|
}
|
|
|
|
rootMessages, err := s.GetRootMessages(id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err)
|
|
}
|
|
conversation.RootMessages = rootMessages
|
|
return &conversation, nil
|
|
}
|
|
|
|
func (s *repo) CreateConversation(title string) (*Conversation, error) {
|
|
// Create the new conversation
|
|
c := &Conversation{Title: title}
|
|
err := s.db.Save(c).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Generate and save its "short name"
|
|
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
|
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
|
err = s.db.Updates(c).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func (s *repo) UpdateConversation(c *Conversation) error {
|
|
if c == nil || c.ID == 0 {
|
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
|
}
|
|
return s.db.Updates(c).Error
|
|
}
|
|
|
|
func (s *repo) DeleteConversation(c *Conversation) error {
|
|
if c == nil || c.ID == 0 {
|
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
|
}
|
|
// Delete messages first
|
|
err := s.db.Where("conversation_id = ?", c.ID).Delete(&Message{}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.db.Delete(c).Error
|
|
}
|
|
|
|
func (s *repo) SaveMessage(m Message) (*Message, error) {
|
|
if m.Conversation == nil {
|
|
return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)")
|
|
}
|
|
newMessage := m
|
|
newMessage.ID = 0
|
|
return &newMessage, s.db.Create(&newMessage).Error
|
|
}
|
|
|
|
func (s *repo) UpdateMessage(m *Message) error {
|
|
if m == nil || m.ID == 0 {
|
|
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
|
}
|
|
return s.db.Updates(m).Error
|
|
}
|
|
|
|
func (s *repo) DeleteMessage(message *Message, prune bool) error {
|
|
return s.db.Delete(&message).Error
|
|
}
|
|
|
|
func (s *repo) GetMessageByID(messageID uint) (*Message, error) {
|
|
var message Message
|
|
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
|
return &message, err
|
|
}
|
|
|
|
// Reply to a message with a series of messages (each followed by the next)
|
|
func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) {
|
|
var savedMessages []Message
|
|
|
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
|
currentParent := to
|
|
for i := range messages {
|
|
parent := currentParent
|
|
message := messages[i]
|
|
message.Parent = parent
|
|
message.Conversation = parent.Conversation
|
|
message.ID = 0
|
|
message.CreatedAt = time.Time{}
|
|
|
|
if err := tx.Create(&message).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// update parent selected reply
|
|
parent.Replies = append(parent.Replies, message)
|
|
parent.SelectedReply = &message
|
|
if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
savedMessages = append(savedMessages, message)
|
|
currentParent = &message
|
|
}
|
|
return nil
|
|
})
|
|
|
|
return savedMessages, err
|
|
}
|
|
|
|
// CloneBranch returns a deep clone of the given message and its replies, returning
|
|
// a new message object. The new message will be attached to the same parent as
|
|
// the messageToClone
|
|
func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) {
|
|
newMessage := messageToClone
|
|
newMessage.ID = 0
|
|
newMessage.Replies = nil
|
|
newMessage.SelectedReplyID = nil
|
|
newMessage.SelectedReply = nil
|
|
|
|
originalReplies := messageToClone.Replies
|
|
|
|
if err := s.db.Create(&newMessage).Error; err != nil {
|
|
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
|
}
|
|
|
|
var replyCount uint = 0
|
|
for _, reply := range originalReplies {
|
|
replyCount++
|
|
|
|
newReply := reply
|
|
newReply.ConversationID = messageToClone.ConversationID
|
|
newReply.ParentID = &newMessage.ID
|
|
newReply.Parent = &newMessage
|
|
|
|
res, c, err := s.CloneBranch(newReply)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
newMessage.Replies = append(newMessage.Replies, *res)
|
|
replyCount += c
|
|
|
|
if reply.ID == *messageToClone.SelectedReplyID {
|
|
newMessage.SelectedReplyID = &res.ID
|
|
if err := s.UpdateMessage(&newMessage); err != nil {
|
|
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
|
}
|
|
}
|
|
}
|
|
return &newMessage, replyCount, nil
|
|
}
|
|
|
|
func fetchMessages(db *gorm.DB) ([]Message, error) {
|
|
var messages []Message
|
|
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
|
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
|
}
|
|
|
|
messageMap := make(map[uint]Message)
|
|
for i, message := range messages {
|
|
messageMap[messages[i].ID] = message
|
|
}
|
|
|
|
// Create a map to store replies by their parent ID
|
|
repliesMap := make(map[uint][]Message)
|
|
for i, message := range messages {
|
|
if messages[i].ParentID != nil {
|
|
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
|
}
|
|
}
|
|
|
|
// Assign replies, parent, and selected reply to each message
|
|
for i := range messages {
|
|
if replies, exists := repliesMap[messages[i].ID]; exists {
|
|
messages[i].Replies = make([]Message, len(replies))
|
|
for j, m := range replies {
|
|
messages[i].Replies[j] = m
|
|
}
|
|
}
|
|
if messages[i].ParentID != nil {
|
|
if parent, exists := messageMap[*messages[i].ParentID]; exists {
|
|
messages[i].Parent = &parent
|
|
}
|
|
}
|
|
if messages[i].SelectedReplyID != nil {
|
|
if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists {
|
|
messages[i].SelectedReply = &selectedReply
|
|
}
|
|
}
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func (r repo) GetRootMessages(conversationID uint) ([]Message, error) {
|
|
var rootMessages []Message
|
|
err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err)
|
|
}
|
|
return rootMessages, nil
|
|
}
|
|
|
|
func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) {
|
|
var messages []Message
|
|
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create a map to store messages by their ID
|
|
messageMap := make(map[uint]*Message, len(messages))
|
|
for i := range messages {
|
|
messageMap[messages[i].ID] = &messages[i]
|
|
}
|
|
|
|
// Construct Replies
|
|
repliesMap := make(map[uint][]*Message, len(messages))
|
|
for _, m := range messageMap {
|
|
if m.ParentID == nil {
|
|
continue
|
|
}
|
|
if p, ok := messageMap[*m.ParentID]; ok {
|
|
repliesMap[p.ID] = append(repliesMap[p.ID], m)
|
|
}
|
|
}
|
|
|
|
// Add replies to messages
|
|
for _, m := range messageMap {
|
|
if replies, ok := repliesMap[m.ID]; ok {
|
|
m.Replies = make([]Message, len(replies))
|
|
for idx, reply := range replies {
|
|
m.Replies[idx] = *reply
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build the path
|
|
var path []Message
|
|
nextID := &message.ID
|
|
|
|
for {
|
|
current, exists := messageMap[*nextID]
|
|
if !exists {
|
|
return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID)
|
|
}
|
|
|
|
path = append(path, *current)
|
|
|
|
nextID = getNext(current)
|
|
if nextID == nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
return path, nil
|
|
}
|
|
|
|
// PathToRoot traverses the provided message's Parent until reaching the tree
|
|
// root and returns a slice of all messages traversed in chronological order
|
|
// (starting with the root and ending with the message provided)
|
|
func (s *repo) PathToRoot(message *Message) ([]Message, error) {
|
|
if message == nil || message.ID <= 0 {
|
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
|
}
|
|
|
|
path, err := s.buildPath(message, func(m *Message) *uint {
|
|
return m.ParentID
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
slices.Reverse(path)
|
|
|
|
return path, nil
|
|
}
|
|
|
|
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
|
// tree leaf and returns a slice of all messages traversed in chronological
|
|
// order (starting with the message provided and ending with the leaf)
|
|
func (s *repo) PathToLeaf(message *Message) ([]Message, error) {
|
|
if message == nil || message.ID <= 0 {
|
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
|
}
|
|
|
|
return s.buildPath(message, func(m *Message) *uint {
|
|
return m.SelectedReplyID
|
|
})
|
|
}
|
|
|
|
func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
|
}
|
|
|
|
// Create new conversation
|
|
conversation, err := s.CreateConversation("")
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
messages[0].Conversation = conversation
|
|
|
|
// Create first message
|
|
firstMessage, err := s.SaveMessage(messages[0])
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
messages[0] = *firstMessage
|
|
|
|
// Update conversation's selected root message
|
|
conversation.RootMessages = []Message{messages[0]}
|
|
conversation.SelectedRoot = &messages[0]
|
|
err = s.UpdateConversation(conversation)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Add additional replies to conversation
|
|
if len(messages) > 1 {
|
|
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
messages = append([]Message{messages[0]}, newMessages...)
|
|
}
|
|
return conversation, messages, nil
|
|
}
|
|
|
|
|
|
// CloneConversation clones the given conversation and all of its meesages
|
|
func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) {
|
|
rootMessages, err := s.GetRootMessages(toClone.ID)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
|
|
}
|
|
|
|
clone, err := s.CreateConversation(toClone.Title + " - Clone")
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
|
|
}
|
|
|
|
var errors []error
|
|
var messageCnt uint = 0
|
|
for _, root := range rootMessages {
|
|
messageCnt++
|
|
newRoot := root
|
|
newRoot.ConversationID = &clone.ID
|
|
|
|
cloned, count, err := s.CloneBranch(newRoot)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
continue
|
|
}
|
|
messageCnt += count
|
|
|
|
if root.ID == *toClone.SelectedRootID {
|
|
clone.SelectedRootID = &cloned.ID
|
|
if err := s.UpdateConversation(clone); err != nil {
|
|
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(errors) > 0 {
|
|
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
}
|
|
|
|
return clone, messageCnt, nil
|
|
}
|
|
|
|
func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) {
|
|
if c.SelectedRoot == nil {
|
|
return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug")
|
|
}
|
|
return s.PathToLeaf(c.SelectedRoot)
|
|
}
|