diff --git a/pkg/api/api.go b/pkg/api/api.go new file mode 100644 index 0000000..c26c5b1 --- /dev/null +++ b/pkg/api/api.go @@ -0,0 +1,118 @@ +package api + +import ( + "encoding/json" + "fmt" +) + +type MessageRole string + +const ( + MessageRoleSystem MessageRole = "system" + MessageRoleUser MessageRole = "user" + MessageRoleAssistant MessageRole = "assistant" + MessageRoleToolCall MessageRole = "tool_call" + MessageRoleToolResult MessageRole = "tool_result" +) + +type Message struct { + Content string // TODO: support multi-part messages + Role MessageRole + ToolCalls []ToolCall + ToolResults []ToolResult +} + +type ToolSpec struct { + Name string + Description string + Parameters []ToolParameter + Impl func(*ToolSpec, map[string]interface{}) (string, error) +} + +type ToolParameter struct { + Name string `json:"name"` + Type string `json:"type"` // "string", "integer", "boolean" + Required bool `json:"required"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +type ToolCall struct { + ID string `json:"id" yaml:"-"` + Name string `json:"name" yaml:"tool"` + Parameters map[string]interface{} `json:"parameters" yaml:"parameters"` +} + +type ToolResult struct { + ToolCallID string `json:"toolCallID" yaml:"-"` + ToolName string `json:"toolName,omitempty" yaml:"tool"` + Result string `json:"result,omitempty" yaml:"result"` +} + +func NewMessageWithAssistant(content string) *Message { + return &Message{ + Role: MessageRoleAssistant, + Content: content, + } +} + +func NewMessageWithToolCalls(content string, toolCalls []ToolCall) *Message { + return &Message{ + Role: MessageRoleToolCall, + Content: content, + ToolCalls: toolCalls, + } +} + +func (m MessageRole) IsAssistant() bool { + switch m { + case MessageRoleAssistant, MessageRoleToolCall: + return true + } + return false +} + +func (m MessageRole) IsSystem() bool { + switch m { + case MessageRoleSystem: + return true + } + return false +} + +// FriendlyRole returns a human friendly signifier for the message's role. +func (m MessageRole) FriendlyRole() string { + switch m { + case MessageRoleUser: + return "You" + case MessageRoleSystem: + return "System" + case MessageRoleAssistant: + return "Assistant" + case MessageRoleToolCall: + return "Tool Call" + case MessageRoleToolResult: + return "Tool Result" + default: + return string(m) + } +} + +// TODO: remove this +type CallResult struct { + Message string `json:"message"` + Result any `json:"result,omitempty"` +} + +func (r CallResult) ToJson() (string, error) { + if r.Message == "" { + // When message not supplied, assume success + r.Message = "success" + } + + jsonBytes, err := json.Marshal(r) + if err != nil { + return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err) + } + return string(jsonBytes), nil +} diff --git a/pkg/api/conversation.go b/pkg/api/conversation.go deleted file mode 100644 index c35ebd3..0000000 --- a/pkg/api/conversation.go +++ /dev/null @@ -1,106 +0,0 @@ -package api - -import ( - "database/sql" - "database/sql/driver" - "encoding/json" - "time" -) - -type Conversation struct { - ID uint `gorm:"primaryKey"` - ShortName sql.NullString - Title string - SelectedRootID *uint - SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` -} - -type MessageRole string - -const ( - MessageRoleSystem MessageRole = "system" - MessageRoleUser MessageRole = "user" - MessageRoleAssistant MessageRole = "assistant" - MessageRoleToolCall MessageRole = "tool_call" - MessageRoleToolResult MessageRole = "tool_result" -) - -type MessageMeta struct { - GenerationProvider *string `json:"generation_provider,omitempty"` - GenerationModel *string `json:"generation_model,omitempty"` -} - -type Message struct { - ID uint `gorm:"primaryKey"` - CreatedAt time.Time - Metadata MessageMeta - - ConversationID *uint `gorm:"index"` - Conversation *Conversation `gorm:"foreignKey:ConversationID"` - ParentID *uint - Parent *Message `gorm:"foreignKey:ParentID"` - Replies []Message `gorm:"foreignKey:ParentID"` - SelectedReplyID *uint - SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` - - Role MessageRole - Content string - ToolCalls ToolCalls // a json array of tool calls (from the model) - ToolResults ToolResults // a json array of tool results -} - -func (m *MessageMeta) Scan(value interface{}) error { - return json.Unmarshal(value.([]byte), m) -} - -func (m MessageMeta) Value() (driver.Value, error) { - return json.Marshal(m) -} - -func ApplySystemPrompt(m []Message, system string, force bool) []Message { - if len(m) > 0 && m[0].Role == MessageRoleSystem { - if force { - m[0].Content = system - } - return m - } else { - return append([]Message{{ - Role: MessageRoleSystem, - Content: system, - }}, m...) - } -} - -func (m MessageRole) IsAssistant() bool { - switch m { - case MessageRoleAssistant, MessageRoleToolCall: - return true - } - return false -} - -func (m MessageRole) IsSystem() bool { - switch m { - case MessageRoleSystem: - return true - } - return false -} - -// FriendlyRole returns a human friendly signifier for the message's role. -func (m MessageRole) FriendlyRole() string { - switch m { - case MessageRoleUser: - return "You" - case MessageRoleSystem: - return "System" - case MessageRoleAssistant: - return "Assistant" - case MessageRoleToolCall: - return "Tool Call" - case MessageRoleToolResult: - return "Tool Result" - default: - return string(m) - } -} diff --git a/pkg/api/tools.go b/pkg/api/tools.go deleted file mode 100644 index f4e3c72..0000000 --- a/pkg/api/tools.go +++ /dev/null @@ -1,98 +0,0 @@ -package api - -import ( - "database/sql/driver" - "encoding/json" - "fmt" -) - -type ToolSpec struct { - Name string - Description string - Parameters []ToolParameter - Impl func(*ToolSpec, map[string]interface{}) (string, error) -} - -type ToolParameter struct { - Name string `json:"name"` - Type string `json:"type"` // "string", "integer", "boolean" - Required bool `json:"required"` - Description string `json:"description"` - Enum []string `json:"enum,omitempty"` -} - -type ToolCall struct { - ID string `json:"id" yaml:"-"` - Name string `json:"name" yaml:"tool"` - Parameters map[string]interface{} `json:"parameters" yaml:"parameters"` -} - -type ToolResult struct { - ToolCallID string `json:"toolCallID" yaml:"-"` - ToolName string `json:"toolName,omitempty" yaml:"tool"` - Result string `json:"result,omitempty" yaml:"result"` -} - -type ToolCalls []ToolCall - -func (tc *ToolCalls) Scan(value any) (err error) { - s := value.(string) - if value == nil || s == "" { - *tc = nil - return - } - err = json.Unmarshal([]byte(s), tc) - return -} - -func (tc ToolCalls) Value() (driver.Value, error) { - if len(tc) == 0 { - return "", nil - } - jsonBytes, err := json.Marshal(tc) - if err != nil { - return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err) - } - return string(jsonBytes), nil -} - -type ToolResults []ToolResult - -func (tr *ToolResults) Scan(value any) (err error) { - s := value.(string) - if value == nil || s == "" { - *tr = nil - return - } - err = json.Unmarshal([]byte(s), tr) - return -} - -func (tr ToolResults) Value() (driver.Value, error) { - if len(tr) == 0 { - return "", nil - } - jsonBytes, err := json.Marshal([]ToolResult(tr)) - if err != nil { - return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err) - } - return string(jsonBytes), nil -} - -type CallResult struct { - Message string `json:"message"` - Result any `json:"result,omitempty"` -} - -func (r CallResult) ToJson() (string, error) { - if r.Message == "" { - // When message not supplied, assume success - r.Message = "success" - } - - jsonBytes, err := json.Marshal(r) - if err != nil { - return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err) - } - return string(jsonBytes), nil -} diff --git a/pkg/cmd/chat.go b/pkg/cmd/chat.go index 4d02f2a..4a2cece 100644 --- a/pkg/cmd/chat.go +++ b/pkg/cmd/chat.go @@ -54,7 +54,7 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } diff --git a/pkg/cmd/clone.go b/pkg/cmd/clone.go index 5055257..18c192b 100644 --- a/pkg/cmd/clone.go +++ b/pkg/cmd/clone.go @@ -27,7 +27,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command { return err } - clone, messageCnt, err := ctx.Store.CloneConversation(*toClone) + clone, messageCnt, err := ctx.Conversations.CloneConversation(*toClone) if err != nil { return fmt.Errorf("Failed to clone conversation: %v", err) } @@ -40,7 +40,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } return cmd diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index 965efd1..e634cf4 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -29,9 +29,9 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { } shortName := args[0] - conversation := cmdutil.LookupConversation(ctx, shortName) + c := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) + messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot) if err != nil { return fmt.Errorf("could not retrieve conversation messages: %v", err) } @@ -58,7 +58,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ") // Update the original message - err = ctx.Store.UpdateMessage(lastMessage) + err = ctx.Conversations.UpdateMessage(lastMessage) if err != nil { return fmt.Errorf("could not update the last message: %v", err) } @@ -70,7 +70,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } applyGenerationFlags(ctx, cmd) diff --git a/pkg/cmd/edit.go b/pkg/cmd/edit.go index fe6dd28..9ffbeb0 100644 --- a/pkg/cmd/edit.go +++ b/pkg/cmd/edit.go @@ -22,11 +22,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { }, RunE: func(cmd *cobra.Command, args []string) error { shortName := args[0] - conversation := cmdutil.LookupConversation(ctx, shortName) + c := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) + messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot) if err != nil { - return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) + return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title) } offset, _ := cmd.Flags().GetInt("offset") @@ -62,11 +62,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { // Update the message in-place inplace, _ := cmd.Flags().GetBool("in-place") if inplace { - return ctx.Store.UpdateMessage(&toEdit) + return ctx.Conversations.UpdateMessage(&toEdit) } // Otherwise, create a branch for the edited message - message, _, err := ctx.Store.CloneBranch(toEdit) + message, _, err := ctx.Conversations.CloneBranch(toEdit) if err != nil { return err } @@ -74,11 +74,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { if desiredIdx > 0 { // update selected reply messages[desiredIdx-1].SelectedReply = message - err = ctx.Store.UpdateMessage(&messages[desiredIdx-1]) + err = ctx.Conversations.UpdateMessage(&messages[desiredIdx-1]) } else { // update selected root - conversation.SelectedRoot = message - err = ctx.Store.UpdateConversation(conversation) + c.SelectedRoot = message + err = ctx.Conversations.UpdateConversation(c) } return err }, @@ -87,7 +87,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } diff --git a/pkg/cmd/list.go b/pkg/cmd/list.go index 6367ec0..2acf770 100644 --- a/pkg/cmd/list.go +++ b/pkg/cmd/list.go @@ -20,7 +20,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command { Short: "List conversations", Long: `List conversations in order of recent activity`, RunE: func(cmd *cobra.Command, args []string) error { - messages, err := ctx.Store.LatestConversationMessages() + messages, err := ctx.Conversations.LatestConversationMessages() if err != nil { return fmt.Errorf("Could not fetch conversations: %v", err) } diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go index e4a61fc..9f86095 100644 --- a/pkg/cmd/new.go +++ b/pkg/cmd/new.go @@ -5,6 +5,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "github.com/spf13/cobra" ) @@ -25,12 +26,12 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } - messages := []api.Message{{ + messages := []conversation.Message{{ Role: api.MessageRoleUser, Content: input, }} - conversation, messages, err := ctx.Store.StartConversation(messages...) + conversation, messages, err := ctx.Conversations.StartConversation(messages...) if err != nil { return fmt.Errorf("Could not start a new conversation: %v", err) } @@ -43,7 +44,7 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { } conversation.Title = title - err = ctx.Store.UpdateConversation(conversation) + err = ctx.Conversations.UpdateConversation(conversation) if err != nil { lmcli.Warn("Could not save conversation title: %v\n", err) } diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index abab6b9..953d51f 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -5,6 +5,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "github.com/spf13/cobra" ) @@ -25,7 +26,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } - messages := []api.Message{{ + messages := []conversation.Message{{ Role: api.MessageRoleUser, Content: input, }} diff --git a/pkg/cmd/remove.go b/pkg/cmd/remove.go index 8079ffb..13f6f64 100644 --- a/pkg/cmd/remove.go +++ b/pkg/cmd/remove.go @@ -4,8 +4,8 @@ import ( "fmt" "strings" - "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "github.com/spf13/cobra" ) @@ -23,14 +23,14 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - var toRemove []*api.Conversation + var toRemove []*conversation.Conversation for _, shortName := range args { conversation := cmdutil.LookupConversation(ctx, shortName) toRemove = append(toRemove, conversation) } var errors []error for _, c := range toRemove { - err := ctx.Store.DeleteConversation(c) + err := ctx.Conversations.DeleteConversation(c) if err != nil { errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err)) } @@ -44,7 +44,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command { compMode := cobra.ShellCompDirectiveNoFileComp var completions []string outer: - for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) { + for _, completion := range ctx.Conversations.ConversationShortNameCompletions(toComplete) { parts := strings.Split(completion, "\t") for _, arg := range args { if parts[0] == arg { diff --git a/pkg/cmd/rename.go b/pkg/cmd/rename.go index c45bbfd..352fce4 100644 --- a/pkg/cmd/rename.go +++ b/pkg/cmd/rename.go @@ -30,7 +30,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command { generate, _ := cmd.Flags().GetBool("generate") if generate { - messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) + messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot) if err != nil { return fmt.Errorf("Could not retrieve conversation messages: %v", err) } @@ -46,7 +46,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command { } conversation.Title = title - err = ctx.Store.UpdateConversation(conversation) + err = ctx.Conversations.UpdateConversation(conversation) if err != nil { lmcli.Warn("Could not update conversation title: %v\n", err) } @@ -57,7 +57,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } diff --git a/pkg/cmd/reply.go b/pkg/cmd/reply.go index a0c0a65..d274200 100644 --- a/pkg/cmd/reply.go +++ b/pkg/cmd/reply.go @@ -5,6 +5,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "github.com/spf13/cobra" ) @@ -28,14 +29,14 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { } shortName := args[0] - conversation := cmdutil.LookupConversation(ctx, shortName) + c := cmdutil.LookupConversation(ctx, shortName) reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "") if reply == "" { return fmt.Errorf("No reply was provided.") } - cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{ + cmdutil.HandleConversationReply(ctx, c, true, conversation.Message{ Role: api.MessageRoleUser, Content: reply, }) @@ -46,7 +47,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } diff --git a/pkg/cmd/retry.go b/pkg/cmd/retry.go index e2ba866..d66e9ba 100644 --- a/pkg/cmd/retry.go +++ b/pkg/cmd/retry.go @@ -28,12 +28,12 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { } shortName := args[0] - conversation := cmdutil.LookupConversation(ctx, shortName) + c := cmdutil.LookupConversation(ctx, shortName) // Load the complete thread from the root message - messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) + messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot) if err != nil { - return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title) + return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title) } offset, _ := cmd.Flags().GetInt("offset") @@ -67,7 +67,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 2ca2d6f..19f481a 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -9,7 +9,8 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/util" "github.com/charmbracelet/lipgloss" @@ -17,7 +18,7 @@ import ( // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. -func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { +func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) { m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") if err != nil { return nil, err @@ -40,7 +41,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag } if system != "" { - messages = api.ApplySystemPrompt(messages, system, false) + messages = conversation.ApplySystemPrompt(messages, system, false) } content := make(chan provider.Chunk) @@ -50,7 +51,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag go ShowDelayedContent(content) reply, err := p.CreateChatCompletionStream( - context.Background(), params, messages, content, + context.Background(), params, conversation.MessagesToAPI(messages), content, ) if reply.Content != "" { @@ -67,8 +68,8 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag // lookupConversation either returns the conversation found by the // short name or exits the program -func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation { - c, err := ctx.Store.ConversationByShortName(shortName) +func LookupConversation(ctx *lmcli.Context, shortName string) *conversation.Conversation { + c, err := ctx.Conversations.FindConversationByShortName(shortName) if err != nil { lmcli.Fatal("Could not lookup conversation: %v\n", err) } @@ -78,8 +79,8 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation return c } -func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) { - c, err := ctx.Store.ConversationByShortName(shortName) +func LookupConversationE(ctx *lmcli.Context, shortName string) (*conversation.Conversation, error) { + c, err := ctx.Conversations.FindConversationByShortName(shortName) if err != nil { return nil, fmt.Errorf("Could not lookup conversation: %v", err) } @@ -89,8 +90,8 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversatio return c, nil } -func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) { - messages, err := ctx.Store.PathToLeaf(c.SelectedRoot) +func HandleConversationReply(ctx *lmcli.Context, c *conversation.Conversation, persist bool, toSend ...conversation.Message) { + messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot) if err != nil { lmcli.Fatal("Could not load messages: %v\n", err) } @@ -99,40 +100,40 @@ func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bo // handleConversationReply handles sending messages to an existing // conversation, optionally persisting both the sent replies and responses. -func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) { +func HandleReply(ctx *lmcli.Context, to *conversation.Message, persist bool, messages ...conversation.Message) { if to == nil { lmcli.Fatal("Can't prompt from an empty message.") } - existing, err := ctx.Store.PathToRoot(to) + existing, err := ctx.Conversations.PathToRoot(to) if err != nil { lmcli.Fatal("Could not load messages: %v\n", err) } RenderConversation(ctx, append(existing, messages...), true) - var savedReplies []api.Message + var savedReplies []conversation.Message if persist && len(messages) > 0 { - savedReplies, err = ctx.Store.Reply(to, messages...) + savedReplies, err = ctx.Conversations.Reply(to, messages...) if err != nil { lmcli.Warn("Could not save messages: %v\n", err) } } // render a message header with no contents - RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant})) + RenderMessage(ctx, (&conversation.Message{Role: api.MessageRoleAssistant})) - var lastSavedMessage *api.Message + var lastSavedMessage *conversation.Message lastSavedMessage = to if len(savedReplies) > 0 { lastSavedMessage = &savedReplies[len(savedReplies)-1] } - replyCallback := func(reply api.Message) { + replyCallback := func(reply conversation.Message) { if !persist { return } - savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply) + savedReplies, err = ctx.Conversations.Reply(lastSavedMessage, reply) if err != nil { lmcli.Warn("Could not save reply: %v\n", err) } @@ -145,7 +146,7 @@ func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ... } } -func FormatForExternalPrompt(messages []api.Message, system bool) string { +func FormatForExternalPrompt(messages []conversation.Message, system bool) string { sb := strings.Builder{} for _, message := range messages { if message.Content == "" { @@ -164,7 +165,7 @@ func FormatForExternalPrompt(messages []api.Message, system bool) string { return sb.String() } -func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) { +func GenerateTitle(ctx *lmcli.Context, messages []conversation.Message) (string, error) { const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below. Example conversation: @@ -189,19 +190,19 @@ Example response: } // Serialize the conversation to JSON - conversation, err := json.Marshal(msgs) + jsonBytes, err := json.Marshal(msgs) if err != nil { return "", err } - generateRequest := []api.Message{ + generateRequest := []conversation.Message{ { Role: api.MessageRoleSystem, Content: systemPrompt, }, { Role: api.MessageRoleUser, - Content: string(conversation), + Content: string(jsonBytes), }, } @@ -218,7 +219,7 @@ Example response: } response, err := p.CreateChatCompletion( - context.Background(), requestParams, generateRequest, + context.Background(), requestParams, conversation.MessagesToAPI(generateRequest), ) if err != nil { return "", err @@ -293,7 +294,7 @@ func ShowDelayedContent(content <-chan provider.Chunk) { // RenderConversation renders the given messages to TTY, with optional space // for a subsequent message. spaceForResponse controls how many '\n' characters // are printed immediately after the final message (1 if false, 2 if true) -func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) { +func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spaceForResponse bool) { l := len(messages) for i, message := range messages { RenderMessage(ctx, &message) @@ -304,7 +305,7 @@ func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResp } } -func RenderMessage(ctx *lmcli.Context, m *api.Message) { +func RenderMessage(ctx *lmcli.Context, m *conversation.Message) { var messageAge string if m.CreatedAt.IsZero() { messageAge = "now" diff --git a/pkg/cmd/view.go b/pkg/cmd/view.go index 0da608a..ee122da 100644 --- a/pkg/cmd/view.go +++ b/pkg/cmd/view.go @@ -24,7 +24,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) - messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot) + messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot) if err != nil { return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err) } @@ -37,7 +37,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command { if len(args) != 0 { return nil, compMode } - return ctx.Store.ConversationShortNameCompletions(toComplete), compMode + return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode }, } diff --git a/pkg/conversation/conversation.go b/pkg/conversation/conversation.go new file mode 100644 index 0000000..356f2a6 --- /dev/null +++ b/pkg/conversation/conversation.go @@ -0,0 +1,98 @@ +package conversation + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "time" + + "git.mlow.ca/mlow/lmcli/pkg/api" +) + +type Conversation struct { + ID uint `gorm:"primaryKey"` + ShortName sql.NullString + Title string + SelectedRootID *uint + SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` + RootMessages []Message `gorm:"-:all"` +} + +type MessageMeta struct { + GenerationProvider *string `json:"generation_provider,omitempty"` + GenerationModel *string `json:"generation_model,omitempty"` +} + +type Message struct { + ID uint `gorm:"primaryKey"` + CreatedAt time.Time + Metadata MessageMeta + + ConversationID *uint `gorm:"index"` + Conversation *Conversation `gorm:"foreignKey:ConversationID"` + ParentID *uint + Parent *Message `gorm:"foreignKey:ParentID"` + Replies []Message `gorm:"foreignKey:ParentID"` + SelectedReplyID *uint + SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` + + Role api.MessageRole + Content string + ToolCalls ToolCalls // a json array of tool calls (from the model) + ToolResults ToolResults // a json array of tool results +} + +func (m *MessageMeta) Scan(value interface{}) error { + return json.Unmarshal(value.([]byte), m) +} + +func (m MessageMeta) Value() (driver.Value, error) { + return json.Marshal(m) +} + +type ToolCalls []api.ToolCall + +func (tc *ToolCalls) Scan(value any) (err error) { + s := value.(string) + if value == nil || s == "" { + *tc = nil + return + } + err = json.Unmarshal([]byte(s), tc) + return +} + +func (tc ToolCalls) Value() (driver.Value, error) { + if len(tc) == 0 { + return "", nil + } + jsonBytes, err := json.Marshal(tc) + if err != nil { + return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err) + } + return string(jsonBytes), nil +} + +type ToolResults []api.ToolResult + +func (tr *ToolResults) Scan(value any) (err error) { + s := value.(string) + if value == nil || s == "" { + *tr = nil + return + } + err = json.Unmarshal([]byte(s), tr) + return +} + +func (tr ToolResults) Value() (driver.Value, error) { + if len(tr) == 0 { + return "", nil + } + jsonBytes, err := json.Marshal([]api.ToolResult(tr)) + if err != nil { + return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err) + } + return string(jsonBytes), nil +} diff --git a/pkg/lmcli/store.go b/pkg/conversation/repo.go similarity index 57% rename from pkg/lmcli/store.go rename to pkg/conversation/repo.go index b21eacb..47db744 100644 --- a/pkg/lmcli/store.go +++ b/pkg/conversation/repo.go @@ -1,4 +1,4 @@ -package lmcli +package conversation import ( "database/sql" @@ -8,43 +8,57 @@ import ( "strings" "time" - "git.mlow.ca/mlow/lmcli/pkg/api" sqids "github.com/sqids/sqids-go" "gorm.io/gorm" ) -type ConversationStore interface { - ConversationByShortName(shortName string) (*api.Conversation, error) +// Repo exposes low-level message and conversation management. See +// Service for high-level helpers +type Repo interface { + // LatestConversationMessages returns a slice of all conversations ordered by when they were last updated (newest to oldest) + LatestConversationMessages() ([]Message, error) + + FindConversationByShortName(shortName string) (*Conversation, error) ConversationShortNameCompletions(search string) []string - RootMessages(conversationID uint) ([]api.Message, error) - LatestConversationMessages() ([]api.Message, error) + GetConversationByID(int uint) (*Conversation, error) + GetRootMessages(conversationID uint) ([]Message, error) - StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) - UpdateConversation(conversation *api.Conversation) error - DeleteConversation(conversation *api.Conversation) error - CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) + CreateConversation(title string) (*Conversation, error) + UpdateConversation(*Conversation) error + DeleteConversation(*Conversation) error - MessageByID(messageID uint) (*api.Message, error) - MessageReplies(messageID uint) ([]api.Message, error) + GetMessageByID(messageID uint) (*Message, error) - UpdateMessage(message *api.Message) error - DeleteMessage(message *api.Message, prune bool) error - CloneBranch(toClone api.Message) (*api.Message, uint, error) - Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) + SaveMessage(message Message) (*Message, error) + UpdateMessage(message *Message) error + DeleteMessage(message *Message, prune bool) error + CloneBranch(toClone Message) (*Message, uint, error) + Reply(to *Message, messages ...Message) ([]Message, error) - PathToRoot(message *api.Message) ([]api.Message, error) - PathToLeaf(message *api.Message) ([]api.Message, error) + PathToRoot(message *Message) ([]Message, error) + PathToLeaf(message *Message) ([]Message, error) + + // Retrieves and return the "selected thread" of the conversation. + // The "selected thread" of the conversation is a chain of messages + // starting from the Conversation's SelectedRoot Message, following each + // Message's SelectedReply until the tail Message is reached. + GetSelectedThread(*Conversation) ([]Message, error) + + // Start a new conversation with the given messages + StartConversation(messages ...Message) (*Conversation, []Message, error) + + CloneConversation(toClone Conversation) (*Conversation, uint, error) } -type SQLStore struct { +type repo struct { db *gorm.DB sqids *sqids.Sqids } -func NewSQLStore(db *gorm.DB) (*SQLStore, error) { +func NewRepo(db *gorm.DB) (Repo, error) { models := []any{ - &api.Conversation{}, - &api.Message{}, + &Conversation{}, + &Message{}, } for _, x := range models { @@ -55,12 +69,70 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) { } _sqids, _ := sqids.New(sqids.Options{MinLength: 4}) - return &SQLStore{db, _sqids}, nil + return &repo{db, _sqids}, nil } -func (s *SQLStore) createConversation() (*api.Conversation, error) { +func (s *repo) LatestConversationMessages() ([]Message, error) { + var latestMessages []Message + + subQuery := s.db.Model(&Message{}). + Select("MAX(created_at) as max_created_at, conversation_id"). + Group("conversation_id") + + err := s.db.Model(&Message{}). + Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). + Group("messages.conversation_id"). + Order("created_at DESC"). + Preload("Conversation.SelectedRoot"). + Find(&latestMessages).Error + + if err != nil { + return nil, err + } + + return latestMessages, nil +} + +func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) { + if shortName == "" { + return nil, errors.New("shortName is empty") + } + var conversation Conversation + err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error + return &conversation, err +} + +func (s *repo) ConversationShortNameCompletions(shortName string) []string { + var conversations []Conversation + // ignore error for completions + s.db.Find(&conversations) + completions := make([]string, 0, len(conversations)) + for _, conversation := range conversations { + if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { + completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) + } + } + return completions +} + +func (s *repo) GetConversationByID(id uint) (*Conversation, error) { + var conversation Conversation + err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error + if err != nil { + return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err) + } + + rootMessages, err := s.GetRootMessages(id) + if err != nil { + return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err) + } + conversation.RootMessages = rootMessages + return &conversation, nil +} + +func (s *repo) CreateConversation(title string) (*Conversation, error) { // Create the new conversation - c := &api.Conversation{} + c := &Conversation{Title: title} err := s.db.Save(c).Error if err != nil { return nil, err @@ -75,159 +147,54 @@ func (s *SQLStore) createConversation() (*api.Conversation, error) { return c, nil } -func (s *SQLStore) UpdateConversation(c *api.Conversation) error { +func (s *repo) UpdateConversation(c *Conversation) error { if c == nil || c.ID == 0 { return fmt.Errorf("Conversation is nil or invalid (missing ID)") } return s.db.Updates(c).Error } -func (s *SQLStore) DeleteConversation(c *api.Conversation) error { +func (s *repo) DeleteConversation(c *Conversation) error { + if c == nil || c.ID == 0 { + return fmt.Errorf("Conversation is nil or invalid (missing ID)") + } // Delete messages first - err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error + err := s.db.Where("conversation_id = ?", c.ID).Delete(&Message{}).Error if err != nil { return err } return s.db.Delete(c).Error } -func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error { - panic("Not yet implemented") - //return s.db.Delete(&message).Error +func (s *repo) SaveMessage(m Message) (*Message, error) { + if m.Conversation == nil { + return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)") + } + newMessage := m + newMessage.ID = 0 + return &newMessage, s.db.Create(&newMessage).Error } -func (s *SQLStore) UpdateMessage(m *api.Message) error { +func (s *repo) UpdateMessage(m *Message) error { if m == nil || m.ID == 0 { return fmt.Errorf("Message is nil or invalid (missing ID)") } return s.db.Updates(m).Error } -func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { - var conversations []api.Conversation - // ignore error for completions - s.db.Find(&conversations) - completions := make([]string, 0, len(conversations)) - for _, conversation := range conversations { - if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) { - completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title)) - } - } - return completions +func (s *repo) DeleteMessage(message *Message, prune bool) error { + return s.db.Delete(&message).Error } -func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) { - if shortName == "" { - return nil, errors.New("shortName is empty") - } - var conversation api.Conversation - err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error - return &conversation, err -} - -func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) { - var rootMessages []api.Message - err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error - if err != nil { - return nil, err - } - return rootMessages, nil -} - -func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) { - var message api.Message +func (s *repo) GetMessageByID(messageID uint) (*Message, error) { + var message Message err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error return &message, err } -func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) { - var replies []api.Message - err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error - return replies, err -} - -// StartConversation starts a new conversation with the provided messages -func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) { - if len(messages) == 0 { - return nil, nil, fmt.Errorf("Must provide at least 1 message") - } - - // Create new conversation - conversation, err := s.createConversation() - if err != nil { - return nil, nil, err - } - - // Create first message - messages[0].Conversation = conversation - err = s.db.Create(&messages[0]).Error - if err != nil { - return nil, nil, err - } - - // Update conversation's selected root message - conversation.SelectedRoot = &messages[0] - err = s.UpdateConversation(conversation) - if err != nil { - return nil, nil, err - } - - // Add additional replies to conversation - if len(messages) > 1 { - newMessages, err := s.Reply(&messages[0], messages[1:]...) - if err != nil { - return nil, nil, err - } - messages = append([]api.Message{messages[0]}, newMessages...) - } - return conversation, messages, nil -} - -// CloneConversation clones the given conversation and all of its root meesages -func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) { - rootMessages, err := s.RootMessages(toClone.ID) - if err != nil { - return nil, 0, err - } - - clone, err := s.createConversation() - if err != nil { - return nil, 0, fmt.Errorf("Could not create clone: %s", err) - } - clone.Title = toClone.Title + " - Clone" - - var errors []error - var messageCnt uint = 0 - for _, root := range rootMessages { - messageCnt++ - newRoot := root - newRoot.ConversationID = &clone.ID - - cloned, count, err := s.CloneBranch(newRoot) - if err != nil { - errors = append(errors, err) - continue - } - messageCnt += count - - if root.ID == *toClone.SelectedRootID { - clone.SelectedRootID = &cloned.ID - if err := s.UpdateConversation(clone); err != nil { - errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err)) - } - } - } - - if len(errors) > 0 { - return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors) - } - - return clone, messageCnt, nil -} - -// Reply to a message with a series of messages (each following the next) -func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) { - var savedMessages []api.Message +// Reply to a message with a series of messages (each followed by the next) +func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) { + var savedMessages []Message err := s.db.Transaction(func(tx *gorm.DB) error { currentParent := to @@ -262,17 +229,14 @@ func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Messag // CloneBranch returns a deep clone of the given message and its replies, returning // a new message object. The new message will be attached to the same parent as // the messageToClone -func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) { +func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) { newMessage := messageToClone newMessage.ID = 0 newMessage.Replies = nil newMessage.SelectedReplyID = nil newMessage.SelectedReply = nil - originalReplies, err := s.MessageReplies(messageToClone.ID) - if err != nil { - return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err) - } + originalReplies := messageToClone.Replies if err := s.db.Create(&newMessage).Error; err != nil { return nil, 0, fmt.Errorf("Could not clone message: %s", err) @@ -304,19 +268,19 @@ func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, return &newMessage, replyCount, nil } -func fetchMessages(db *gorm.DB) ([]api.Message, error) { - var messages []api.Message +func fetchMessages(db *gorm.DB) ([]Message, error) { + var messages []Message if err := db.Preload("Conversation").Find(&messages).Error; err != nil { return nil, fmt.Errorf("Could not fetch messages: %v", err) } - messageMap := make(map[uint]api.Message) + messageMap := make(map[uint]Message) for i, message := range messages { messageMap[messages[i].ID] = message } // Create a map to store replies by their parent ID - repliesMap := make(map[uint][]api.Message) + repliesMap := make(map[uint][]Message) for i, message := range messages { if messages[i].ParentID != nil { repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) @@ -326,7 +290,7 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) { // Assign replies, parent, and selected reply to each message for i := range messages { if replies, exists := repliesMap[messages[i].ID]; exists { - messages[i].Replies = make([]api.Message, len(replies)) + messages[i].Replies = make([]Message, len(replies)) for j, m := range replies { messages[i].Replies[j] = m } @@ -345,21 +309,51 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) { return messages, nil } -func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) { - var messages []api.Message +func (r repo) GetRootMessages(conversationID uint) ([]Message, error) { + var rootMessages []Message + err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error + if err != nil { + return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err) + } + return rootMessages, nil +} + +func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) { + var messages []Message messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) if err != nil { return nil, err } // Create a map to store messages by their ID - messageMap := make(map[uint]*api.Message) + messageMap := make(map[uint]*Message, len(messages)) for i := range messages { messageMap[messages[i].ID] = &messages[i] } + // Construct Replies + repliesMap := make(map[uint][]*Message, len(messages)) + for _, m := range messageMap { + if m.ParentID == nil { + continue + } + if p, ok := messageMap[*m.ParentID]; ok { + repliesMap[p.ID] = append(repliesMap[p.ID], m) + } + } + + // Add replies to messages + for _, m := range messageMap { + if replies, ok := repliesMap[m.ID]; ok { + m.Replies = make([]Message, len(replies)) + for idx, reply := range replies { + m.Replies[idx] = *reply + } + } + } + // Build the path - var path []api.Message + var path []Message nextID := &message.ID for { @@ -382,12 +376,12 @@ func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *u // PathToRoot traverses the provided message's Parent until reaching the tree // root and returns a slice of all messages traversed in chronological order // (starting with the root and ending with the message provided) -func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) { +func (s *repo) PathToRoot(message *Message) ([]Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") } - path, err := s.buildPath(message, func(m *api.Message) *uint { + path, err := s.buildPath(message, func(m *Message) *uint { return m.ParentID }) if err != nil { @@ -401,33 +395,99 @@ func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) { // PathToLeaf traverses the provided message's SelectedReply until reaching a // tree leaf and returns a slice of all messages traversed in chronological // order (starting with the message provided and ending with the leaf) -func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) { +func (s *repo) PathToLeaf(message *Message) ([]Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") } - return s.buildPath(message, func(m *api.Message) *uint { + return s.buildPath(message, func(m *Message) *uint { return m.SelectedReplyID }) } -func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) { - var latestMessages []api.Message - - subQuery := s.db.Model(&api.Message{}). - Select("MAX(created_at) as max_created_at, conversation_id"). - Group("conversation_id") - - err := s.db.Model(&api.Message{}). - Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). - Group("messages.conversation_id"). - Order("created_at DESC"). - Preload("Conversation.SelectedRoot"). - Find(&latestMessages).Error - - if err != nil { - return nil, err +func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) { + if len(messages) == 0 { + return nil, nil, fmt.Errorf("Must provide at least 1 message") } - return latestMessages, nil + // Create new conversation + conversation, err := s.CreateConversation("") + if err != nil { + return nil, nil, err + } + messages[0].Conversation = conversation + + // Create first message + firstMessage, err := s.SaveMessage(messages[0]) + if err != nil { + return nil, nil, err + } + messages[0] = *firstMessage + + // Update conversation's selected root message + conversation.RootMessages = []Message{messages[0]} + conversation.SelectedRoot = &messages[0] + err = s.UpdateConversation(conversation) + if err != nil { + return nil, nil, err + } + + // Add additional replies to conversation + if len(messages) > 1 { + newMessages, err := s.Reply(&messages[0], messages[1:]...) + if err != nil { + return nil, nil, err + } + messages = append([]Message{messages[0]}, newMessages...) + } + return conversation, messages, nil +} + + +// CloneConversation clones the given conversation and all of its meesages +func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) { + rootMessages, err := s.GetRootMessages(toClone.ID) + if err != nil { + return nil, 0, fmt.Errorf("Could not create clone: %v", err) + } + + clone, err := s.CreateConversation(toClone.Title + " - Clone") + if err != nil { + return nil, 0, fmt.Errorf("Could not create clone: %v", err) + } + + var errors []error + var messageCnt uint = 0 + for _, root := range rootMessages { + messageCnt++ + newRoot := root + newRoot.ConversationID = &clone.ID + + cloned, count, err := s.CloneBranch(newRoot) + if err != nil { + errors = append(errors, err) + continue + } + messageCnt += count + + if root.ID == *toClone.SelectedRootID { + clone.SelectedRootID = &cloned.ID + if err := s.UpdateConversation(clone); err != nil { + errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err)) + } + } + } + + if len(errors) > 0 { + return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors) + } + + return clone, messageCnt, nil +} + +func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) { + if c.SelectedRoot == nil { + return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug") + } + return s.PathToLeaf(c.SelectedRoot) } diff --git a/pkg/conversation/tools.go b/pkg/conversation/tools.go new file mode 100644 index 0000000..ca61e88 --- /dev/null +++ b/pkg/conversation/tools.go @@ -0,0 +1,55 @@ +package conversation + +import ( + "git.mlow.ca/mlow/lmcli/pkg/api" +) + +// ApplySystemPrompt updates the contents of an existing system Message if it +// exists, or returns a new slice with the system Message prepended. +func ApplySystemPrompt(m []Message, system string, force bool) []Message { + if len(m) > 0 && m[0].Role == api.MessageRoleSystem { + if force { + m[0].Content = system + } + return m + } else { + return append([]Message{{ + Role: api.MessageRoleSystem, + Content: system, + }}, m...) + } +} + +func MessageToAPI(m Message) api.Message { + return api.Message{ + Role: m.Role, + Content: m.Content, + ToolCalls: m.ToolCalls, + ToolResults: m.ToolResults, + } +} + +func MessagesToAPI(messages []Message) []api.Message { + ret := make([]api.Message, 0, len(messages)) + for _, m := range messages { + ret = append(ret, MessageToAPI(m)) + } + return ret +} + +func MessageFromAPI(m api.Message) Message { + return Message{ + Role: m.Role, + Content: m.Content, + ToolCalls: m.ToolCalls, + ToolResults: m.ToolResults, + } +} + +func MessagesFromAPI(messages []api.Message) []Message { + ret := make([]Message, 0, len(messages)) + for _, m := range messages { + ret = append(ret, MessageFromAPI(m)) + } + return ret +} diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index e8b3d48..c216278 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -12,11 +12,12 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" - "git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic" - "git.mlow.ca/mlow/lmcli/pkg/api/provider/google" - "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" - "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" + "git.mlow.ca/mlow/lmcli/pkg/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider/anthropic" + "git.mlow.ca/mlow/lmcli/pkg/provider/google" + "git.mlow.ca/mlow/lmcli/pkg/provider/ollama" + "git.mlow.ca/mlow/lmcli/pkg/provider/openai" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util/tty" "gorm.io/driver/sqlite" @@ -33,7 +34,7 @@ type Agent struct { type Context struct { // high level app configuration, may be mutated at runtime Config Config - Store ConversationStore + Conversations conversation.Repo Chroma *tty.ChromaHighlighter } @@ -44,7 +45,7 @@ func NewContext() (*Context, error) { return nil, err } - store, err := getConversationStore() + store, err := getConversationService() if err != nil { return nil, err } @@ -69,17 +70,16 @@ func createOrOpenAppend(path string) (*os.File, error) { return file, nil } -func getConversationStore() (ConversationStore, error) { +func getConversationService() (conversation.Repo, error) { databaseFile := filepath.Join(dataDir(), "conversations.db") - gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log")) if err != nil { return nil, fmt.Errorf("Could not open database log file: %v", err) } db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{ - Logger: logger.New(log.New(gormLogFile, "", log.LstdFlags), logger.Config{ + Logger: logger.New(log.New(gormLogFile, "\n", log.LstdFlags), logger.Config{ SlowThreshold: 200 * time.Millisecond, - LogLevel: logger.Warn, + LogLevel: logger.Info, IgnoreRecordNotFoundError: false, Colorful: true, }), @@ -87,11 +87,11 @@ func getConversationStore() (ConversationStore, error) { if err != nil { return nil, fmt.Errorf("Error establishing connection to store: %v", err) } - store, err := NewSQLStore(db) + repo, err := conversation.NewRepo(db) if err != nil { return nil, err } - return store, nil + return repo, nil } func (c *Context) GetModels() (models []string) { diff --git a/pkg/api/provider/anthropic/anthropic.go b/pkg/provider/anthropic/anthropic.go similarity index 98% rename from pkg/api/provider/anthropic/anthropic.go rename to pkg/provider/anthropic/anthropic.go index e471959..5b4cd9a 100644 --- a/pkg/api/provider/anthropic/anthropic.go +++ b/pkg/provider/anthropic/anthropic.go @@ -11,7 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider" ) const ANTHROPIC_VERSION = "2023-06-01" @@ -439,15 +439,9 @@ func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) } } - message := &api.Message{ - Role: api.MessageRoleAssistant, - Content: content.String(), - ToolCalls: toolCalls, - } - if len(toolCalls) > 0 { - message.Role = api.MessageRoleToolCall + return api.NewMessageWithToolCalls(content.String(), toolCalls), nil } - return message, nil + return api.NewMessageWithAssistant(content.String()), nil } diff --git a/pkg/api/provider/google/google.go b/pkg/provider/google/google.go similarity index 94% rename from pkg/api/provider/google/google.go rename to pkg/provider/google/google.go index d061d24..1d8bfe0 100644 --- a/pkg/api/provider/google/google.go +++ b/pkg/provider/google/google.go @@ -11,7 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider" ) type Client struct { @@ -337,17 +337,10 @@ func (c *Client) CreateChatCompletion( } if len(toolCalls) > 0 { - return &api.Message{ - Role: api.MessageRoleToolCall, - Content: content, - ToolCalls: convertToolCallToAPI(toolCalls), - }, nil + return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil } - return &api.Message{ - Role: api.MessageRoleAssistant, - Content: content, - }, nil + return api.NewMessageWithAssistant(content), nil } func (c *Client) CreateChatCompletionStream( @@ -435,17 +428,9 @@ func (c *Client) CreateChatCompletionStream( } } - // If there are function calls, handle them and recurse if len(toolCalls) > 0 { - return &api.Message{ - Role: api.MessageRoleToolCall, - Content: content.String(), - ToolCalls: convertToolCallToAPI(toolCalls), - }, nil + return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil } - return &api.Message{ - Role: api.MessageRoleAssistant, - Content: content.String(), - }, nil + return api.NewMessageWithAssistant(content.String()), nil } diff --git a/pkg/api/provider/ollama/ollama.go b/pkg/provider/ollama/ollama.go similarity index 94% rename from pkg/api/provider/ollama/ollama.go rename to pkg/provider/ollama/ollama.go index 264aca7..5b860bb 100644 --- a/pkg/api/provider/ollama/ollama.go +++ b/pkg/provider/ollama/ollama.go @@ -11,7 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider" ) type OllamaClient struct { @@ -115,10 +115,7 @@ func (c *OllamaClient) CreateChatCompletion( return nil, err } - return &api.Message{ - Role: api.MessageRoleAssistant, - Content: completionResp.Message.Content, - }, nil + return api.NewMessageWithAssistant(completionResp.Message.Content), nil } func (c *OllamaClient) CreateChatCompletionStream( @@ -182,8 +179,5 @@ func (c *OllamaClient) CreateChatCompletionStream( } } - return &api.Message{ - Role: api.MessageRoleAssistant, - Content: content.String(), - }, nil + return api.NewMessageWithAssistant(content.String()), nil } diff --git a/pkg/api/provider/openai/openai.go b/pkg/provider/openai/openai.go similarity index 94% rename from pkg/api/provider/openai/openai.go rename to pkg/provider/openai/openai.go index 318c392..9d1f567 100644 --- a/pkg/api/provider/openai/openai.go +++ b/pkg/provider/openai/openai.go @@ -11,7 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider" ) type OpenAIClient struct { @@ -253,17 +253,10 @@ func (c *OpenAIClient) CreateChatCompletion( toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { - return &api.Message{ - Role: api.MessageRoleToolCall, - Content: content, - ToolCalls: convertToolCallToAPI(toolCalls), - }, nil + return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil } - return &api.Message{ - Role: api.MessageRoleAssistant, - Content: content, - }, nil + return api.NewMessageWithAssistant(content), nil } func (c *OpenAIClient) CreateChatCompletionStream( @@ -343,15 +336,8 @@ func (c *OpenAIClient) CreateChatCompletionStream( } if len(toolCalls) > 0 { - return &api.Message{ - Role: api.MessageRoleToolCall, - Content: content.String(), - ToolCalls: convertToolCallToAPI(toolCalls), - }, nil + return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil } - return &api.Message{ - Role: api.MessageRoleAssistant, - Content: content.String(), - }, nil + return api.NewMessageWithAssistant(content.String()), nil } diff --git a/pkg/api/provider/provider.go b/pkg/provider/provider.go similarity index 100% rename from pkg/api/provider/provider.go rename to pkg/provider/provider.go diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 198220c..529e869 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -6,30 +6,30 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "git.mlow.ca/mlow/lmcli/pkg/provider" "github.com/charmbracelet/lipgloss" ) type LoadedConversation struct { - Conv api.Conversation - LastReply api.Message + Conv conversation.Conversation + LastReply conversation.Message } type AppModel struct { Ctx *lmcli.Context Conversations []LoadedConversation - Conversation *api.Conversation - RootMessages []api.Message - Messages []api.Message + Conversation *conversation.Conversation + Messages []conversation.Message Model string ProviderName string Provider provider.ChatCompletionProvider - Agent *lmcli.Agent + Agent *lmcli.Agent } -func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel { +func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel { app := &AppModel{ Ctx: ctx, Conversation: initialConversation, @@ -67,8 +67,7 @@ const ( func (m *AppModel) ClearConversation() { m.Conversation = nil - m.Messages = []api.Message{} - m.RootMessages = []api.Message{} + m.Messages = []conversation.Message{} } func (m *AppModel) ApplySystemPrompt() { @@ -81,7 +80,7 @@ func (m *AppModel) ApplySystemPrompt() { system = m.Ctx.DefaultSystemPrompt() } if system != "" { - m.Messages = api.ApplySystemPrompt(m.Messages, system, false) + m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false) } } @@ -91,7 +90,7 @@ func (m *AppModel) NewConversation() { } func (m *AppModel) LoadConversations() (error, []LoadedConversation) { - messages, err := m.Ctx.Store.LatestConversationMessages() + messages, err := m.Ctx.Conversations.LatestConversationMessages() if err != nil { return fmt.Errorf("Could not load conversations: %v", err), nil } @@ -106,42 +105,34 @@ func (m *AppModel) LoadConversations() (error, []LoadedConversation) { return nil, conversations } -func (a *AppModel) LoadConversationRootMessages() ([]api.Message, error) { - messages, err := a.Ctx.Store.RootMessages(a.Conversation.ID) - if err != nil { - return nil, fmt.Errorf("Could not load conversation root messages: %v %v", a.Conversation.SelectedRoot, err) - } - return messages, nil -} - -func (a *AppModel) LoadConversationMessages() ([]api.Message, error) { - messages, err := a.Ctx.Store.PathToLeaf(a.Conversation.SelectedRoot) +func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) { + messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot) if err != nil { return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err) } return messages, nil } -func (a *AppModel) GenerateConversationTitle(messages []api.Message) (string, error) { +func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) { return cmdutil.GenerateTitle(a.Ctx, messages) } -func (a *AppModel) UpdateConversationTitle(conversation *api.Conversation) error { - return a.Ctx.Store.UpdateConversation(conversation) +func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error { + return a.Ctx.Conversations.UpdateConversation(conversation) } -func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Message, error) { - msg, _, err := a.Ctx.Store.CloneBranch(message) +func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) { + msg, _, err := a.Ctx.Conversations.CloneBranch(message) if err != nil { return nil, fmt.Errorf("Could not clone message: %v", err) } if selected { if msg.Parent == nil { msg.Conversation.SelectedRoot = msg - err = a.Ctx.Store.UpdateConversation(msg.Conversation) + err = a.Ctx.Conversations.UpdateConversation(msg.Conversation) } else { msg.Parent.SelectedReply = msg - err = a.Ctx.Store.UpdateMessage(msg.Parent) + err = a.Ctx.Conversations.UpdateMessage(msg.Parent) } if err != nil { return nil, fmt.Errorf("Could not update selected message: %v", err) @@ -150,11 +141,11 @@ func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Messag return msg, nil } -func (a *AppModel) UpdateMessageContent(message *api.Message) error { - return a.Ctx.Store.UpdateMessage(message) +func (a *AppModel) UpdateMessageContent(message *conversation.Message) error { + return a.Ctx.Conversations.UpdateMessage(message) } -func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) { +func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) { currentIndex := -1 for i, reply := range choices { if reply.ID == selected.ID { @@ -176,25 +167,25 @@ func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir Mess return &choices[next], nil } -func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) { - if len(rootMessages) < 2 { +func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) { + if len(conv.RootMessages) < 2 { return nil, nil } - nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, rootMessages, dir) + nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir) if err != nil { return nil, err } conv.SelectedRoot = nextRoot - err = a.Ctx.Store.UpdateConversation(conv) + err = a.Ctx.Conversations.UpdateConversation(conv) if err != nil { return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err) } return nextRoot, nil } -func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDirection) (*api.Message, error) { +func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) { if len(message.Replies) < 2 { return nil, nil } @@ -205,17 +196,17 @@ func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDire } message.SelectedReply = nextReply - err = a.Ctx.Store.UpdateMessage(message) + err = a.Ctx.Conversations.UpdateMessage(message) if err != nil { return nil, fmt.Errorf("Could not update message SelectedReply: %v", err) } return nextReply, nil } -func (a *AppModel) PersistConversation(conversation *api.Conversation, messages []api.Message) (*api.Conversation, []api.Message, error) { +func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) { var err error if conversation == nil || conversation.ID == 0 { - conversation, messages, err = a.Ctx.Store.StartConversation(messages...) + conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...) if err != nil { return nil, nil, fmt.Errorf("Could not start new conversation: %v", err) } @@ -224,12 +215,12 @@ func (a *AppModel) PersistConversation(conversation *api.Conversation, messages for i := range messages { if messages[i].ID > 0 { - err := a.Ctx.Store.UpdateMessage(&messages[i]) + err := a.Ctx.Conversations.UpdateMessage(&messages[i]) if err != nil { return nil, nil, err } } else if i > 0 { - saved, err := a.Ctx.Store.Reply(&messages[i-1], messages[i]) + saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i]) if err != nil { return nil, nil, err } @@ -251,10 +242,10 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, } func (a *AppModel) Prompt( - messages []api.Message, + messages []conversation.Message, chatReplyChunks chan provider.Chunk, stopSignal chan struct{}, -) (*api.Message, error) { +) (*conversation.Message, error) { model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) if err != nil { return nil, err @@ -280,11 +271,14 @@ func (a *AppModel) Prompt( }() msg, err := p.CreateChatCompletionStream( - ctx, params, messages, chatReplyChunks, + ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks, ) + if msg != nil { + msg := conversation.MessageFromAPI(*msg) msg.Metadata.GenerationProvider = &a.ProviderName msg.Metadata.GenerationModel = &a.Model + return &msg, err } - return msg, err + return nil, err } diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index 23db52c..751193d 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -3,7 +3,7 @@ package tui import ( "fmt" - "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" @@ -130,13 +130,13 @@ func (m *Model) View() string { } type LaunchOptions struct { - InitialConversation *api.Conversation + InitialConversation *conversation.Conversation InitialView shared.View } type LaunchOption func(*LaunchOptions) -func WithInitialConversation(conv *api.Conversation) LaunchOption { +func WithInitialConversation(conv *conversation.Conversation) LaunchOption { return func(opts *LaunchOptions) { opts.InitialConversation = conv } diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index 41ba187..848766e 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -4,7 +4,8 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/api/provider" + "git.mlow.ca/mlow/lmcli/pkg/provider" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/spinner" @@ -20,14 +21,12 @@ type ( msgConversationTitleGenerated string // sent when the conversation has been persisted, triggers a reload of contents msgConversationPersisted struct { - isNew bool - conversation *api.Conversation - messages []api.Message + conversation *conversation.Conversation + messages []conversation.Message } // sent when a conversation's messages are laoded msgConversationMessagesLoaded struct { - messages []api.Message - rootMessages []api.Message + messages []conversation.Message } // a special case of common.MsgError that stops the response waiting animation msgChatResponseError struct { @@ -36,19 +35,19 @@ type ( // sent on each chunk received from LLM msgChatResponseChunk provider.Chunk // sent on each completed reply - msgChatResponse *api.Message + msgChatResponse *conversation.Message // sent when the response is canceled msgChatResponseCanceled struct{} // sent when results from a tool call are returned msgToolResults []api.ToolResult // sent when the given message is made the new selected reply of its parent - msgSelectedReplyCycled *api.Message + msgSelectedReplyCycled *conversation.Message // sent when the given message is made the new selected root of the current conversation - msgSelectedRootCycled *api.Message + msgSelectedRootCycled *conversation.Message // sent when a message's contents are updated and saved - msgMessageUpdated *api.Message + msgMessageUpdated *conversation.Message // sent when a message is cloned, with the cloned message - msgMessageCloned *api.Message + msgMessageCloned *conversation.Message ) type focusState int @@ -84,7 +83,7 @@ type Model struct { selectedMessage int editorTarget editorTarget stopSignal chan struct{} - replyChan chan api.Message + replyChan chan conversation.Message chatReplyChunks chan provider.Chunk persistence bool // whether we will save new messages in the conversation @@ -137,7 +136,7 @@ func Chat(app *model.AppModel) *Model { persistence: true, stopSignal: make(chan struct{}), - replyChan: make(chan api.Message), + replyChan: make(chan conversation.Message), chatReplyChunks: make(chan provider.Chunk), wrap: true, diff --git a/pkg/tui/views/chat/cmds.go b/pkg/tui/views/chat/cmds.go index 078c4c6..9b2e69e 100644 --- a/pkg/tui/views/chat/cmds.go +++ b/pkg/tui/views/chat/cmds.go @@ -4,6 +4,7 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" tea "github.com/charmbracelet/bubbletea" @@ -21,13 +22,7 @@ func (m *Model) loadConversationMessages() tea.Cmd { if err != nil { return shared.AsMsgError(err) } - rootMessages, err := m.App.LoadConversationRootMessages() - if err != nil { - return shared.AsMsgError(err) - } - return msgConversationMessagesLoaded{ - messages, rootMessages, - } + return msgConversationMessagesLoaded{messages} } } @@ -41,7 +36,7 @@ func (m *Model) generateConversationTitle() tea.Cmd { } } -func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd { +func (m *Model) updateConversationTitle(conversation *conversation.Conversation) tea.Cmd { return func() tea.Msg { err := m.App.UpdateConversationTitle(conversation) if err != nil { @@ -51,7 +46,7 @@ func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd } } -func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd { +func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd { return func() tea.Msg { msg, err := m.App.CloneMessage(message, selected) if err != nil { @@ -61,7 +56,7 @@ func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd { } } -func (m *Model) updateMessageContent(message *api.Message) tea.Cmd { +func (m *Model) updateMessageContent(message *conversation.Message) tea.Cmd { return func() tea.Msg { err := m.App.UpdateMessageContent(message) if err != nil { @@ -71,14 +66,13 @@ func (m *Model) updateMessageContent(message *api.Message) tea.Cmd { } } -func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir model.MessageCycleDirection) tea.Cmd { - if len(m.App.RootMessages) < 2 { - +func (m *Model) cycleSelectedRoot(conv *conversation.Conversation, dir model.MessageCycleDirection) tea.Cmd { + if len(conv.RootMessages) < 2 { return nil } return func() tea.Msg { - nextRoot, err := m.App.CycleSelectedRoot(conv, m.App.RootMessages, dir) + nextRoot, err := m.App.CycleSelectedRoot(conv, dir) if err != nil { return shared.WrapError(err) } @@ -86,7 +80,7 @@ func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir model.MessageCycle } } -func (m *Model) cycleSelectedReply(message *api.Message, dir model.MessageCycleDirection) tea.Cmd { +func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.MessageCycleDirection) tea.Cmd { if len(message.Replies) < 2 { return nil } @@ -106,7 +100,7 @@ func (m *Model) persistConversation() tea.Cmd { if err != nil { return shared.AsMsgError(err) } - return msgConversationPersisted{conversation.ID == 0, conversation, messages} + return msgConversationPersisted{conversation, messages} } } diff --git a/pkg/tui/views/chat/input.go b/pkg/tui/views/chat/input.go index 9f2ca91..4cca580 100644 --- a/pkg/tui/views/chat/input.go +++ b/pkg/tui/views/chat/input.go @@ -5,6 +5,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" @@ -70,12 +71,12 @@ func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd { } func (m *Model) scrollSelection(dir int) { - if m.selectedMessage + dir < 0 || m.selectedMessage + dir >= len(m.App.Messages) { + if m.selectedMessage+dir < 0 || m.selectedMessage+dir >= len(m.App.Messages) { return } newIdx := m.selectedMessage - for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir{ + for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir { if !m.showDetails && m.App.Messages[i].Role.IsSystem() { continue } @@ -175,7 +176,7 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd { return shared.WrapError(fmt.Errorf("Can't reply to a user message")) } - m.addMessage(api.Message{ + m.addMessage(conversation.Message{ Role: api.MessageRoleUser, Content: input, }) diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index 83d275d..ad06a4c 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -5,13 +5,14 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" "github.com/charmbracelet/bubbles/cursor" tea "github.com/charmbracelet/bubbletea" ) -func (m *Model) setMessage(i int, msg api.Message) { +func (m *Model) setMessage(i int, msg conversation.Message) { if i >= len(m.App.Messages) { panic("i out of range") } @@ -19,7 +20,7 @@ func (m *Model) setMessage(i int, msg api.Message) { m.messageCache[i] = m.renderMessage(i) } -func (m *Model) addMessage(msg api.Message) { +func (m *Model) addMessage(msg conversation.Message) { m.App.Messages = append(m.App.Messages, msg) m.messageCache = append(m.messageCache, m.renderMessage(len(m.App.Messages)-1)) } @@ -95,7 +96,6 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { } } case msgConversationMessagesLoaded: - m.App.RootMessages = msg.rootMessages m.App.Messages = msg.messages if m.selectedMessage == -1 { m.selectedMessage = len(msg.messages) - 1 @@ -117,7 +117,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { m.setMessageContents(last, m.App.Messages[last].Content+msg.Content) } else { // use chunk in a new message - m.addMessage(api.Message{ + m.addMessage(conversation.Message{ Role: api.MessageRoleAssistant, Content: msg.Content, }) @@ -133,7 +133,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case msgChatResponse: m.state = idle - reply := (*api.Message)(msg) + reply := (*conversation.Message)(msg) reply.Content = strings.TrimSpace(reply.Content) last := len(m.App.Messages) - 1 @@ -181,9 +181,9 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { panic("Previous message not a tool call, unexpected") } - m.addMessage(api.Message{ + m.addMessage(conversation.Message{ Role: api.MessageRoleToolResult, - ToolResults: api.ToolResults(msg), + ToolResults: conversation.ToolResults(msg), }) if m.persistence { @@ -207,15 +207,11 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case msgConversationPersisted: m.App.Conversation = msg.conversation m.App.Messages = msg.messages - if msg.isNew { - m.App.RootMessages = []api.Message{m.App.Messages[0]} - } m.rebuildMessageCache() m.updateContent() case msgMessageCloned: if msg.Parent == nil { m.App.Conversation = msg.Conversation - m.App.RootMessages = append(m.App.RootMessages, *msg) } cmds = append(cmds, m.loadConversationMessages()) case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated: diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index 01b173b..7dd3c6a 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -6,6 +6,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/tui/styles" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" "github.com/charmbracelet/lipgloss" @@ -44,7 +45,7 @@ var ( footerStyle = lipgloss.NewStyle().Padding(0, 1) ) -func (m *Model) renderMessageHeading(i int, message *api.Message) string { +func (m *Model) renderMessageHeading(i int, message *conversation.Message) string { friendly := message.Role.FriendlyRole() style := systemStyle @@ -70,15 +71,15 @@ func (m *Model) renderMessageHeading(i int, message *api.Message) string { prefix = " " } - if i == 0 && len(m.App.RootMessages) > 1 && m.App.Conversation.SelectedRootID != nil { + if i == 0 && len(m.App.Conversation.RootMessages) > 1 && m.App.Conversation.SelectedRootID != nil { selectedRootIndex := 0 - for j, reply := range m.App.RootMessages { + for j, reply := range m.App.Conversation.RootMessages { if reply.ID == *m.App.Conversation.SelectedRootID { selectedRootIndex = j break } } - suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.RootMessages))) + suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.Conversation.RootMessages))) } if i > 0 && len(m.App.Messages[i-1].Replies) > 1 { // Find the selected reply index @@ -230,9 +231,9 @@ func (m *Model) conversationMessagesView() string { // Render a placeholder for the incoming assistant reply if m.state == pendingResponse && m.App.Messages[len(m.App.Messages)-1].Role != api.MessageRoleAssistant { - heading := m.renderMessageHeading(-1, &api.Message{ + heading := m.renderMessageHeading(-1, &conversation.Message{ Role: api.MessageRoleAssistant, - Metadata: api.MessageMeta{ + Metadata: conversation.MessageMeta{ GenerationModel: &m.App.Model, }, }) diff --git a/pkg/tui/views/conversations/conversations.go b/pkg/tui/views/conversations/conversations.go index 1dd7d4d..1a4318d 100644 --- a/pkg/tui/views/conversations/conversations.go +++ b/pkg/tui/views/conversations/conversations.go @@ -5,7 +5,7 @@ import ( "strings" "time" - "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/tui/bubbles" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" @@ -21,7 +21,7 @@ type ( // sent when conversation list is loaded msgConversationsLoaded ([]model.LoadedConversation) // sent when a conversation is selected - msgConversationSelected api.Conversation + msgConversationSelected conversation.Conversation // sent when a conversation is deleted msgConversationDeleted struct{} ) @@ -154,7 +154,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case bubbles.MsgConfirmPromptAnswered: m.confirmPrompt.Blur() if msg.Value { - conv, ok := msg.Payload.(api.Conversation) + conv, ok := msg.Payload.(conversation.Conversation) if ok { cmds = append(cmds, m.deleteConversation(conv)) } @@ -188,9 +188,9 @@ func (m *Model) loadConversations() tea.Cmd { } } -func (m *Model) deleteConversation(conv api.Conversation) tea.Cmd { +func (m *Model) deleteConversation(conv conversation.Conversation) tea.Cmd { return func() tea.Msg { - err := m.App.Ctx.Store.DeleteConversation(&conv) + err := m.App.Ctx.Conversations.DeleteConversation(&conv) if err != nil { return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err)) }