tui: add reply persistence

This commit is contained in:
Matt Low 2024-03-13 21:20:03 +00:00
parent c3a3cb0181
commit ac0e380244
1 changed files with 73 additions and 13 deletions

View File

@ -39,6 +39,7 @@ type model struct {
replyChunkChan chan string replyChunkChan chan string
replyCancelFunc context.CancelFunc replyCancelFunc context.CancelFunc
err error err error
persistence bool // whether we will save new messages in the conversation
// ui state // ui state
focus focusState focus focusState
@ -81,7 +82,6 @@ var (
contentStyle = lipgloss.NewStyle(). contentStyle = lipgloss.NewStyle().
Padding(1) Padding(1)
footerStyle = lipgloss.NewStyle(). footerStyle = lipgloss.NewStyle().
Faint(true).
BorderTop(true). BorderTop(true).
BorderStyle(lipgloss.NormalBorder()) BorderStyle(lipgloss.NormalBorder())
) )
@ -95,8 +95,14 @@ func (m model) Init() tea.Cmd {
) )
} }
func wrapError(err error) tea.Cmd {
return func() tea.Msg {
return msgError(err)
}
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case tea.KeyMsg: case tea.KeyMsg:
@ -107,6 +113,8 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} else { } else {
return m, tea.Quit return m, tea.Quit
} }
case "ctrl+p":
m.persistence = !m.persistence
case "q": case "q":
if m.focus != focusInput { if m.focus != focusInput {
return m, tea.Quit return m, tea.Quit
@ -130,7 +138,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.updateContent() m.updateContent()
case msgConversationLoaded: case msgConversationLoaded:
m.conversation = (*models.Conversation)(msg) m.conversation = (*models.Conversation)(msg)
cmd = m.loadMessages(m.conversation) cmds = append(cmds, m.loadMessages(m.conversation))
case msgMessagesLoaded: case msgMessagesLoaded:
m.setMessages(msg) m.setMessages(msg)
m.updateContent() m.updateContent()
@ -146,7 +154,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}) })
} }
m.updateContent() m.updateContent()
cmd = m.waitForChunk() // wait for the next chunk cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
case msgReply: case msgReply:
// the last reply that was being worked on is finished // the last reply that was being worked on is finished
reply := models.Message(msg) reply := models.Message(msg)
@ -159,18 +167,24 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} else if reply.Role != models.MessageRoleAssistant { } else if reply.Role != models.MessageRoleAssistant {
m.addMessage(reply) m.addMessage(reply)
} }
if m.persistence && m.conversation != nil && m.conversation.ID > 0 {
cmds = append(cmds, m.persistRecentMessages())
}
m.updateContent() m.updateContent()
cmd = m.waitForReply() cmds = append(cmds, m.waitForReply())
case msgResponseEnd: case msgResponseEnd:
m.replyCancelFunc = nil m.replyCancelFunc = nil
m.waitingForReply = false m.waitingForReply = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
} }
if cmd != nil { if len(cmds) > 0 {
return m, cmd return m, tea.Batch(cmds...)
} }
var cmd tea.Cmd
m.input, cmd = m.input.Update(msg) m.input, cmd = m.input.Update(msg)
if cmd != nil { if cmd != nil {
return m, cmd return m, cmd
@ -225,10 +239,18 @@ func (m *model) inputView() string {
} }
func (m *model) footerView() string { func (m *model) footerView() string {
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1) segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
segmentSeparator := "|" segmentSeparator := "|"
saving := ""
if m.persistence {
saving = segmentStyle.Copy().Bold(true).Foreground(lipgloss.Color("2")).Render("✅💾")
} else {
saving = segmentStyle.Copy().Bold(true).Foreground(lipgloss.Color("1")).Render("❌💾")
}
leftSegments := []string{ leftSegments := []string{
saving,
segmentStyle.Render(m.status), segmentStyle.Render(m.status),
} }
rightSegments := []string{ rightSegments := []string{
@ -258,6 +280,7 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
m := model{ m := model{
ctx: ctx, ctx: ctx,
convShortname: convShortname, convShortname: convShortname,
persistence: true,
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan string), replyChunkChan: make(chan string),
@ -298,11 +321,29 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
if strings.TrimSpace(userInput) == "" { if strings.TrimSpace(userInput) == "" {
return nil return nil
} }
m.input.SetValue("")
m.addMessage(models.Message{ reply := models.Message{
Role: models.MessageRoleUser, Role: models.MessageRoleUser,
Content: userInput, Content: userInput,
}) }
if m.persistence && m.conversation != nil && m.conversation.ID > 0 {
// ensure all messages up to the one we're about to add are
// persistent
cmd := m.persistRecentMessages()
if cmd != nil {
return cmd
}
// persist our new message, returning with any possible errors
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil {
return wrapError(err)
}
reply = *savedReply
}
m.input.SetValue("")
m.addMessage(reply)
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
@ -382,6 +423,20 @@ func (m *model) promptLLM() tea.Cmd {
} }
} }
func (m *model) persistRecentMessages() tea.Cmd {
for i, msg := range m.messages {
if msg.ID > 0 {
continue
}
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
}
return nil
}
func (m *model) setMessages(messages []models.Message) { func (m *model) setMessages(messages []models.Message) {
m.messages = messages m.messages = messages
m.highlightCache = make([]string, len(messages)) m.highlightCache = make([]string, len(messages))
@ -436,11 +491,16 @@ func (m *model) updateContent() {
icon = "🔧" icon = "🔧"
} }
var saved string
if message.ID == 0 {
saved = lipgloss.NewStyle().Faint(true).Render(" (not saved)")
}
// write message heading with space for content // write message heading with space for content
header := fmt.Sprintf("%s\n\n", style.Render(icon+friendly)) header := fmt.Sprintf("%s\n\n", style.Render(icon+friendly)+saved)
sb.WriteString(header) sb.WriteString(header)
// TODO: render something for tool calls/results? // TODO: special rendering for tool calls/results?
// write message contents // write message contents
var highlighted string var highlighted string