Tweaks/cleanups to conversation management in tui

- Pass around message/conversation values instead of pointers where it
makes more sense, and store values instead of pointers in the globally
(within the TUI) shared `App` (pointers provide no utility here).

- Split conversation persistence into separate conversation/message
  saving stages
This commit is contained in:
Matt Low 2024-10-25 16:57:15 +00:00
parent 07c96082e7
commit ec21a02ec0
9 changed files with 116 additions and 86 deletions

View File

@ -72,6 +72,14 @@ func (m MessageRole) IsAssistant() bool {
return false return false
} }
func (m MessageRole) IsUser() bool {
switch m {
case MessageRoleUser, MessageRoleToolResult:
return true
}
return false
}
func (m MessageRole) IsSystem() bool { func (m MessageRole) IsSystem() bool {
switch m { switch m {
case MessageRoleSystem: case MessageRoleSystem:

View File

@ -25,6 +25,7 @@ type Repo interface {
CreateConversation(title string) (*Conversation, error) CreateConversation(title string) (*Conversation, error)
UpdateConversation(*Conversation) error UpdateConversation(*Conversation) error
DeleteConversation(*Conversation) error DeleteConversation(*Conversation) error
DeleteConversationById(id uint) error
GetMessageByID(messageID uint) (*Message, error) GetMessageByID(messageID uint) (*Message, error)
@ -71,7 +72,7 @@ func NewRepo(db *gorm.DB) (Repo, error) {
return &repo{db, _sqids}, nil return &repo{db, _sqids}, nil
} }
type conversationListItem struct { type ConversationListItem struct {
ID uint ID uint
ShortName string ShortName string
Title string Title string
@ -80,7 +81,7 @@ type conversationListItem struct {
type ConversationList struct { type ConversationList struct {
Total int Total int
Items []conversationListItem Items []ConversationListItem
} }
// LoadConversationList loads existing conversations, ordered by the date // LoadConversationList loads existing conversations, ordered by the date
@ -95,7 +96,7 @@ func (s *repo) LoadConversationList() (ConversationList, error) {
} }
for _, c := range convos { for _, c := range convos {
list.Items = append(list.Items, conversationListItem{ list.Items = append(list.Items, ConversationListItem{
ID: c.ID, ID: c.ID,
ShortName: c.ShortName.String, ShortName: c.ShortName.String,
Title: c.Title, Title: c.Title,
@ -147,7 +148,7 @@ func (s *repo) GetConversationByID(id uint) (*Conversation, error) {
func (s *repo) CreateConversation(title string) (*Conversation, error) { func (s *repo) CreateConversation(title string) (*Conversation, error) {
// Create the new conversation // Create the new conversation
c := &Conversation{Title: title} c := &Conversation{Title: title}
err := s.db.Save(c).Error err := s.db.Create(c).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -172,12 +173,18 @@ func (s *repo) DeleteConversation(c *Conversation) error {
if c == nil || c.ID == 0 { if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)") return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
// Delete messages first return s.DeleteConversationById(c.ID)
err := s.db.Where("conversation_id = ?", c.ID).Delete(&Message{}).Error }
func (s *repo) DeleteConversationById(id uint) error {
if id == 0 {
return fmt.Errorf("Invalid conversation ID: %d", id)
}
err := s.db.Where("conversation_id = ?", id).Delete(&Message{}).Error
if err != nil { if err != nil {
return err return err
} }
return s.db.Delete(c).Error return s.db.Where("id = ?", id).Delete(&Conversation{}).Error
} }
func (s *repo) SaveMessage(m Message) (*Message, error) { func (s *repo) SaveMessage(m Message) (*Message, error) {
@ -186,6 +193,7 @@ func (s *repo) SaveMessage(m Message) (*Message, error) {
} }
newMessage := m newMessage := m
newMessage.ID = 0 newMessage.ID = 0
newMessage.CreatedAt = time.Now()
return &newMessage, s.db.Create(&newMessage).Error return &newMessage, s.db.Create(&newMessage).Error
} }
@ -234,12 +242,15 @@ func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) {
savedMessages = append(savedMessages, message) savedMessages = append(savedMessages, message)
currentParent = &message currentParent = &message
} }
to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt
s.UpdateConversation(to.Conversation)
return nil return nil
}) })
if err != nil {
return savedMessages, err
}
to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt
err = s.UpdateConversation(to.Conversation)
return savedMessages, err return savedMessages, err
} }

View File

@ -16,7 +16,7 @@ import (
type AppModel struct { type AppModel struct {
Ctx *lmcli.Context Ctx *lmcli.Context
Conversations conversation.ConversationList Conversations conversation.ConversationList
Conversation *conversation.Conversation Conversation conversation.Conversation
Messages []conversation.Message Messages []conversation.Message
Model string Model string
ProviderName string ProviderName string
@ -27,12 +27,13 @@ type AppModel struct {
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel { func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{ app := &AppModel{
Ctx: ctx, Ctx: ctx,
Conversation: initialConversation,
Model: *ctx.Config.Defaults.Model, Model: *ctx.Config.Defaults.Model,
} }
if initialConversation == nil { if initialConversation == nil {
app.NewConversation() app.NewConversation()
} else {
} }
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
@ -61,7 +62,7 @@ const (
) )
func (m *AppModel) ClearConversation() { func (m *AppModel) ClearConversation() {
m.Conversation = nil m.Conversation = conversation.Conversation{}
m.Messages = []conversation.Message{} m.Messages = []conversation.Message{}
} }
@ -96,10 +97,6 @@ func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (s
return cmdutil.GenerateTitle(a.Ctx, messages) return cmdutil.GenerateTitle(a.Ctx, messages)
} }
func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error {
return a.Ctx.Conversations.UpdateConversation(conversation)
}
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) { func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
msg, _, err := a.Ctx.Conversations.CloneBranch(message) msg, _, err := a.Ctx.Conversations.CloneBranch(message)
if err != nil { if err != nil {
@ -182,33 +179,54 @@ func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir Message
return nextReply, nil return nextReply, nil
} }
func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) { func (a *AppModel) PersistMessages() ([]conversation.Message, error) {
var err error messages := make([]conversation.Message, len(a.Messages))
if conversation == nil || conversation.ID == 0 { for i, m := range a.Messages {
conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...) if i == 0 && m.ID == 0 {
if err != nil { m.Conversation = &a.Conversation
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err) m, err := a.Ctx.Conversations.SaveMessage(m)
}
return conversation, messages, nil
}
for i := range messages {
if messages[i].ID > 0 {
err := a.Ctx.Conversations.UpdateMessage(&messages[i])
if err != nil { if err != nil {
return nil, nil, err return nil, fmt.Errorf("Could not create first message %d: %v", a.Messages[i].ID, err)
} }
messages[i] = *m
// let's set the conversation root message(s), as this is the first message
m.Conversation.RootMessages = []conversation.Message{*m}
m.Conversation.SelectedRoot = &m.Conversation.RootMessages[0]
a.Ctx.Conversations.UpdateConversation(m.Conversation)
} else if m.ID > 0 {
// Existing message, update it
err := a.Ctx.Conversations.UpdateMessage(&m)
if err != nil {
return nil, fmt.Errorf("Could not update message %d: %v", a.Messages[i].ID, err)
}
messages[i] = m
} else if i > 0 { } else if i > 0 {
saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i]) // New message, reply to previous
replies, err := a.Ctx.Conversations.Reply(&messages[i-1], m)
if err != nil { if err != nil {
return nil, nil, err return nil, fmt.Errorf("Could not reply with new message: %v", err)
} }
messages[i] = saved[0] messages[i] = replies[0]
} else { } else {
return nil, nil, fmt.Errorf("Error: no messages to reply to") return nil, fmt.Errorf("No messages to reply to (this is a bug)")
} }
} }
return conversation, messages, nil return messages, nil
}
func (a *AppModel) PersistConversation() (conversation.Conversation, error) {
conv := a.Conversation
var err error
if a.Conversation.ID > 0 {
err = a.Ctx.Conversations.UpdateConversation(&conv)
} else {
c, e := a.Ctx.Conversations.CreateConversation("")
err = e
if e == nil && c != nil {
conv = *c
}
}
return conv, err
} }
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) { func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {

View File

@ -4,8 +4,8 @@ import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/conversation" "git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/model"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/spinner"
@ -20,10 +20,8 @@ type (
// sent when a new conversation title generated // sent when a new conversation title generated
msgConversationTitleGenerated string msgConversationTitleGenerated string
// sent when the conversation has been persisted, triggers a reload of contents // sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct { msgConversationPersisted conversation.Conversation
conversation *conversation.Conversation msgMessagesPersisted []conversation.Message
messages []conversation.Message
}
// sent when a conversation's messages are laoded // sent when a conversation's messages are laoded
msgConversationMessagesLoaded struct { msgConversationMessagesLoaded struct {
messages []conversation.Message messages []conversation.Message
@ -35,7 +33,7 @@ type (
// sent on each chunk received from LLM // sent on each chunk received from LLM
msgChatResponseChunk provider.Chunk msgChatResponseChunk provider.Chunk
// sent on each completed reply // sent on each completed reply
msgChatResponse *conversation.Message msgChatResponse conversation.Message
// sent when the response is canceled // sent when the response is canceled
msgChatResponseCanceled struct{} msgChatResponseCanceled struct{}
// sent when results from a tool call are returned // sent when results from a tool call are returned

View File

@ -36,16 +36,6 @@ func (m *Model) generateConversationTitle() tea.Cmd {
} }
} }
func (m *Model) updateConversationTitle(conversation *conversation.Conversation) tea.Cmd {
return func() tea.Msg {
err := m.App.UpdateConversationTitle(conversation)
if err != nil {
return shared.WrapError(err)
}
return nil
}
}
func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd { func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
msg, err := m.App.CloneMessage(message, selected) msg, err := m.App.CloneMessage(message, selected)
@ -96,11 +86,21 @@ func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.Mess
func (m *Model) persistConversation() tea.Cmd { func (m *Model) persistConversation() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
conversation, messages, err := m.App.PersistConversation(m.App.Conversation, m.App.Messages) conversation, err := m.App.PersistConversation()
if err != nil { if err != nil {
return shared.AsMsgError(err) return shared.AsMsgError(err)
} }
return msgConversationPersisted{conversation, messages} return msgConversationPersisted(conversation)
}
}
func (m *Model) persistMessages() tea.Cmd {
return func() tea.Msg {
messages, err := m.App.PersistMessages()
if err != nil {
return shared.AsMsgError(err)
}
return msgMessagesPersisted(messages)
} }
} }
@ -130,7 +130,7 @@ func (m *Model) promptLLM() tea.Cmd {
if err != nil { if err != nil {
return msgChatResponseError{Err: err} return msgChatResponseError{Err: err}
} }
return msgChatResponse(resp) return msgChatResponse(*resp)
}, },
) )
} }

View File

@ -127,7 +127,7 @@ func (m *Model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
var cmd tea.Cmd var cmd tea.Cmd
if m.selectedMessage == 0 { if m.selectedMessage == 0 {
cmd = m.cycleSelectedRoot(m.App.Conversation, dir) cmd = m.cycleSelectedRoot(&m.App.Conversation, dir)
} else if m.selectedMessage > 0 { } else if m.selectedMessage > 0 {
cmd = m.cycleSelectedReply(&m.App.Messages[m.selectedMessage-1], dir) cmd = m.cycleSelectedReply(&m.App.Messages[m.selectedMessage-1], dir)
} }
@ -162,7 +162,6 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
m.input.Blur() m.input.Blur()
return shared.KeyHandled(msg) return shared.KeyHandled(msg)
case "ctrl+s": case "ctrl+s":
// TODO: call a "handleSend" function which returns a tea.Cmd
if m.state != idle { if m.state != idle {
return nil return nil
} }
@ -172,7 +171,7 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
return shared.KeyHandled(msg) return shared.KeyHandled(msg)
} }
if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role == api.MessageRoleUser { if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role.IsUser() {
return shared.WrapError(fmt.Errorf("Can't reply to a user message")) return shared.WrapError(fmt.Errorf("Can't reply to a user message"))
} }

View File

@ -75,7 +75,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
if m.App.Conversation != nil && m.App.Conversation.ID > 0 { if m.App.Conversation.ID > 0 {
// (re)load conversation contents // (re)load conversation contents
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())
} }
@ -133,7 +133,7 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
case msgChatResponse: case msgChatResponse:
m.state = idle m.state = idle
reply := (*conversation.Message)(msg) reply := conversation.Message(msg)
reply.Content = strings.TrimSpace(reply.Content) reply.Content = strings.TrimSpace(reply.Content)
last := len(m.App.Messages) - 1 last := len(m.App.Messages) - 1
@ -142,16 +142,15 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
} }
if m.App.Messages[last].Role.IsAssistant() { if m.App.Messages[last].Role.IsAssistant() {
// TODO: handle continuations gracefully - some models support them well, others fail horribly. // TODO: handle continuations gracefully - only some models support them
m.setMessage(last, *reply) m.setMessage(last, reply)
} else { } else {
m.addMessage(*reply) m.addMessage(reply)
} }
switch reply.Role { if reply.Role == api.MessageRoleToolCall {
case api.MessageRoleToolCall:
// TODO: user confirmation before execution // TODO: user confirmation before execution
// m.state = waitingForConfirmation // m.state = confirmToolUse
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls)) cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
} }
@ -159,11 +158,9 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
cmds = append(cmds, m.persistConversation()) cmds = append(cmds, m.persistConversation())
} }
if m.App.Conversation.Title == "" { if m.App.Conversation.Title == "" && len(m.App.Messages) > 0 {
cmds = append(cmds, m.generateConversationTitle()) cmds = append(cmds, m.generateConversationTitle())
} }
m.updateContent()
case msgChatResponseCanceled: case msgChatResponseCanceled:
m.state = idle m.state = idle
m.updateContent() m.updateContent()
@ -194,8 +191,8 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
case msgConversationTitleGenerated: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
m.App.Conversation.Title = title m.App.Conversation.Title = title
if m.persistence { if m.persistence && m.App.Conversation.ID > 0 {
cmds = append(cmds, m.updateConversationTitle(m.App.Conversation)) cmds = append(cmds, m.persistConversation())
} }
case cursor.BlinkMsg: case cursor.BlinkMsg:
if m.state == pendingResponse { if m.state == pendingResponse {
@ -205,14 +202,13 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
m.updateContent() m.updateContent()
} }
case msgConversationPersisted: case msgConversationPersisted:
m.App.Conversation = msg.conversation m.App.Conversation = conversation.Conversation(msg)
m.App.Messages = msg.messages cmds = append(cmds, m.persistMessages())
case msgMessagesPersisted:
m.App.Messages = msg
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case msgMessageCloned: case msgMessageCloned:
if msg.Parent == nil {
m.App.Conversation = msg.Conversation
}
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated: case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
cmds = append(cmds, m.loadConversationMessages()) cmds = append(cmds, m.loadConversationMessages())

View File

@ -71,7 +71,7 @@ func (m *Model) renderMessageHeading(i int, message *conversation.Message) strin
prefix = " " prefix = " "
} }
if i == 0 && len(m.App.Conversation.RootMessages) > 1 && m.App.Conversation.SelectedRootID != nil { if i == 0 && m.App.Conversation.SelectedRootID != nil && len(m.App.Conversation.RootMessages) > 1 {
selectedRootIndex := 0 selectedRootIndex := 0
for j, reply := range m.App.Conversation.RootMessages { for j, reply := range m.App.Conversation.RootMessages {
if reply.ID == *m.App.Conversation.SelectedRootID { if reply.ID == *m.App.Conversation.SelectedRootID {
@ -261,7 +261,7 @@ func (m *Model) Content(width, height int) string {
func (m *Model) Header(width int) string { func (m *Model) Header(width int) string {
titleStyle := lipgloss.NewStyle().Bold(true) titleStyle := lipgloss.NewStyle().Bold(true)
var title string var title string
if m.App.Conversation != nil && m.App.Conversation.Title != "" { if m.App.Conversation.Title != "" {
title = m.App.Conversation.Title title = m.App.Conversation.Title
} else { } else {
title = "Untitled" title = "Untitled"

View File

@ -21,14 +21,14 @@ type (
// sent when conversation list is loaded // sent when conversation list is loaded
msgConversationsLoaded conversation.ConversationList msgConversationsLoaded conversation.ConversationList
// sent when a single conversation is loaded // sent when a single conversation is loaded
msgConversationLoaded *conversation.Conversation msgConversationLoaded conversation.Conversation
// sent when a conversation is deleted // sent when a conversation is deleted
msgConversationDeleted struct{} msgConversationDeleted struct{}
) )
type Model struct { type Model struct {
App *model.AppModel App *model.AppModel
width int width int
height int height int
cursor int cursor int
@ -151,14 +151,14 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
m.content.SetContent(m.renderConversationList()) m.content.SetContent(m.renderConversationList())
case msgConversationLoaded: case msgConversationLoaded:
m.App.ClearConversation() m.App.ClearConversation()
m.App.Conversation = msg m.App.Conversation = conversation.Conversation(msg)
cmds = append(cmds, func() tea.Msg { cmds = append(cmds, func() tea.Msg {
return shared.MsgViewChange(shared.ViewChat) return shared.MsgViewChange(shared.ViewChat)
}) })
case bubbles.MsgConfirmPromptAnswered: case bubbles.MsgConfirmPromptAnswered:
m.confirmPrompt.Blur() m.confirmPrompt.Blur()
if msg.Value { if msg.Value {
conv, ok := msg.Payload.(conversation.Conversation) conv, ok := msg.Payload.(conversation.ConversationListItem)
if ok { if ok {
cmds = append(cmds, m.deleteConversation(conv)) cmds = append(cmds, m.deleteConversation(conv))
} }
@ -198,13 +198,13 @@ func (m *Model) loadConversation(conversationID uint) tea.Cmd {
if err != nil { if err != nil {
return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err)) return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err))
} }
return msgConversationLoaded(conversation) return msgConversationLoaded(*conversation)
} }
} }
func (m *Model) deleteConversation(conv conversation.Conversation) tea.Cmd { func (m *Model) deleteConversation(conv conversation.ConversationListItem) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.App.Ctx.Conversations.DeleteConversation(&conv) err := m.App.Ctx.Conversations.DeleteConversationById(conv.ID)
if err != nil { if err != nil {
return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err)) return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err))
} }