Add message branching

Updated the behaviour of commands:

- `lmcli edit`
  - by default create a new branch/message branch with the edited contents
  - add --in-place to avoid creating a branch
  - no longer delete messages after the edited message
  - only do the edit, don't fetch a new response
- `lmcli retry`
  - create a new branch rather than replacing old messages
  - add --offset to change where to retry from
This commit is contained in:
Matt Low 2024-05-20 18:12:44 +00:00
parent f6e55f6bff
commit 8c53752146
16 changed files with 505 additions and 308 deletions

View File

@ -5,7 +5,6 @@ import (
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -28,36 +27,12 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
return err return err
} }
messagesToCopy, err := ctx.Store.Messages(toClone) clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String) return fmt.Errorf("Failed to clone conversation: %v", err)
} }
clone := &model.Conversation{ fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
Title: toClone.Title + " - Clone",
}
if err := ctx.Store.SaveConversation(clone); err != nil {
return fmt.Errorf("Cloud not create clone: %s", err)
}
var errors []error
messageCnt := 0
for _, message := range messagesToCopy {
newMessage := message
newMessage.ConversationID = clone.ID
newMessage.ID = 0
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
errors = append(errors, err)
} else {
messageCnt++
}
}
if len(errors) > 0 {
return fmt.Errorf("Messages failed to be cloned: %v", errors)
}
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {

View File

@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("could not retrieve conversation messages: %v", err) return fmt.Errorf("could not retrieve conversation messages: %v", err)
} }

View File

@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
@ -39,21 +39,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
} }
desiredIdx := len(messages) - 1 - offset desiredIdx := len(messages) - 1 - offset
toEdit := messages[desiredIdx]
// walk backwards through the conversation deleting messages until and
// including the last user message
toRemove := []model.Message{}
var toEdit *model.Message
for i := len(messages) - 1; i >= 0; i-- {
if i == desiredIdx {
toEdit = &messages[i]
}
toRemove = append(toRemove, messages[i])
messages = messages[:i]
if toEdit != nil {
break
}
}
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content) newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
switch newContents { switch newContents {
@ -63,26 +49,38 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
toEdit.Content = newContents
role, _ := cmd.Flags().GetString("role") role, _ := cmd.Flags().GetString("role")
if role == "" { if role != "" {
role = string(toEdit.Role) if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
}
for _, message := range toRemove {
err = ctx.Store.DeleteMessage(&message)
if err != nil {
lmcli.Warn("Could not delete message: %v\n", err)
} }
toEdit.Role = model.MessageRole(role)
} }
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ // Update the message in-place
ConversationID: conversation.ID, inplace, _ := cmd.Flags().GetBool("in-place")
Role: model.MessageRole(role), if inplace {
Content: newContents, return ctx.Store.UpdateMessage(&toEdit)
}) }
return nil
// Otherwise, create a branch for the edited message
message, _, err := ctx.Store.CloneBranch(toEdit)
if err != nil {
return err
}
if desiredIdx > 0 {
// update selected reply
messages[desiredIdx-1].SelectedReply = message
err = ctx.Store.UpdateMessage(&messages[desiredIdx-1])
} else {
// update selected root
conversation.SelectedRoot = message
err = ctx.Store.UpdateConversation(conversation)
}
return err
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
@ -93,8 +91,9 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch")
cmd.Flags().Int("offset", 1, "Offset from the last message to edit") cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)") cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)")
return cmd return cmd
} }

View File

@ -2,7 +2,6 @@ package cmd
import ( import (
"fmt" "fmt"
"slices"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
@ -21,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
Short: "List conversations", Short: "List conversations",
Long: `List conversations in order of recent activity`, Long: `List conversations in order of recent activity`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
conversations, err := ctx.Store.Conversations() messages, err := ctx.Store.LatestConversationMessages()
if err != nil { if err != nil {
return fmt.Errorf("Could not fetch conversations: %v", err) return fmt.Errorf("Could not fetch conversations: %v", err)
} }
@ -58,13 +57,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
all, _ := cmd.Flags().GetBool("all") all, _ := cmd.Flags().GetBool("all")
for _, conversation := range conversations { for _, message := range messages {
lastMessage, err := ctx.Store.LastMessage(&conversation) messageAge := now.Sub(message.CreatedAt)
if lastMessage == nil || err != nil {
continue
}
messageAge := now.Sub(lastMessage.CreatedAt)
var category string var category string
for _, c := range categories { for _, c := range categories {
@ -76,9 +70,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
formatted := fmt.Sprintf( formatted := fmt.Sprintf(
"%s - %s - %s", "%s - %s - %s",
conversation.ShortName.String, message.Conversation.ShortName.String,
util.HumanTimeElapsedSince(messageAge), util.HumanTimeElapsedSince(messageAge),
conversation.Title, message.Conversation.Title,
) )
categorized[category] = append( categorized[category] = append(
@ -96,14 +90,10 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
continue continue
} }
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
return int(a.timeSinceReply - b.timeSinceReply)
})
fmt.Printf("%s:\n", category.name) fmt.Printf("%s:\n", category.name)
for _, conv := range conversationLines { for _, conv := range conversationLines {
if conversationsPrinted >= count && !all { if conversationsPrinted >= count && !all {
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted) fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
break outer break outer
} }

View File

@ -15,42 +15,43 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Start a new conversation", Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`, Long: `Start a new conversation with the Large Language Model.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
if messageContents == "" { if input == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
conversation := &model.Conversation{} var messages []model.Message
err := ctx.Store.SaveConversation(conversation)
if err != nil { // TODO: probably just make this part of the conversation
return fmt.Errorf("Could not save new conversation: %v", err) system := ctx.GetSystemPrompt()
if system != "" {
messages = append(messages, model.Message{
Role: model.MessageRoleSystem,
Content: system,
})
} }
messages := []model.Message{ messages = append(messages, model.Message{
{ Role: model.MessageRoleUser,
ConversationID: conversation.ID, Content: input,
Role: model.MessageRoleSystem, })
Content: ctx.GetSystemPrompt(),
}, conversation, messages, err := ctx.Store.StartConversation(messages...)
{ if err != nil {
ConversationID: conversation.ID, return fmt.Errorf("Could not start a new conversation: %v", err)
Role: model.MessageRoleUser,
Content: messageContents,
},
} }
cmdutil.HandleConversationReply(ctx, conversation, true, messages...) cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
title, err := cmdutil.GenerateTitle(ctx, conversation) title, err := cmdutil.GenerateTitle(ctx, messages)
if err != nil { if err != nil {
lmcli.Warn("Could not generate title for conversation: %v\n", err) lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.UpdateConversation(conversation)
err = ctx.Store.SaveConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation after generating title: %v\n", err) lmcli.Warn("Could not save conversation title: %v\n", err)
} }
return nil return nil
}, },

View File

@ -15,22 +15,27 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
Short: "Do a one-shot prompt", Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`, Long: `Prompt the Large Language Model and get a response.`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "") input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
if message == "" { if input == "" {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
messages := []model.Message{ var messages []model.Message
{
// TODO: stop supplying system prompt as a message
system := ctx.GetSystemPrompt()
if system != "" {
messages = append(messages, model.Message{
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: ctx.GetSystemPrompt(), Content: system,
}, })
{
Role: model.MessageRoleUser,
Content: message,
},
} }
messages = append(messages, model.Message{
Role: model.MessageRoleUser,
Content: input,
})
_, err := cmdutil.Prompt(ctx, messages, nil) _, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)

View File

@ -24,12 +24,17 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
var err error var err error
var title string
generate, _ := cmd.Flags().GetBool("generate") generate, _ := cmd.Flags().GetBool("generate")
var title string
if generate { if generate {
title, err = cmdutil.GenerateTitle(ctx, conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil {
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
}
title, err = cmdutil.GenerateTitle(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("Could not generate conversation title: %v", err) return fmt.Errorf("Could not generate conversation title: %v", err)
} }
@ -41,9 +46,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
} }
conversation.Title = title conversation.Title = title
err = ctx.Store.SaveConversation(conversation) err = ctx.Store.UpdateConversation(conversation)
if err != nil { if err != nil {
lmcli.Warn("Could not save conversation with new title: %v\n", err) lmcli.Warn("Could not update conversation title: %v\n", err)
} }
return nil return nil
}, },

View File

@ -31,9 +31,9 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
} }
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: reply, Content: reply,
}) })
return nil return nil
}, },

View File

@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "retry <conversation>", Use: "retry <conversation>",
Short: "Retry the last user reply in a conversation", Short: "Retry the last user reply in a conversation",
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`, Long: `Prompt the conversation from the last user response.`,
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
argCount := 1 argCount := 1
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil { if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
@ -25,25 +25,36 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) // Load the complete thread from the root message
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
} }
// walk backwards through the conversation and delete messages, break offset, _ := cmd.Flags().GetInt("offset")
// when we find the latest user response if offset < 0 {
for i := len(messages) - 1; i >= 0; i-- { offset = -offset
if messages[i].Role == model.MessageRoleUser {
break
}
err = ctx.Store.DeleteMessage(&messages[i])
if err != nil {
lmcli.Warn("Could not delete previous reply: %v\n", err)
}
} }
cmdutil.HandleConversationReply(ctx, conversation, true) if offset > len(messages)-1 {
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
}
retryFromIdx := len(messages) - 1 - offset
// decrease retryFromIdx until we hit a user message
for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser {
retryFromIdx--
}
if messages[retryFromIdx].Role != model.MessageRoleUser {
return fmt.Errorf("No user messages to retry")
}
fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx])
// Start a new branch at the last user message
cmdutil.HandleReply(ctx, &messages[retryFromIdx], true)
return nil return nil
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
@ -55,6 +66,8 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
cmd.Flags().Int("offset", 1, "Offset from the last message retry from.")
applyPromptFlags(ctx, cmd) applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -73,43 +73,57 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
existing, err := ctx.Store.Messages(c) if to == nil {
if err != nil { lmcli.Fatal("Can't prompt from an empty message.")
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
} }
if persist { existing, err := ctx.Store.PathToRoot(to)
for _, message := range toSend { if err != nil {
err = ctx.Store.SaveMessage(&message) lmcli.Fatal("Could not load messages: %v\n", err)
if err != nil { }
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
} RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []model.Message
if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil {
lmcli.Warn("Could not save messages: %v\n", err)
} }
} }
allMessages := append(existing, toSend...)
RenderConversation(ctx, allMessages, true)
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastSavedMessage *model.Message
lastSavedMessage = to
if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply model.Message) { replyCallback := func(reply model.Message) {
if !persist { if !persist {
return return
} }
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil { if err != nil {
lmcli.Warn("Could not save reply: %v\n", err) lmcli.Warn("Could not save reply: %v\n", err)
} }
lastSavedMessage = &savedReplies[0]
} }
_, err = Prompt(ctx, allMessages, replyCallback) _, err = Prompt(ctx, append(existing, messages...), replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
@ -134,12 +148,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) { func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
messages, err := ctx.Store.Messages(c)
if err != nil {
return "", err
}
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective. const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
Example conversation: Example conversation:

View File

@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
shortName := args[0] shortName := args[0]
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
messages, err := ctx.Store.Messages(conversation) messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
if err != nil { if err != nil {
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
} }
cmdutil.RenderConversation(ctx, messages, false) cmdutil.RenderConversation(ctx, messages, false)

View File

@ -36,7 +36,9 @@ func NewContext() (*Context, error) {
} }
databaseFile := filepath.Join(dataDir(), "conversations.db") databaseFile := filepath.Join(dataDir(), "conversations.db")
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{}) db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
//Logger: logger.Default.LogMode(logger.Info),
})
if err != nil { if err != nil {
return nil, fmt.Errorf("Error establishing connection to store: %v", err) return nil, fmt.Errorf("Error establishing connection to store: %v", err)
} }

View File

@ -16,19 +16,28 @@ const (
) )
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"` ConversationID uint `gorm:"index"`
Conversation Conversation `gorm:"foreignKey:ConversationID"`
Content string Content string
Role MessageRole Role MessageRole
CreatedAt time.Time CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the modl) ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results ToolResults ToolResults // a json array of tool results
ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
} }
type Conversation struct { type Conversation struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ShortName sql.NullString ShortName sql.NullString
Title string Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
} }
type RequestParameters struct { type RequestParameters struct {

View File

@ -13,21 +13,26 @@ import (
) )
type ConversationStore interface { type ConversationStore interface {
Conversations() ([]model.Conversation, error)
ConversationByShortName(shortName string) (*model.Conversation, error) ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationShortNameCompletions(search string) []string ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]model.Message, error)
LatestConversationMessages() ([]model.Message, error)
SaveConversation(conversation *model.Conversation) error StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
UpdateConversation(conversation *model.Conversation) error
DeleteConversation(conversation *model.Conversation) error DeleteConversation(conversation *model.Conversation) error
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
Messages(conversation *model.Conversation) ([]model.Message, error) MessageByID(messageID uint) (*model.Message, error)
LastMessage(conversation *model.Conversation) (*model.Message, error) MessageReplies(messageID uint) ([]model.Message, error)
SaveMessage(message *model.Message) error
DeleteMessage(message *model.Message) error
UpdateMessage(message *model.Message) error UpdateMessage(message *model.Message) error
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error) DeleteMessage(message *model.Message, prune bool) error
CloneBranch(toClone model.Message) (*model.Message, uint, error)
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
PathToRoot(message *model.Message) ([]model.Message, error)
PathToLeaf(message *model.Message) ([]model.Message, error)
} }
type SQLStore struct { type SQLStore struct {
@ -52,47 +57,52 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error { func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
err := s.db.Save(&conversation).Error // Save the new conversation
err := s.db.Save(&c).Error
if err != nil { if err != nil {
return err return err
} }
if !conversation.ShortName.Valid { // Generate and save its "short name"
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)}) shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true} c.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error return s.UpdateConversation(c)
}
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(&c).Error
return err
} }
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error { func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{}) // Delete messages first
return s.db.Delete(&conversation).Error err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
if err != nil {
return err
}
return s.db.Delete(&c).Error
} }
func (s *SQLStore) SaveMessage(message *model.Message) error { func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
return s.db.Create(message).Error panic("Not yet implemented")
//return s.db.Delete(&message).Error
} }
func (s *SQLStore) DeleteMessage(message *model.Message) error { func (s *SQLStore) UpdateMessage(m *model.Message) error {
return s.db.Delete(&message).Error if m == nil || m.ID == 0 {
} return fmt.Errorf("Message is nil or invalid (missing ID)")
}
func (s *SQLStore) UpdateMessage(message *model.Message) error { return s.db.Updates(&m).Error
return s.db.Updates(&message).Error
}
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
var conversations []model.Conversation
err := s.db.Find(&conversations).Error
return conversations, err
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var completions []string var conversations []model.Conversation
conversations, _ := s.Conversations() // ignore error for completions // ignore error for completions
s.db.Find(&conversations)
completions := make([]string, 0, len(conversations))
for _, conversation := range conversations { for _, conversation := range conversations {
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
@ -106,27 +116,250 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
return nil, errors.New("shortName is empty") return nil, errors.New("shortName is empty")
} }
var conversation model.Conversation var conversation model.Conversation
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err return &conversation, err
} }
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) { func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
var messages []model.Message var rootMessages []model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
return messages, err if err != nil {
return nil, err
}
return rootMessages, nil
} }
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) { func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
var message model.Message var message model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err return &message, err
} }
// AddReply adds the given messages as a reply to the given conversation, can be func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
// used to easily copy a message associated with one conversation, to another var replies []model.Message
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) { err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
m.ConversationID = c.ID return replies, err
m.ID = 0 }
m.CreatedAt = time.Time{}
return &m, s.SaveMessage(&m) // StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message")
}
// Create new conversation
conversation := &model.Conversation{}
err := s.saveNewConversation(conversation)
if err != nil {
return nil, nil, err
}
// Create first message
messages[0].ConversationID = conversation.ID
err = s.db.Create(&messages[0]).Error
if err != nil {
return nil, nil, err
}
// Update conversation's selected root message
conversation.SelectedRoot = &messages[0]
err = s.UpdateConversation(conversation)
if err != nil {
return nil, nil, err
}
// Add additional replies to conversation
if len(messages) > 1 {
newMessages, err := s.Reply(&messages[0], messages[1:]...)
if err != nil {
return nil, nil, err
}
messages = append([]model.Message{messages[0]}, newMessages...)
}
return conversation, messages, nil
}
// CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID)
if err != nil {
return nil, 0, err
}
clone := &model.Conversation{
Title: toClone.Title + " - Clone",
}
if err := s.saveNewConversation(clone); err != nil {
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
}
var errors []error
var messageCnt uint = 0
for _, root := range rootMessages {
messageCnt++
newRoot := root
newRoot.ConversationID = clone.ID
cloned, count, err := s.CloneBranch(newRoot)
if err != nil {
errors = append(errors, err)
continue
}
messageCnt += count
if root.ID == *toClone.SelectedRootID {
clone.SelectedRootID = &cloned.ID
if err := s.UpdateConversation(clone); err != nil {
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
}
}
}
if len(errors) > 0 {
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
}
return clone, messageCnt, nil
}
// Reply to a message with a series of messages (each following the next)
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
var savedMessages []model.Message
err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to
for i := range messages {
message := messages[i]
message.ConversationID = currentParent.ConversationID
message.ParentID = &currentParent.ID
message.ID = 0
message.CreatedAt = time.Time{}
if err := tx.Create(&message).Error; err != nil {
return err
}
// update parent selected reply
currentParent.SelectedReply = &message
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
return err
}
savedMessages = append(savedMessages, message)
currentParent = &message
}
return nil
})
return savedMessages, err
}
// CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as
// the message to clone.
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
newMessage := messageToClone
newMessage.ID = 0
newMessage.Replies = nil
newMessage.SelectedReplyID = nil
newMessage.SelectedReply = nil
originalReplies, err := s.MessageReplies(messageToClone.ID)
if err != nil {
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
}
if err := s.db.Create(&newMessage).Error; err != nil {
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
}
var replyCount uint = 0
for _, reply := range originalReplies {
replyCount++
newReply := reply
newReply.ConversationID = messageToClone.ConversationID
newReply.ParentID = &newMessage.ID
newReply.Parent = &newMessage
res, c, err := s.CloneBranch(newReply)
if err != nil {
return nil, 0, err
}
newMessage.Replies = append(newMessage.Replies, *res)
replyCount += c
if reply.ID == *messageToClone.SelectedReplyID {
newMessage.SelectedReplyID = &res.ID
if err := s.UpdateMessage(&newMessage); err != nil {
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
}
}
}
return &newMessage, replyCount, nil
}
// PathToRoot traverses message Parent until reaching the tree root
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
}
var path []model.Message
current := message
for {
path = append([]model.Message{*current}, path...)
if current.Parent == nil {
break
}
var err error
current, err = s.MessageByID(*current.ParentID)
if err != nil {
return nil, fmt.Errorf("finding parent message: %w", err)
}
}
return path, nil
}
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil {
return nil, fmt.Errorf("Message is nil")
}
var path []model.Message
current := message
for {
path = append(path, *current)
if current.SelectedReplyID == nil {
break
}
var err error
current, err = s.MessageByID(*current.SelectedReplyID)
if err != nil {
return nil, fmt.Errorf("finding selected reply: %w", err)
}
}
return path, nil
}
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
var latestMessages []model.Message
subQuery := s.db.Model(&model.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id")
err := s.db.Model(&model.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Group("messages.conversation_id").
Order("created_at DESC").
Preload("Conversation").
Find(&latestMessages).Error
if err != nil {
return nil, err
}
return latestMessages, nil
} }

View File

@ -307,14 +307,9 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
} }
if m.persistence { if m.persistence {
var err error err := m.persistConversation()
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} else {
cmds = append(cmds, m.persistConversation())
} }
} }
@ -341,7 +336,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
title := string(msg) title := string(msg)
m.conversation.Title = title m.conversation.Title = title
if m.persistence { if m.persistence {
err := m.ctx.Store.SaveConversation(m.conversation) err := m.ctx.Store.UpdateConversation(m.conversation)
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} }
@ -469,8 +464,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
m.input.Blur() m.input.Blur()
return true, nil return true, nil
case "ctrl+s": case "ctrl+s":
userInput := strings.TrimSpace(m.input.Value()) input := strings.TrimSpace(m.input.Value())
if strings.TrimSpace(userInput) == "" { if input == "" {
return true, nil return true, nil
} }
@ -478,35 +473,19 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return true, wrapError(fmt.Errorf("Can't reply to a user message")) return true, wrapError(fmt.Errorf("Can't reply to a user message"))
} }
reply := models.Message{ m.addMessage(models.Message{
Role: models.MessageRoleUser, Role: models.MessageRoleUser,
Content: userInput, Content: input,
} })
if m.persistence {
var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil {
return true, wrapError(err)
}
// ensure all messages up to the one we're about to add are persisted
cmd := m.persistConversation()
if cmd != nil {
return true, cmd
}
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil {
return true, wrapError(err)
}
reply = *savedReply
}
m.input.SetValue("") m.input.SetValue("")
m.addMessage(reply)
if m.persistence {
err := m.persistConversation()
if err != nil {
return true, wrapError(err)
}
}
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
@ -783,7 +762,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd { func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
messages, err := m.ctx.Store.Messages(c) messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
} }
@ -791,62 +770,48 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
} }
} }
func (m *chatModel) persistConversation() tea.Cmd { func (m *chatModel) persistConversation() error {
existingMessages, err := m.ctx.Store.Messages(m.conversation) if m.conversation.ID == 0 {
if err != nil { // Start a new conversation with all messages so far
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err)) c, messages, err := m.ctx.Store.StartConversation(m.messages...)
} if err != nil {
return err
existingById := make(map[uint]*models.Message, len(existingMessages))
for _, msg := range existingMessages {
existingById[msg.ID] = &msg
}
currentById := make(map[uint]*models.Message, len(m.messages))
for _, msg := range m.messages {
currentById[msg.ID] = &msg
}
for _, msg := range existingMessages {
_, ok := currentById[msg.ID]
if !ok {
err := m.ctx.Store.DeleteMessage(&msg)
if err != nil {
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
}
} }
m.conversation = c
m.messages = messages
return nil
} }
for i, msg := range m.messages { // else, we'll handle updating an existing conversation's messages
if msg.ID > 0 { for i := 0; i < len(m.messages); i++ {
exist, ok := existingById[msg.ID] if m.messages[i].ID > 0 {
if ok { // message has an ID, update its contents
if msg.Content == exist.Content { // TODO: check for content/tool equality before updating?
continue err := m.ctx.Store.UpdateMessage(&m.messages[i])
} if err != nil {
// update message when contents don't match that of store return err
err := m.ctx.Store.UpdateMessage(&msg)
if err != nil {
return wrapError(err)
}
} else {
// this would be quite odd... and I'm not sure how to handle
// it at the time of writing this
} }
} else if i > 0 {
// messages is new, so add it as a reply to previous message
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i])
if err != nil {
return err
}
m.messages[i] = saved[0]
} else { } else {
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg) // message has no id and no previous messages to add it to
if err != nil { // this shouldn't happen?
return wrapError(err) return fmt.Errorf("Error: no messages to reply to")
}
m.setMessage(i, *newMessage)
} }
} }
return nil return nil
} }
func (m *chatModel) generateConversationTitle() tea.Cmd { func (m *chatModel) generateConversationTitle() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation) title, err := cmdutil.GenerateTitle(m.ctx, m.messages)
if err != nil { if err != nil {
return msgError(err) return msgError(err)
} }

View File

@ -2,7 +2,6 @@ package tui
import ( import (
"fmt" "fmt"
"slices"
"strings" "strings"
"time" "time"
@ -145,25 +144,17 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
func (m *conversationsModel) loadConversations() tea.Cmd { func (m *conversationsModel) loadConversations() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
conversations, err := m.ctx.Store.Conversations() messages, err := m.ctx.Store.LatestConversationMessages()
if err != nil { if err != nil {
return msgError(fmt.Errorf("Could not load conversations: %v", err)) return msgError(fmt.Errorf("Could not load conversations: %v", err))
} }
loaded := make([]loadedConversation, len(conversations)) loaded := make([]loadedConversation, len(messages))
for i, c := range conversations { for i, m := range messages {
lastMessage, err := m.ctx.Store.LastMessage(&c) loaded[i].lastReply = m
if err != nil { loaded[i].conv = m.Conversation
return msgError(err)
}
loaded[i].conv = c
loaded[i].lastReply = *lastMessage
} }
slices.SortFunc(loaded, func(a, b loadedConversation) int {
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
})
return msgConversationsLoaded(loaded) return msgConversationsLoaded(loaded)
} }
} }