Add message branching
Updated the behaviour of commands: - `lmcli edit` - by default create a new branch/message branch with the edited contents - add --in-place to avoid creating a branch - no longer delete messages after the edited message - only do the edit, don't fetch a new response - `lmcli retry` - create a new branch rather than replacing old messages - add --offset to change where to retry from
This commit is contained in:
@@ -36,7 +36,9 @@ func NewContext() (*Context, error) {
|
||||
}
|
||||
|
||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||
//Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,28 @@ const (
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"index"`
|
||||
Conversation Conversation `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
ParentID *uint
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
SelectedRootID *uint
|
||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
|
||||
@@ -13,21 +13,26 @@ import (
|
||||
)
|
||||
|
||||
type ConversationStore interface {
|
||||
Conversations() ([]model.Conversation, error)
|
||||
|
||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||
ConversationShortNameCompletions(search string) []string
|
||||
RootMessages(conversationID uint) ([]model.Message, error)
|
||||
LatestConversationMessages() ([]model.Message, error)
|
||||
|
||||
SaveConversation(conversation *model.Conversation) error
|
||||
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
||||
UpdateConversation(conversation *model.Conversation) error
|
||||
DeleteConversation(conversation *model.Conversation) error
|
||||
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
||||
|
||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
||||
MessageByID(messageID uint) (*model.Message, error)
|
||||
MessageReplies(messageID uint) ([]model.Message, error)
|
||||
|
||||
SaveMessage(message *model.Message) error
|
||||
DeleteMessage(message *model.Message) error
|
||||
UpdateMessage(message *model.Message) error
|
||||
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
||||
DeleteMessage(message *model.Message, prune bool) error
|
||||
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
||||
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
||||
|
||||
PathToRoot(message *model.Message) ([]model.Message, error)
|
||||
PathToLeaf(message *model.Message) ([]model.Message, error)
|
||||
}
|
||||
|
||||
type SQLStore struct {
|
||||
@@ -52,47 +57,52 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
return &SQLStore{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
|
||||
// Save the new conversation
|
||||
err := s.db.Save(&c).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
// Generate and save its "short name"
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
||||
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
return s.UpdateConversation(c)
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
||||
if c == nil || c.ID == 0 {
|
||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||
}
|
||||
|
||||
return err
|
||||
return s.db.Updates(&c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
||||
// Delete messages first
|
||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Delete(&c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
||||
return s.db.Create(message).Error
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
||||
panic("Not yet implemented")
|
||||
//return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
||||
return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
||||
return s.db.Updates(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
||||
var conversations []model.Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
||||
if m == nil || m.ID == 0 {
|
||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||
}
|
||||
return s.db.Updates(&m).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
var conversations []model.Conversation
|
||||
// ignore error for completions
|
||||
s.db.Find(&conversations)
|
||||
completions := make([]string, 0, len(conversations))
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
@@ -106,27 +116,250 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation model.Conversation
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
||||
var messages []model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
||||
var rootMessages []model.Message
|
||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rootMessages, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
||||
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
||||
var message model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||
return &message, err
|
||||
}
|
||||
|
||||
// AddReply adds the given messages as a reply to the given conversation, can be
|
||||
// used to easily copy a message associated with one conversation, to another
|
||||
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
||||
m.ConversationID = c.ID
|
||||
m.ID = 0
|
||||
m.CreatedAt = time.Time{}
|
||||
return &m, s.SaveMessage(&m)
|
||||
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
||||
var replies []model.Message
|
||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// StartConversation starts a new conversation with the provided messages
|
||||
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||
}
|
||||
|
||||
// Create new conversation
|
||||
conversation := &model.Conversation{}
|
||||
err := s.saveNewConversation(conversation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create first message
|
||||
messages[0].ConversationID = conversation.ID
|
||||
err = s.db.Create(&messages[0]).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Update conversation's selected root message
|
||||
conversation.SelectedRoot = &messages[0]
|
||||
err = s.UpdateConversation(conversation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Add additional replies to conversation
|
||||
if len(messages) > 1 {
|
||||
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
messages = append([]model.Message{messages[0]}, newMessages...)
|
||||
}
|
||||
return conversation, messages, nil
|
||||
}
|
||||
|
||||
// CloneConversation clones the given conversation and all of its root meesages
|
||||
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
||||
rootMessages, err := s.RootMessages(toClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
clone := &model.Conversation{
|
||||
Title: toClone.Title + " - Clone",
|
||||
}
|
||||
if err := s.saveNewConversation(clone); err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
||||
}
|
||||
|
||||
var errors []error
|
||||
var messageCnt uint = 0
|
||||
for _, root := range rootMessages {
|
||||
messageCnt++
|
||||
newRoot := root
|
||||
newRoot.ConversationID = clone.ID
|
||||
|
||||
cloned, count, err := s.CloneBranch(newRoot)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
messageCnt += count
|
||||
|
||||
if root.ID == *toClone.SelectedRootID {
|
||||
clone.SelectedRootID = &cloned.ID
|
||||
if err := s.UpdateConversation(clone); err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
return clone, messageCnt, nil
|
||||
}
|
||||
|
||||
// Reply to a message with a series of messages (each following the next)
|
||||
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
|
||||
var savedMessages []model.Message
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
currentParent := to
|
||||
for i := range messages {
|
||||
message := messages[i]
|
||||
message.ConversationID = currentParent.ConversationID
|
||||
message.ParentID = ¤tParent.ID
|
||||
message.ID = 0
|
||||
message.CreatedAt = time.Time{}
|
||||
|
||||
if err := tx.Create(&message).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update parent selected reply
|
||||
currentParent.SelectedReply = &message
|
||||
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
savedMessages = append(savedMessages, message)
|
||||
currentParent = &message
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return savedMessages, err
|
||||
}
|
||||
|
||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||
// a new message object. The new message will be attached to the same parent as
|
||||
// the message to clone.
|
||||
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
||||
newMessage := messageToClone
|
||||
newMessage.ID = 0
|
||||
newMessage.Replies = nil
|
||||
newMessage.SelectedReplyID = nil
|
||||
newMessage.SelectedReply = nil
|
||||
|
||||
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
||||
}
|
||||
|
||||
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||
}
|
||||
|
||||
var replyCount uint = 0
|
||||
for _, reply := range originalReplies {
|
||||
replyCount++
|
||||
|
||||
newReply := reply
|
||||
newReply.ConversationID = messageToClone.ConversationID
|
||||
newReply.ParentID = &newMessage.ID
|
||||
newReply.Parent = &newMessage
|
||||
|
||||
res, c, err := s.CloneBranch(newReply)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
newMessage.Replies = append(newMessage.Replies, *res)
|
||||
replyCount += c
|
||||
|
||||
if reply.ID == *messageToClone.SelectedReplyID {
|
||||
newMessage.SelectedReplyID = &res.ID
|
||||
if err := s.UpdateMessage(&newMessage); err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &newMessage, replyCount, nil
|
||||
}
|
||||
|
||||
// PathToRoot traverses message Parent until reaching the tree root
|
||||
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
||||
if message == nil {
|
||||
return nil, fmt.Errorf("Message is nil")
|
||||
}
|
||||
var path []model.Message
|
||||
current := message
|
||||
for {
|
||||
path = append([]model.Message{*current}, path...)
|
||||
if current.Parent == nil {
|
||||
break
|
||||
}
|
||||
|
||||
var err error
|
||||
current, err = s.MessageByID(*current.ParentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding parent message: %w", err)
|
||||
}
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
|
||||
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
||||
if message == nil {
|
||||
return nil, fmt.Errorf("Message is nil")
|
||||
}
|
||||
var path []model.Message
|
||||
current := message
|
||||
for {
|
||||
path = append(path, *current)
|
||||
if current.SelectedReplyID == nil {
|
||||
break
|
||||
}
|
||||
|
||||
var err error
|
||||
current, err = s.MessageByID(*current.SelectedReplyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding selected reply: %w", err)
|
||||
}
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
||||
var latestMessages []model.Message
|
||||
|
||||
subQuery := s.db.Model(&model.Message{}).
|
||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||
Group("conversation_id")
|
||||
|
||||
err := s.db.Model(&model.Message{}).
|
||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||
Group("messages.conversation_id").
|
||||
Order("created_at DESC").
|
||||
Preload("Conversation").
|
||||
Find(&latestMessages).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return latestMessages, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user