Add message branching

Updated the behaviour of commands:
- `lmcli edit` - no longer deletes messages past the edited message
- `lmcli retry` - creates a branch from the previous user message
This commit is contained in:
Matt Low 2024-05-20 18:12:44 +00:00
parent f6e55f6bff
commit db465f1bf0
16 changed files with 470 additions and 305 deletions

View File

@ -5,7 +5,6 @@ import (
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -28,36 +27,12 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
return err return err
} }
messagesToCopy, err := ctx.Store.Messages(toClone) clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String) return fmt.Errorf("Failed to clone conversation: %v", err)
} }
clone := &model.Conversation{ fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
Title: toClone.Title + " - Clone",
}
if err := ctx.Store.SaveConversation(clone); err != nil {
return fmt.Errorf("Cloud not create clone: %s", err)
}
var errors []error
messageCnt := 0
for _, message := range messagesToCopy {
newMessage := message
newMessage.ConversationID = clone.ID
newMessage.ID = 0
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
errors = append(errors, err)
} else {
messageCnt++
}
}
if len(errors) > 0 {
return fmt.Errorf("Messages failed to be cloned: %v", errors)
}
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("could not retrieve conversation messages: %v", err) return fmt.Errorf("could not retrieve conversation messages: %v", err)
} }

View File

@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
@ -39,21 +39,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
} }
desiredIdx := len(messages) - 1 - offset desiredIdx := len(messages) - 1 - offset
toEdit := messages[desiredIdx]
// walk backwards through the conversation deleting messages until and
// including the last user message
toRemove := []model.Message{}
var toEdit *model.Message
for i := len(messages) - 1; i >= 0; i-- {
if i == desiredIdx {
toEdit = &messages[i]
}
toRemove = append(toRemove, messages[i])
messages = messages[:i]
if toEdit != nil {
break
}
}
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content) newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
switch newContents { switch newContents {
@ -63,26 +49,17 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
toEdit.Content = newContents
role, _ := cmd.Flags().GetString("role") role, _ := cmd.Flags().GetString("role")
if role == "" { if role != "" {
role = string(toEdit.Role) if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
} }
toEdit.Role = model.MessageRole(role)
for _, message := range toRemove {
err = ctx.Store.DeleteMessage(&message)
if err != nil {
lmcli.Warn("Could not delete message: %v\n", err)
}
} }
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ return ctx.Store.UpdateMessage(&toEdit)
ConversationID: conversation.ID,
Role: model.MessageRole(role),
Content: newContents,
})
return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp

View File

@ -2,7 +2,6 @@ package cmd
import ( import (
"fmt" "fmt"
"slices"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
@ -21,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
Short: "List conversations", Short: "List conversations",
Long: `List conversations in order of recent activity`, Long: `List conversations in order of recent activity`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
conversations, err := ctx.Store.Conversations() messages, err := ctx.Store.LatestConversationMessages()
if err != nil { if err != nil {
return fmt.Errorf("Could not fetch conversations: %v", err) return fmt.Errorf("Could not fetch conversations: %v", err)
} }
@ -58,13 +57,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
all, _ := cmd.Flags().GetBool("all") all, _ := cmd.Flags().GetBool("all")
for _, conversation := range conversations { for _, message := range messages {
lastMessage, err := ctx.Store.LastMessage(&conversation) messageAge := now.Sub(message.CreatedAt)
if lastMessage == nil || err != nil {
continue
}
messageAge := now.Sub(lastMessage.CreatedAt)
var category string var category string
for _, c := range categories { for _, c := range categories {
@ -76,9 +70,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
formatted := fmt.Sprintf( formatted := fmt.Sprintf(
"%s - %s - %s", "%s - %s - %s",
conversation.ShortName.String, message.Conversation.ShortName.String,
util.HumanTimeElapsedSince(messageAge), util.HumanTimeElapsedSince(messageAge),
conversation.Title, message.Conversation.Title,
) )
categorized[category] = append( categorized[category] = append(
@ -96,14 +90,10 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
continue continue
} }
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
return int(a.timeSinceReply - b.timeSinceReply)
})
fmt.Printf("%s:\n", category.name) fmt.Printf("%s:\n", category.name)
for _, conv := range conversationLines { for _, conv := range conversationLines {
if conversationsPrinted >= count && !all { if conversationsPrinted >= count && !all {
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted) fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
break outer break outer
} }

View File

@ -15,42 +15,43 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Start a new conversation", Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`, Long: `Start a new conversation with the Large Language Model.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
if messageContents == "" { if input == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
conversation := &model.Conversation{} var messages []model.Message
err := ctx.Store.SaveConversation(conversation)
if err != nil {
return fmt.Errorf("Could not save new conversation: %v", err)
}
messages := []model.Message{ // TODO: probably just make this part of the conversation
{ system := ctx.GetSystemPrompt()
ConversationID: conversation.ID, if system != "" {
messages = append(messages, model.Message{
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: ctx.GetSystemPrompt(), Content: system,
}, })
{
ConversationID: conversation.ID,
Role: model.MessageRoleUser,
Content: messageContents,
},
} }
cmdutil.HandleConversationReply(ctx, conversation, true, messages...) messages = append(messages, model.Message{
Role: model.MessageRoleUser,
Content: input,
})
title, err := cmdutil.GenerateTitle(ctx, conversation) conversation, messages, err := ctx.Store.StartConversation(messages...)
if err != nil { if err != nil {
lmcli.Warn("Could not generate title for conversation: %v\n", err) return fmt.Errorf("Could not start a new conversation: %v", err)
}
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
title, err := cmdutil.GenerateTitle(ctx, messages)
if err != nil {
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.UpdateConversation(conversation)
err = ctx.Store.SaveConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation after generating title: %v\n", err) lmcli.Warn("Could not save conversation title: %v\n", err)
} }
return nil return nil
}, },

View File

@ -15,22 +15,27 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Do a one-shot prompt", Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`, Long: `Prompt the Large Language Model and get a response.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
if message == "" { if input == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
messages := []model.Message{ var messages []model.Message
{
// TODO: stop supplying system prompt as a message
system := ctx.GetSystemPrompt()
if system != "" {
messages = append(messages, model.Message{
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: ctx.GetSystemPrompt(), Content: system,
}, })
{
Role: model.MessageRoleUser,
Content: message,
},
} }
messages = append(messages, model.Message{
Role: model.MessageRoleUser,
Content: input,
})
_, err := cmdutil.Prompt(ctx, messages, nil) _, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)

View File

@ -24,12 +24,17 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
var err error var err error
var title string
generate, _ := cmd.Flags().GetBool("generate") generate, _ := cmd.Flags().GetBool("generate")
var title string
if generate { if generate {
title, err = cmdutil.GenerateTitle(ctx, conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
}
title, err = cmdutil.GenerateTitle(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("Could not generate conversation title: %v", err) return fmt.Errorf("Could not generate conversation title: %v", err)
} }
@ -41,9 +46,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.SaveConversation(conversation) err = ctx.Store.UpdateConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation with new title: %v\n", err) lmcli.Warn("Could not update conversation title: %v\n", err)
} }
return nil return nil
}, },

View File

@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "retry <conversation>", Use: "retry <conversation>",
Short: "Retry the last user reply in a conversation", Short: "Retry the last user reply in a conversation",
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`, Long: `Prompt the conversation from the last user response.`,
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
argCount := 1 argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
@ -25,25 +25,28 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) // Load the complete thread from the root message
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
// walk backwards through the conversation and delete messages, break // Find the last user message in the conversation
// when we find the latest user response var lastUserMessage *model.Message
for i := len(messages) - 1; i >= 0; i-- { var i int
for i = len(messages) - 1; i >= 0; i-- {
if messages[i].Role == model.MessageRoleUser { if messages[i].Role == model.MessageRoleUser {
lastUserMessage = &messages[i]
break break
} }
err = ctx.Store.DeleteMessage(&messages[i])
if err != nil {
lmcli.Warn("Could not delete previous reply: %v\n", err)
}
} }
cmdutil.HandleConversationReply(ctx, conversation, true) if lastUserMessage == nil {
return fmt.Errorf("No user message found in the conversation: %s", conversation.Title)
}
// Start a new branch at the last user message
cmdutil.HandleReply(ctx, lastUserMessage, true)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -73,43 +73,58 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
existing, err := ctx.Store.Messages(c) if to == nil {
lmcli.Fatal("Can't prompt from an empty message.")
}
existing, err := ctx.Store.PathToRoot(to)
if err != nil { if err != nil {
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title) lmcli.Fatal("Could not load messages: %v\n", err)
} }
if persist { RenderConversation(ctx, append(existing, messages...), true)
for _, message := range toSend {
err = ctx.Store.SaveMessage(&message) var savedReplies []model.Message
if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil { if err != nil {
lmcli.Warn("Could not save %s message: %v\n", message.Role, err) lmcli.Warn("Could not save messages: %v\n", err)
} }
} }
}
allMessages := append(existing, toSend...)
RenderConversation(ctx, allMessages, true)
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastMessage *model.Message
lastMessage = to
if len(savedReplies) > 1 {
lastMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply model.Message) { replyCallback := func(reply model.Message) {
if !persist { if !persist {
return return
} }
savedReplies, err = ctx.Store.Reply(lastMessage, reply)
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil { if err != nil {
lmcli.Warn("Could not save reply: %v\n", err) lmcli.Warn("Could not save reply: %v\n", err)
} }
lastMessage = &savedReplies[0]
} }
_, err = Prompt(ctx, allMessages, replyCallback) _, err = Prompt(ctx, append(existing, messages...), replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
@ -134,12 +149,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) { func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
messages, err := ctx.Store.Messages(c)
if err != nil {
return "", err
}
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective. const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
Example conversation: Example conversation:

View File

@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
} }
cmdutil.RenderConversation(ctx, messages, false) cmdutil.RenderConversation(ctx, messages, false)

View File

@ -36,7 +36,9 @@ func NewContext() (*Context, error) {
} }
databaseFile := filepath.Join(dataDir(), "conversations.db") databaseFile := filepath.Join(dataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{}) db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
if err != nil { if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err) return nil, fmt.Errorf("Error establishing connection to store: %v", err)
} }

View File

@ -17,18 +17,27 @@ const (
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"` ConversationID uint `gorm:"index"`
Conversation Conversation `gorm:"foreignKey:ConversationID"`
Content string Content string
Role MessageRole Role MessageRole
CreatedAt time.Time CreatedAt time.Time `gorm:"index"`
ToolCalls ToolCalls // a json array of tool calls (from the modl) ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results ToolResults ToolResults // a json array of tool results
ParentID *uint `gorm:"index"`
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
} }
type Conversation struct { type Conversation struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ShortName sql.NullString ShortName sql.NullString
Title string Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
} }
type RequestParameters struct { type RequestParameters struct {

View File

@ -13,21 +13,26 @@ import (
) )
type ConversationStore interface { type ConversationStore interface {
Conversations() ([]model.Conversation, error)
ConversationByShortName(shortName string) (*model.Conversation, error) ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationShortNameCompletions(search string) []string ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]model.Message, error)
LatestConversationMessages() ([]model.Message, error)
SaveConversation(conversation *model.Conversation) error StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
UpdateConversation(conversation *model.Conversation) error
DeleteConversation(conversation *model.Conversation) error DeleteConversation(conversation *model.Conversation) error
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
Messages(conversation *model.Conversation) ([]model.Message, error) MessageByID(messageID uint) (*model.Message, error)
LastMessage(conversation *model.Conversation) (*model.Message, error) MessageReplies(messageID uint) ([]model.Message, error)
SaveMessage(message *model.Message) error
DeleteMessage(message *model.Message) error
UpdateMessage(message *model.Message) error UpdateMessage(message *model.Message) error
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error) DeleteMessage(message *model.Message, prune bool) error
CloneBranch(toClone model.Message) (*model.Message, uint, error)
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
PathToRoot(message *model.Message) ([]model.Message, error)
PathToLeaf(message *model.Message) ([]model.Message, error)
} }
type SQLStore struct { type SQLStore struct {
@ -52,47 +57,52 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error { func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
err := s.db.Save(&conversation).Error // Save the new conversation
err := s.db.Save(&c).Error
if err != nil { if err != nil {
return err return err
} }
if !conversation.ShortName.Valid { // Generate and save its "short name"
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)}) shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true} c.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error return s.UpdateConversation(c)
}
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(&c).Error
}
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
// Delete messages first
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
if err != nil {
return err return err
}
return s.db.Delete(&c).Error
} }
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error { func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{}) panic("Not yet implemented")
return s.db.Delete(&conversation).Error //return s.db.Delete(&message).Error
} }
func (s *SQLStore) SaveMessage(message *model.Message) error { func (s *SQLStore) UpdateMessage(m *model.Message) error {
return s.db.Create(message).Error if m == nil || m.ID == 0 {
} return fmt.Errorf("Message is nil or invalid (missing ID)")
}
func (s *SQLStore) DeleteMessage(message *model.Message) error { return s.db.Updates(&m).Error
return s.db.Delete(&message).Error
}
func (s *SQLStore) UpdateMessage(message *model.Message) error {
return s.db.Updates(&message).Error
}
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
var conversations []model.Conversation
err := s.db.Find(&conversations).Error
return conversations, err
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var completions []string var conversations []model.Conversation
conversations, _ := s.Conversations() // ignore error for completions // ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations { for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
@ -106,27 +116,249 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
return nil, errors.New("shortName is empty") return nil, errors.New("shortName is empty")
} }
var conversation model.Conversation var conversation model.Conversation
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err return &conversation, err
} }
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) { func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
var messages []model.Message var rootMessages []model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
return messages, err if err != nil {
return nil, err
}
return rootMessages, nil
} }
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) { func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
var message model.Message var message model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err return &message, err
} }
// AddReply adds the given messages as a reply to the given conversation, can be func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
// used to easily copy a message associated with one conversation, to another var replies []model.Message
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) { err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
m.ConversationID = c.ID return replies, err
m.ID = 0 }
m.CreatedAt = time.Time{}
return &m, s.SaveMessage(&m) // StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
// Create new conversation
conversation := &model.Conversation{}
err := s.saveNewConversation(conversation)
if err != nil {
return nil, nil, err
}
// Create first message
messages[0].ConversationID = conversation.ID
err = s.db.Create(&messages[0]).Error
if err != nil {
return nil, nil, err
}
// Update conversation's selected root message
conversation.SelectedRoot = &messages[0]
err = s.UpdateConversation(conversation)
if err != nil {
return nil, nil, err
}
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]model.Message{messages[0]}, newMessages...)
}
return conversation, messages, nil
}
// CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID)
if err != nil {
return nil, 0, err
}
clone := &model.Conversation{
Title: toClone.Title + " - Clone",
}
if err := s.saveNewConversation(clone); err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
}
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
// Reply replies to the given parentMessage with a series of messages
func (s *SQLStore) Reply(parentMessage *model.Message, messages ...model.Message) ([]model.Message, error) {
var savedMessages []model.Message
currentParent := parentMessage
err := s.db.Transaction(func(tx *gorm.DB) error {
for i := range messages {
message := messages[i]
message.ConversationID = currentParent.ConversationID
message.ParentID = &currentParent.ID
message.ID = 0
message.CreatedAt = time.Time{}
if err := tx.Create(&message).Error; err != nil {
return err
}
// update parent selected reply
currentParent.SelectedReply = &message
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
return err
}
savedMessages = append(savedMessages, message)
currentParent = &message
}
return nil
})
return savedMessages, err
}
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the message to clone.
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
newMessage.Replies = nil
newMessage.SelectedReplyID = nil
newMessage.SelectedReply = nil
if err := s.db.Create(&newMessage).Error; err != nil {
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
}
originalReplies, err := s.MessageReplies(messageToClone.ID)
if err != nil {
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
}
var replyCount uint = 0
for _, reply := range originalReplies {
replyCount++
newReply := reply
newReply.ConversationID = messageToClone.ConversationID
newReply.ParentID = &newMessage.ID
newReply.Parent = &newMessage
res, c, err := s.CloneBranch(newReply)
if err != nil {
return nil, 0, err
}
newMessage.Replies = append(newMessage.Replies, *res)
replyCount += c
if reply.ID == *messageToClone.SelectedReplyID {
newMessage.SelectedReplyID = &res.ID
if err := s.UpdateMessage(&newMessage); err != nil {
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
}
}
}
return &newMessage, replyCount, nil
}
// PathToRoot traverses message Parent until reaching the tree root
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
}
var path []model.Message
current := message
for {
path = append([]model.Message{*current}, path...)
if current.Parent == nil {
break
}
var err error
current, err = s.MessageByID(*current.ParentID)
if err != nil {
return nil, fmt.Errorf("finding parent message: %w", err)
}
}
return path, nil
}
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
}
var path []model.Message
current := message
for {
path = append(path, *current)
if current.SelectedReplyID == nil {
break
}
var err error
current, err = s.MessageByID(*current.SelectedReplyID)
if err != nil {
return nil, fmt.Errorf("finding selected reply: %w", err)
}
}
return path, nil
}
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
var latestMessages []model.Message
subQuery := s.db.Model(&model.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id")
err := s.db.Model(&model.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Order("created_at DESC").
Preload("Conversation").
Find(&latestMessages).Error
if err != nil {
return nil, err
}
return latestMessages, nil
} }

View File

@ -307,14 +307,9 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
} }
if m.persistence { if m.persistence {
var err error err := m.persistConversation()
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} else {
cmds = append(cmds, m.persistConversation())
} }
} }
@ -341,7 +336,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
title := string(msg) title := string(msg)
m.conversation.Title = title m.conversation.Title = title
if m.persistence { if m.persistence {
err := m.ctx.Store.SaveConversation(m.conversation) err := m.ctx.Store.UpdateConversation(m.conversation)
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} }
@ -469,8 +464,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Blur() m.input.Blur()
return true, nil return true, nil
case "ctrl+s": case "ctrl+s":
userInput := strings.TrimSpace(m.input.Value()) input := strings.TrimSpace(m.input.Value())
if strings.TrimSpace(userInput) == "" { if input == "" {
return true, nil return true, nil
} }
@ -478,35 +473,19 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return true, wrapError(fmt.Errorf("Can't reply to a user message")) return true, wrapError(fmt.Errorf("Can't reply to a user message"))
} }
reply := models.Message{ m.addMessage(models.Message{
Role: models.MessageRoleUser, Role: models.MessageRoleUser,
Content: userInput, Content: input,
} })
if m.persistence {
var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil {
return true, wrapError(err)
}
// ensure all messages up to the one we're about to add are persisted
cmd := m.persistConversation()
if cmd != nil {
return true, cmd
}
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil {
return true, wrapError(err)
}
reply = *savedReply
}
m.input.SetValue("") m.input.SetValue("")
m.addMessage(reply)
if m.persistence {
err := m.persistConversation()
if err != nil {
return true, wrapError(err)
}
}
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
@ -783,7 +762,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd { func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.ctx.Store.Messages(c) messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
} }
@ -791,62 +770,48 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
} }
} }
func (m *chatModel) persistConversation() tea.Cmd { func (m *chatModel) persistConversation() error {
existingMessages, err := m.ctx.Store.Messages(m.conversation) if m.conversation.ID == 0 {
// Start a new conversation with all messages so far
c, messages, err := m.ctx.Store.StartConversation(m.messages...)
if err != nil { if err != nil {
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err)) return err
}
m.conversation = c
m.messages = messages
return nil
} }
existingById := make(map[uint]*models.Message, len(existingMessages)) // else, we'll handle updating an existing conversation's messages
for _, msg := range existingMessages { for i := 0; i < len(m.messages); i++ {
existingById[msg.ID] = &msg if m.messages[i].ID > 0 {
} // message has an ID, update its contents
// TODO: check for content/tool equality before updating?
currentById := make(map[uint]*models.Message, len(m.messages)) err := m.ctx.Store.UpdateMessage(&m.messages[i])
for _, msg := range m.messages {
currentById[msg.ID] = &msg
}
for _, msg := range existingMessages {
_, ok := currentById[msg.ID]
if !ok {
err := m.ctx.Store.DeleteMessage(&msg)
if err != nil { if err != nil {
return wrapError(fmt.Errorf("Failed to remove messages: %v", err)) return err
} }
} } else if i > 0 {
} // messages is new, so add it as a reply to previous message
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i])
for i, msg := range m.messages {
if msg.ID > 0 {
exist, ok := existingById[msg.ID]
if ok {
if msg.Content == exist.Content {
continue
}
// update message when contents don't match that of store
err := m.ctx.Store.UpdateMessage(&msg)
if err != nil { if err != nil {
return wrapError(err) return err
} }
m.messages[i] = saved[0]
} else { } else {
// this would be quite odd... and I'm not sure how to handle // message has no id and no previous messages to add it to
// it at the time of writing this // this shouldn't happen?
} return fmt.Errorf("Error: no messages to reply to")
} else {
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
} }
} }
return nil return nil
} }
func (m *chatModel) generateConversationTitle() tea.Cmd { func (m *chatModel) generateConversationTitle() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation) title, err := cmdutil.GenerateTitle(m.ctx, m.messages)
if err != nil { if err != nil {
return msgError(err) return msgError(err)
} }

View File

@ -2,7 +2,6 @@ package tui
import ( import (
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
@ -145,24 +144,16 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
func (m *conversationsModel) loadConversations() tea.Cmd { func (m *conversationsModel) loadConversations() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
conversations, err := m.ctx.Store.Conversations() messages, err := m.ctx.Store.LatestConversationMessages()
if err != nil { if err != nil {
return msgError(fmt.Errorf("Could not load conversations: %v", err)) return msgError(fmt.Errorf("Could not load conversations: %v", err))
} }
loaded := make([]loadedConversation, len(conversations)) loaded := make([]loadedConversation, len(messages))
for i, c := range conversations { for i, m := range messages {
lastMessage, err := m.ctx.Store.LastMessage(&c) loaded[i].lastReply = m
if err != nil { loaded[i].conv = m.Conversation
return msgError(err)
} }
loaded[i].conv = c
loaded[i].lastReply = *lastMessage
}
slices.SortFunc(loaded, func(a, b loadedConversation) int {
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
})
return msgConversationsLoaded(loaded) return msgConversationsLoaded(loaded)
} }