Add message branching
Updated the behaviour of commands: - `lmcli edit` - by default create a new branch/message branch with the edited contents - add --in-place to avoid creating a branch - no longer delete messages after the edited message - only do the edit, don't fetch a new response - `lmcli retry` - create a new branch rather than replacing old messages - add --offset to change where to retry from
This commit is contained in:
parent
f6e55f6bff
commit
8c53752146
@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,36 +27,12 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
messagesToCopy, err := ctx.Store.Messages(toClone)
|
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
return fmt.Errorf("Failed to clone conversation: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clone := &model.Conversation{
|
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
|
||||||
Title: toClone.Title + " - Clone",
|
|
||||||
}
|
|
||||||
if err := ctx.Store.SaveConversation(clone); err != nil {
|
|
||||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
messageCnt := 0
|
|
||||||
for _, message := range messagesToCopy {
|
|
||||||
newMessage := message
|
|
||||||
newMessage.ConversationID = clone.ID
|
|
||||||
newMessage.ID = 0
|
|
||||||
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
} else {
|
|
||||||
messageCnt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
@ -39,21 +39,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
desiredIdx := len(messages) - 1 - offset
|
desiredIdx := len(messages) - 1 - offset
|
||||||
|
toEdit := messages[desiredIdx]
|
||||||
// walk backwards through the conversation deleting messages until and
|
|
||||||
// including the last user message
|
|
||||||
toRemove := []model.Message{}
|
|
||||||
var toEdit *model.Message
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if i == desiredIdx {
|
|
||||||
toEdit = &messages[i]
|
|
||||||
}
|
|
||||||
toRemove = append(toRemove, messages[i])
|
|
||||||
messages = messages[:i]
|
|
||||||
if toEdit != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||||
switch newContents {
|
switch newContents {
|
||||||
@ -63,26 +49,38 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toEdit.Content = newContents
|
||||||
|
|
||||||
role, _ := cmd.Flags().GetString("role")
|
role, _ := cmd.Flags().GetString("role")
|
||||||
if role == "" {
|
if role != "" {
|
||||||
role = string(toEdit.Role)
|
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||||
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
|
||||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||||
}
|
}
|
||||||
|
toEdit.Role = model.MessageRole(role)
|
||||||
|
}
|
||||||
|
|
||||||
for _, message := range toRemove {
|
// Update the message in-place
|
||||||
err = ctx.Store.DeleteMessage(&message)
|
inplace, _ := cmd.Flags().GetBool("in-place")
|
||||||
|
if inplace {
|
||||||
|
return ctx.Store.UpdateMessage(&toEdit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, create a branch for the edited message
|
||||||
|
message, _, err := ctx.Store.CloneBranch(toEdit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not delete message: %v\n", err)
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
if desiredIdx > 0 {
|
||||||
ConversationID: conversation.ID,
|
// update selected reply
|
||||||
Role: model.MessageRole(role),
|
messages[desiredIdx-1].SelectedReply = message
|
||||||
Content: newContents,
|
err = ctx.Store.UpdateMessage(&messages[desiredIdx-1])
|
||||||
})
|
} else {
|
||||||
return nil
|
// update selected root
|
||||||
|
conversation.SelectedRoot = message
|
||||||
|
err = ctx.Store.UpdateConversation(conversation)
|
||||||
|
}
|
||||||
|
return err
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
@ -93,8 +91,9 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch")
|
||||||
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
||||||
cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)")
|
cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)")
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
@ -21,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "List conversations",
|
Short: "List conversations",
|
||||||
Long: `List conversations in order of recent activity`,
|
Long: `List conversations in order of recent activity`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
conversations, err := ctx.Store.Conversations()
|
messages, err := ctx.Store.LatestConversationMessages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||||
}
|
}
|
||||||
@ -58,13 +57,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
all, _ := cmd.Flags().GetBool("all")
|
all, _ := cmd.Flags().GetBool("all")
|
||||||
|
|
||||||
for _, conversation := range conversations {
|
for _, message := range messages {
|
||||||
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
messageAge := now.Sub(message.CreatedAt)
|
||||||
if lastMessage == nil || err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
|
||||||
|
|
||||||
var category string
|
var category string
|
||||||
for _, c := range categories {
|
for _, c := range categories {
|
||||||
@ -76,9 +70,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
formatted := fmt.Sprintf(
|
formatted := fmt.Sprintf(
|
||||||
"%s - %s - %s",
|
"%s - %s - %s",
|
||||||
conversation.ShortName.String,
|
message.Conversation.ShortName.String,
|
||||||
util.HumanTimeElapsedSince(messageAge),
|
util.HumanTimeElapsedSince(messageAge),
|
||||||
conversation.Title,
|
message.Conversation.Title,
|
||||||
)
|
)
|
||||||
|
|
||||||
categorized[category] = append(
|
categorized[category] = append(
|
||||||
@ -96,14 +90,10 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
|
||||||
return int(a.timeSinceReply - b.timeSinceReply)
|
|
||||||
})
|
|
||||||
|
|
||||||
fmt.Printf("%s:\n", category.name)
|
fmt.Printf("%s:\n", category.name)
|
||||||
for _, conv := range conversationLines {
|
for _, conv := range conversationLines {
|
||||||
if conversationsPrinted >= count && !all {
|
if conversationsPrinted >= count && !all {
|
||||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
|
||||||
break outer
|
break outer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,42 +15,43 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Start a new conversation",
|
Short: "Start a new conversation",
|
||||||
Long: `Start a new conversation with the Large Language Model.`,
|
Long: `Start a new conversation with the Large Language Model.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
|
||||||
if messageContents == "" {
|
if input == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation := &model.Conversation{}
|
var messages []model.Message
|
||||||
err := ctx.Store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Could not save new conversation: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := []model.Message{
|
// TODO: probably just make this part of the conversation
|
||||||
{
|
system := ctx.GetSystemPrompt()
|
||||||
ConversationID: conversation.ID,
|
if system != "" {
|
||||||
|
messages = append(messages, model.Message{
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: ctx.GetSystemPrompt(),
|
Content: system,
|
||||||
},
|
})
|
||||||
{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: messageContents,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
messages = append(messages, model.Message{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: input,
|
||||||
|
})
|
||||||
|
|
||||||
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
return fmt.Errorf("Could not start a new conversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
|
||||||
|
|
||||||
|
title, err := cmdutil.GenerateTitle(ctx, messages)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
|
err = ctx.Store.UpdateConversation(conversation)
|
||||||
err = ctx.Store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
lmcli.Warn("Could not save conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -15,22 +15,27 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Do a one-shot prompt",
|
Short: "Do a one-shot prompt",
|
||||||
Long: `Prompt the Large Language Model and get a response.`,
|
Long: `Prompt the Large Language Model and get a response.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
|
||||||
if message == "" {
|
if input == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []model.Message{
|
var messages []model.Message
|
||||||
{
|
|
||||||
|
// TODO: stop supplying system prompt as a message
|
||||||
|
system := ctx.GetSystemPrompt()
|
||||||
|
if system != "" {
|
||||||
|
messages = append(messages, model.Message{
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: ctx.GetSystemPrompt(),
|
Content: system,
|
||||||
},
|
})
|
||||||
{
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: message,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
messages = append(messages, model.Message{
|
||||||
|
Role: model.MessageRoleUser,
|
||||||
|
Content: input,
|
||||||
|
})
|
||||||
|
|
||||||
_, err := cmdutil.Prompt(ctx, messages, nil)
|
_, err := cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
|
@ -24,12 +24,17 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
var title string
|
||||||
|
|
||||||
generate, _ := cmd.Flags().GetBool("generate")
|
generate, _ := cmd.Flags().GetBool("generate")
|
||||||
var title string
|
|
||||||
if generate {
|
if generate {
|
||||||
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
||||||
|
}
|
||||||
|
title, err = cmdutil.GenerateTitle(ctx, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||||
}
|
}
|
||||||
@ -41,9 +46,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.SaveConversation(conversation)
|
err = ctx.Store.UpdateConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
lmcli.Warn("Could not update conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "retry <conversation>",
|
Use: "retry <conversation>",
|
||||||
Short: "Retry the last user reply in a conversation",
|
Short: "Retry the last user reply in a conversation",
|
||||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
Long: `Prompt the conversation from the last user response.`,
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
argCount := 1
|
argCount := 1
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
@ -25,25 +25,36 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
// Load the complete thread from the root message
|
||||||
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
// walk backwards through the conversation and delete messages, break
|
offset, _ := cmd.Flags().GetInt("offset")
|
||||||
// when we find the latest user response
|
if offset < 0 {
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
offset = -offset
|
||||||
if messages[i].Role == model.MessageRoleUser {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ctx.Store.DeleteMessage(&messages[i])
|
if offset > len(messages)-1 {
|
||||||
if err != nil {
|
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||||
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true)
|
retryFromIdx := len(messages) - 1 - offset
|
||||||
|
|
||||||
|
// decrease retryFromIdx until we hit a user message
|
||||||
|
for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser {
|
||||||
|
retryFromIdx--
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages[retryFromIdx].Role != model.MessageRoleUser {
|
||||||
|
return fmt.Errorf("No user messages to retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx])
|
||||||
|
|
||||||
|
// Start a new branch at the last user message
|
||||||
|
cmdutil.HandleReply(ctx, &messages[retryFromIdx], true)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
@ -55,6 +66,8 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.Flags().Int("offset", 1, "Offset from the last message retry from.")
|
||||||
|
|
||||||
applyPromptFlags(ctx, cmd)
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -73,43 +73,57 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||||
|
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
|
}
|
||||||
|
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
||||||
|
}
|
||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
// handleConversationReply handles sending messages to an existing
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
|
||||||
existing, err := ctx.Store.Messages(c)
|
if to == nil {
|
||||||
|
lmcli.Fatal("Can't prompt from an empty message.")
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, err := ctx.Store.PathToRoot(to)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if persist {
|
RenderConversation(ctx, append(existing, messages...), true)
|
||||||
for _, message := range toSend {
|
|
||||||
err = ctx.Store.SaveMessage(&message)
|
var savedReplies []model.Message
|
||||||
|
if persist && len(messages) > 0 {
|
||||||
|
savedReplies, err = ctx.Store.Reply(to, messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
lmcli.Warn("Could not save messages: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
allMessages := append(existing, toSend...)
|
|
||||||
|
|
||||||
RenderConversation(ctx, allMessages, true)
|
|
||||||
|
|
||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||||
|
|
||||||
|
var lastSavedMessage *model.Message
|
||||||
|
lastSavedMessage = to
|
||||||
|
if len(savedReplies) > 0 {
|
||||||
|
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
||||||
|
}
|
||||||
|
|
||||||
replyCallback := func(reply model.Message) {
|
replyCallback := func(reply model.Message) {
|
||||||
if !persist {
|
if !persist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
|
||||||
reply.ConversationID = c.ID
|
|
||||||
err = ctx.Store.SaveMessage(&reply)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
lastSavedMessage = &savedReplies[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = Prompt(ctx, allMessages, replyCallback)
|
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
@ -134,12 +148,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
|
||||||
messages, err := ctx.Store.Messages(c)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
||||||
|
|
||||||
Example conversation:
|
Example conversation:
|
||||||
|
@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.RenderConversation(ctx, messages, false)
|
cmdutil.RenderConversation(ctx, messages, false)
|
||||||
|
@ -36,7 +36,9 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||||
|
//Logger: logger.Default.LogMode(logger.Info),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -17,18 +17,27 @@ const (
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
ConversationID uint `gorm:"index"`
|
||||||
|
Conversation Conversation `gorm:"foreignKey:ConversationID"`
|
||||||
Content string
|
Content string
|
||||||
Role MessageRole
|
Role MessageRole
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||||
ToolResults ToolResults // a json array of tool results
|
ToolResults ToolResults // a json array of tool results
|
||||||
|
ParentID *uint
|
||||||
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||||
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||||
|
|
||||||
|
SelectedReplyID *uint
|
||||||
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ShortName sql.NullString
|
ShortName sql.NullString
|
||||||
Title string
|
Title string
|
||||||
|
SelectedRootID *uint
|
||||||
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestParameters struct {
|
type RequestParameters struct {
|
||||||
|
@ -13,21 +13,26 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ConversationStore interface {
|
type ConversationStore interface {
|
||||||
Conversations() ([]model.Conversation, error)
|
|
||||||
|
|
||||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||||
ConversationShortNameCompletions(search string) []string
|
ConversationShortNameCompletions(search string) []string
|
||||||
|
RootMessages(conversationID uint) ([]model.Message, error)
|
||||||
|
LatestConversationMessages() ([]model.Message, error)
|
||||||
|
|
||||||
SaveConversation(conversation *model.Conversation) error
|
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
||||||
|
UpdateConversation(conversation *model.Conversation) error
|
||||||
DeleteConversation(conversation *model.Conversation) error
|
DeleteConversation(conversation *model.Conversation) error
|
||||||
|
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
||||||
|
|
||||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
MessageByID(messageID uint) (*model.Message, error)
|
||||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
MessageReplies(messageID uint) ([]model.Message, error)
|
||||||
|
|
||||||
SaveMessage(message *model.Message) error
|
|
||||||
DeleteMessage(message *model.Message) error
|
|
||||||
UpdateMessage(message *model.Message) error
|
UpdateMessage(message *model.Message) error
|
||||||
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
DeleteMessage(message *model.Message, prune bool) error
|
||||||
|
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
||||||
|
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
||||||
|
|
||||||
|
PathToRoot(message *model.Message) ([]model.Message, error)
|
||||||
|
PathToLeaf(message *model.Message) ([]model.Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLStore struct {
|
type SQLStore struct {
|
||||||
@ -52,47 +57,52 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
|||||||
return &SQLStore{db, _sqids}, nil
|
return &SQLStore{db, _sqids}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
|
||||||
err := s.db.Save(&conversation).Error
|
// Save the new conversation
|
||||||
|
err := s.db.Save(&c).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !conversation.ShortName.Valid {
|
// Generate and save its "short name"
|
||||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
||||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||||
err = s.db.Save(&conversation).Error
|
return s.UpdateConversation(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
||||||
|
if c == nil || c.ID == 0 {
|
||||||
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
}
|
}
|
||||||
|
return s.db.Updates(&c).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
||||||
|
// Delete messages first
|
||||||
|
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
return s.db.Delete(&c).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
||||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
panic("Not yet implemented")
|
||||||
return s.db.Delete(&conversation).Error
|
//return s.db.Delete(&message).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
||||||
return s.db.Create(message).Error
|
if m == nil || m.ID == 0 {
|
||||||
}
|
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||||
|
}
|
||||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
return s.db.Updates(&m).Error
|
||||||
return s.db.Delete(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
|
||||||
return s.db.Updates(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
|
||||||
var conversations []model.Conversation
|
|
||||||
err := s.db.Find(&conversations).Error
|
|
||||||
return conversations, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||||
var completions []string
|
var conversations []model.Conversation
|
||||||
conversations, _ := s.Conversations() // ignore error for completions
|
// ignore error for completions
|
||||||
|
s.db.Find(&conversations)
|
||||||
|
completions := make([]string, 0, len(conversations))
|
||||||
for _, conversation := range conversations {
|
for _, conversation := range conversations {
|
||||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||||
@ -106,27 +116,250 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
|
|||||||
return nil, errors.New("shortName is empty")
|
return nil, errors.New("shortName is empty")
|
||||||
}
|
}
|
||||||
var conversation model.Conversation
|
var conversation model.Conversation
|
||||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
return &conversation, err
|
return &conversation, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
||||||
var messages []model.Message
|
var rootMessages []model.Message
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||||
return messages, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rootMessages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
||||||
var message model.Message
|
var message model.Message
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||||
return &message, err
|
return &message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddReply adds the given messages as a reply to the given conversation, can be
|
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
||||||
// used to easily copy a message associated with one conversation, to another
|
var replies []model.Message
|
||||||
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||||
m.ConversationID = c.ID
|
return replies, err
|
||||||
m.ID = 0
|
}
|
||||||
m.CreatedAt = time.Time{}
|
|
||||||
return &m, s.SaveMessage(&m)
|
// StartConversation starts a new conversation with the provided messages
|
||||||
|
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new conversation
|
||||||
|
conversation := &model.Conversation{}
|
||||||
|
err := s.saveNewConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create first message
|
||||||
|
messages[0].ConversationID = conversation.ID
|
||||||
|
err = s.db.Create(&messages[0]).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update conversation's selected root message
|
||||||
|
conversation.SelectedRoot = &messages[0]
|
||||||
|
err = s.UpdateConversation(conversation)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add additional replies to conversation
|
||||||
|
if len(messages) > 1 {
|
||||||
|
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages = append([]model.Message{messages[0]}, newMessages...)
|
||||||
|
}
|
||||||
|
return conversation, messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneConversation clones the given conversation and all of its root meesages
|
||||||
|
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
||||||
|
rootMessages, err := s.RootMessages(toClone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clone := &model.Conversation{
|
||||||
|
Title: toClone.Title + " - Clone",
|
||||||
|
}
|
||||||
|
if err := s.saveNewConversation(clone); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
var messageCnt uint = 0
|
||||||
|
for _, root := range rootMessages {
|
||||||
|
messageCnt++
|
||||||
|
newRoot := root
|
||||||
|
newRoot.ConversationID = clone.ID
|
||||||
|
|
||||||
|
cloned, count, err := s.CloneBranch(newRoot)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messageCnt += count
|
||||||
|
|
||||||
|
if root.ID == *toClone.SelectedRootID {
|
||||||
|
clone.SelectedRootID = &cloned.ID
|
||||||
|
if err := s.UpdateConversation(clone); err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone, messageCnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reply to a message with a series of messages (each following the next)
|
||||||
|
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
|
||||||
|
var savedMessages []model.Message
|
||||||
|
|
||||||
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
currentParent := to
|
||||||
|
for i := range messages {
|
||||||
|
message := messages[i]
|
||||||
|
message.ConversationID = currentParent.ConversationID
|
||||||
|
message.ParentID = ¤tParent.ID
|
||||||
|
message.ID = 0
|
||||||
|
message.CreatedAt = time.Time{}
|
||||||
|
|
||||||
|
if err := tx.Create(&message).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update parent selected reply
|
||||||
|
currentParent.SelectedReply = &message
|
||||||
|
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
savedMessages = append(savedMessages, message)
|
||||||
|
currentParent = &message
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return savedMessages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||||
|
// a new message object. The new message will be attached to the same parent as
|
||||||
|
// the message to clone.
|
||||||
|
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
||||||
|
newMessage := messageToClone
|
||||||
|
newMessage.ID = 0
|
||||||
|
newMessage.Replies = nil
|
||||||
|
newMessage.SelectedReplyID = nil
|
||||||
|
newMessage.SelectedReply = nil
|
||||||
|
|
||||||
|
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var replyCount uint = 0
|
||||||
|
for _, reply := range originalReplies {
|
||||||
|
replyCount++
|
||||||
|
|
||||||
|
newReply := reply
|
||||||
|
newReply.ConversationID = messageToClone.ConversationID
|
||||||
|
newReply.ParentID = &newMessage.ID
|
||||||
|
newReply.Parent = &newMessage
|
||||||
|
|
||||||
|
res, c, err := s.CloneBranch(newReply)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
newMessage.Replies = append(newMessage.Replies, *res)
|
||||||
|
replyCount += c
|
||||||
|
|
||||||
|
if reply.ID == *messageToClone.SelectedReplyID {
|
||||||
|
newMessage.SelectedReplyID = &res.ID
|
||||||
|
if err := s.UpdateMessage(&newMessage); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &newMessage, replyCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathToRoot traverses message Parent until reaching the tree root
|
||||||
|
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
||||||
|
if message == nil {
|
||||||
|
return nil, fmt.Errorf("Message is nil")
|
||||||
|
}
|
||||||
|
var path []model.Message
|
||||||
|
current := message
|
||||||
|
for {
|
||||||
|
path = append([]model.Message{*current}, path...)
|
||||||
|
if current.Parent == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
current, err = s.MessageByID(*current.ParentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("finding parent message: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
|
||||||
|
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
||||||
|
if message == nil {
|
||||||
|
return nil, fmt.Errorf("Message is nil")
|
||||||
|
}
|
||||||
|
var path []model.Message
|
||||||
|
current := message
|
||||||
|
for {
|
||||||
|
path = append(path, *current)
|
||||||
|
if current.SelectedReplyID == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
current, err = s.MessageByID(*current.SelectedReplyID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("finding selected reply: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
||||||
|
var latestMessages []model.Message
|
||||||
|
|
||||||
|
subQuery := s.db.Model(&model.Message{}).
|
||||||
|
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||||
|
Group("conversation_id")
|
||||||
|
|
||||||
|
err := s.db.Model(&model.Message{}).
|
||||||
|
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||||
|
Group("messages.conversation_id").
|
||||||
|
Order("created_at DESC").
|
||||||
|
Preload("Conversation").
|
||||||
|
Find(&latestMessages).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return latestMessages, nil
|
||||||
}
|
}
|
||||||
|
119
pkg/tui/chat.go
119
pkg/tui/chat.go
@ -307,14 +307,9 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
var err error
|
err := m.persistConversation()
|
||||||
if m.conversation.ID == 0 {
|
|
||||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmds = append(cmds, wrapError(err))
|
cmds = append(cmds, wrapError(err))
|
||||||
} else {
|
|
||||||
cmds = append(cmds, m.persistConversation())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,7 +336,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
|||||||
title := string(msg)
|
title := string(msg)
|
||||||
m.conversation.Title = title
|
m.conversation.Title = title
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
err := m.ctx.Store.SaveConversation(m.conversation)
|
err := m.ctx.Store.UpdateConversation(m.conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmds = append(cmds, wrapError(err))
|
cmds = append(cmds, wrapError(err))
|
||||||
}
|
}
|
||||||
@ -469,8 +464,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
m.input.Blur()
|
m.input.Blur()
|
||||||
return true, nil
|
return true, nil
|
||||||
case "ctrl+s":
|
case "ctrl+s":
|
||||||
userInput := strings.TrimSpace(m.input.Value())
|
input := strings.TrimSpace(m.input.Value())
|
||||||
if strings.TrimSpace(userInput) == "" {
|
if input == "" {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -478,35 +473,19 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
||||||
}
|
}
|
||||||
|
|
||||||
reply := models.Message{
|
m.addMessage(models.Message{
|
||||||
Role: models.MessageRoleUser,
|
Role: models.MessageRoleUser,
|
||||||
Content: userInput,
|
Content: input,
|
||||||
}
|
})
|
||||||
|
|
||||||
if m.persistence {
|
|
||||||
var err error
|
|
||||||
if m.conversation.ID == 0 {
|
|
||||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return true, wrapError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensure all messages up to the one we're about to add are persisted
|
|
||||||
cmd := m.persistConversation()
|
|
||||||
if cmd != nil {
|
|
||||||
return true, cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
|
||||||
if err != nil {
|
|
||||||
return true, wrapError(err)
|
|
||||||
}
|
|
||||||
reply = *savedReply
|
|
||||||
}
|
|
||||||
|
|
||||||
m.input.SetValue("")
|
m.input.SetValue("")
|
||||||
m.addMessage(reply)
|
|
||||||
|
if m.persistence {
|
||||||
|
err := m.persistConversation()
|
||||||
|
if err != nil {
|
||||||
|
return true, wrapError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
m.content.GotoBottom()
|
m.content.GotoBottom()
|
||||||
@ -783,7 +762,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
|
|||||||
|
|
||||||
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
messages, err := m.ctx.Store.Messages(c)
|
messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
||||||
}
|
}
|
||||||
@ -791,62 +770,48 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *chatModel) persistConversation() tea.Cmd {
|
func (m *chatModel) persistConversation() error {
|
||||||
existingMessages, err := m.ctx.Store.Messages(m.conversation)
|
if m.conversation.ID == 0 {
|
||||||
|
// Start a new conversation with all messages so far
|
||||||
|
c, messages, err := m.ctx.Store.StartConversation(m.messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
|
return err
|
||||||
|
}
|
||||||
|
m.conversation = c
|
||||||
|
m.messages = messages
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
existingById := make(map[uint]*models.Message, len(existingMessages))
|
// else, we'll handle updating an existing conversation's messages
|
||||||
for _, msg := range existingMessages {
|
for i := 0; i < len(m.messages); i++ {
|
||||||
existingById[msg.ID] = &msg
|
if m.messages[i].ID > 0 {
|
||||||
}
|
// message has an ID, update its contents
|
||||||
|
// TODO: check for content/tool equality before updating?
|
||||||
currentById := make(map[uint]*models.Message, len(m.messages))
|
err := m.ctx.Store.UpdateMessage(&m.messages[i])
|
||||||
for _, msg := range m.messages {
|
|
||||||
currentById[msg.ID] = &msg
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range existingMessages {
|
|
||||||
_, ok := currentById[msg.ID]
|
|
||||||
if !ok {
|
|
||||||
err := m.ctx.Store.DeleteMessage(&msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
|
return err
|
||||||
}
|
}
|
||||||
}
|
} else if i > 0 {
|
||||||
}
|
// messages is new, so add it as a reply to previous message
|
||||||
|
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i])
|
||||||
for i, msg := range m.messages {
|
|
||||||
if msg.ID > 0 {
|
|
||||||
exist, ok := existingById[msg.ID]
|
|
||||||
if ok {
|
|
||||||
if msg.Content == exist.Content {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// update message when contents don't match that of store
|
|
||||||
err := m.ctx.Store.UpdateMessage(&msg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return wrapError(err)
|
return err
|
||||||
}
|
}
|
||||||
|
m.messages[i] = saved[0]
|
||||||
} else {
|
} else {
|
||||||
// this would be quite odd... and I'm not sure how to handle
|
// message has no id and no previous messages to add it to
|
||||||
// it at the time of writing this
|
// this shouldn't happen?
|
||||||
}
|
return fmt.Errorf("Error: no messages to reply to")
|
||||||
} else {
|
|
||||||
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
m.setMessage(i, *newMessage)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
|
title, err := cmdutil.GenerateTitle(m.ctx, m.messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(err)
|
return msgError(err)
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package tui
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -145,24 +144,16 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
|||||||
|
|
||||||
func (m *conversationsModel) loadConversations() tea.Cmd {
|
func (m *conversationsModel) loadConversations() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
conversations, err := m.ctx.Store.Conversations()
|
messages, err := m.ctx.Store.LatestConversationMessages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded := make([]loadedConversation, len(conversations))
|
loaded := make([]loadedConversation, len(messages))
|
||||||
for i, c := range conversations {
|
for i, m := range messages {
|
||||||
lastMessage, err := m.ctx.Store.LastMessage(&c)
|
loaded[i].lastReply = m
|
||||||
if err != nil {
|
loaded[i].conv = m.Conversation
|
||||||
return msgError(err)
|
|
||||||
}
|
}
|
||||||
loaded[i].conv = c
|
|
||||||
loaded[i].lastReply = *lastMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
slices.SortFunc(loaded, func(a, b loadedConversation) int {
|
|
||||||
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
|
|
||||||
})
|
|
||||||
|
|
||||||
return msgConversationsLoaded(loaded)
|
return msgConversationsLoaded(loaded)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user