From 811ec4b251a5b64f774f78a1b7542f56db543bcd Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 31 Mar 2024 02:03:53 +0000 Subject: [PATCH] tui: split up conversation related code into conversation.go moved some things to util, re-ordered some functions --- pkg/tui/conversation.go | 687 ++++++++++++++++++++++++++++ pkg/tui/conversation_list.go | 5 + pkg/tui/tui.go | 846 +++-------------------------------- pkg/tui/util.go | 46 ++ 4 files changed, 801 insertions(+), 783 deletions(-) create mode 100644 pkg/tui/conversation.go diff --git a/pkg/tui/conversation.go b/pkg/tui/conversation.go new file mode 100644 index 0000000..22faba5 --- /dev/null +++ b/pkg/tui/conversation.go @@ -0,0 +1,687 @@ +package tui + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" + models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/muesli/reflow/wordwrap" + "gopkg.in/yaml.v2" +) + +// custom tea.Msg types +type ( + // sent on each chunk received from LLM + msgResponseChunk string + // sent when response is finished being received + msgResponseEnd string + // a special case of msgError that stops the response waiting animation + msgResponseError error + // sent on each completed reply + msgAssistantReply models.Message + // sent when a conversation is (re)loaded + msgConversationLoaded *models.Conversation + // sent when a new conversation title is set + msgConversationTitleChanged string + // sent when a conversation's messages are laoded + msgMessagesLoaded []models.Message +) + +// styles +var ( + headerStyle = lipgloss.NewStyle(). + PaddingLeft(1). + PaddingRight(1). + Background(lipgloss.Color("0")) + + messageHeadingStyle = lipgloss.NewStyle(). + MarginTop(1). + MarginBottom(1). + PaddingLeft(1). + Bold(true) + + userStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("10")) + + assistantStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("12")) + + messageStyle = lipgloss.NewStyle(). + PaddingLeft(2). + PaddingRight(2) + + inputFocusedStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder(), true, true, true, false) + + inputBlurredStyle = lipgloss.NewStyle(). + Faint(true). + Border(lipgloss.RoundedBorder(), true, true, true, false) + + footerStyle = lipgloss.NewStyle() +) + +func (m *model) handleConversationInput(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc": + m.state = stateConversationList + return m.loadConversations() + case "ctrl+p": + m.persistence = !m.persistence + case "ctrl+t": + m.showToolResults = !m.showToolResults + m.rebuildMessageCache() + m.updateContent() + case "ctrl+w": + m.wrap = !m.wrap + m.rebuildMessageCache() + m.updateContent() + default: + switch m.focus { + case focusInput: + return m.handleInputKey(msg) + case focusMessages: + return m.handleMessagesKey(msg) + } + } + return nil +} + +func (m *model) handleConversationUpdate(msg tea.Msg) []tea.Cmd { + var cmds []tea.Cmd + switch msg := msg.(type) { + case msgTempfileEditorClosed: + contents := string(msg) + switch m.editorTarget { + case input: + m.input.SetValue(contents) + case selectedMessage: + m.setMessageContents(m.selectedMessage, contents) + if m.persistence && m.messages[m.selectedMessage].ID > 0 { + // update persisted message + err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage]) + if err != nil { + cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err))) + } + } + m.updateContent() + } + case msgConversationLoaded: + m.conversation = (*models.Conversation)(msg) + cmds = append(cmds, m.loadMessages(m.conversation)) + case msgMessagesLoaded: + m.setMessages(msg) + m.updateContent() + case msgResponseChunk: + chunk := string(msg) + last := len(m.messages) - 1 + if last >= 0 && m.messages[last].Role.IsAssistant() { + m.setMessageContents(last, m.messages[last].Content+chunk) + } else { + m.addMessage(models.Message{ + Role: models.MessageRoleAssistant, + Content: chunk, + }) + } + m.updateContent() + cmds = append(cmds, m.waitForChunk()) // wait for the next chunk + case msgAssistantReply: + // the last reply that was being worked on is finished + reply := models.Message(msg) + reply.Content = strings.TrimSpace(reply.Content) + + last := len(m.messages) - 1 + if last < 0 { + panic("Unexpected empty messages handling msgAssistantReply") + } + + if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() { + // this was a continuation, so replace the previous message with the completed reply + m.setMessage(last, reply) + } else { + m.addMessage(reply) + } + + if m.persistence { + var err error + if m.conversation.ID == 0 { + err = m.ctx.Store.SaveConversation(m.conversation) + } + if err != nil { + cmds = append(cmds, wrapError(err)) + } else { + cmds = append(cmds, m.persistConversation()) + } + } + + if m.conversation.Title == "" { + cmds = append(cmds, m.generateConversationTitle()) + } + + m.updateContent() + cmds = append(cmds, m.waitForReply()) + case msgResponseEnd: + m.waitingForReply = false + last := len(m.messages) - 1 + if last < 0 { + panic("Unexpected empty messages handling msgResponseEnd") + } + m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content)) + m.updateContent() + m.status = "Press ctrl+s to send" + case msgResponseError: + m.waitingForReply = false + m.status = "Press ctrl+s to send" + m.err = error(msg) + case msgConversationTitleChanged: + title := string(msg) + m.conversation.Title = title + if m.persistence { + err := m.ctx.Store.SaveConversation(m.conversation) + if err != nil { + cmds = append(cmds, wrapError(err)) + } + } + } + + var cmd tea.Cmd + m.spinner, cmd = m.spinner.Update(msg) + if cmd != nil { + cmds = append(cmds, cmd) + } + + prevInputLineCnt := m.input.LineCount() + inputCaptured := false + m.input, cmd = m.input.Update(msg) + if cmd != nil { + inputCaptured = true + cmds = append(cmds, cmd) + } + + if !inputCaptured { + m.content, cmd = m.content.Update(msg) + if cmd != nil { + cmds = append(cmds, cmd) + } + } + + // update views once window dimensions are known + if m.width > 0 { + m.views.header = m.headerView() + m.views.footer = m.footerView() + m.views.error = m.errorView() + fixedHeight := height(m.views.header) + height(m.views.error) + height(m.views.footer) + + // calculate clamped input height to accomodate input text + newHeight := max(4, min((m.height-fixedHeight-1)/2, m.input.LineCount())) + m.input.SetHeight(newHeight) + m.views.input = m.input.View() + + m.content.Height = m.height - fixedHeight - height(m.views.input) + m.views.content = m.content.View() + } + + // this is a pretty nasty hack to ensure the input area viewport doesn't + // scroll below its content, which can happen when the input viewport + // height has grown, or previously entered lines have been deleted + if prevInputLineCnt != m.input.LineCount() { + // dist is the distance we'd need to scroll up from the current cursor + // position to position the last input line at the bottom of the + // viewport. if negative, we're already scrolled above the bottom + dist := m.input.Line() - (m.input.LineCount() - m.input.Height()) + if dist > 0 { + for i := 0; i < dist; i++ { + // move cursor up until content reaches the bottom of the viewport + m.input.CursorUp() + } + m.input, cmd = m.input.Update(nil) + for i := 0; i < dist; i++ { + // move cursor back down to its previous position + m.input.CursorDown() + } + m.input, cmd = m.input.Update(nil) + } + } + + return cmds +} + +func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "tab": + m.focus = focusInput + m.updateContent() + m.input.Focus() + case "e": + message := m.messages[m.selectedMessage] + cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n") + m.editorTarget = selectedMessage + return cmd + case "ctrl+k": + if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) { + m.selectedMessage-- + m.updateContent() + offset := m.messageOffsets[m.selectedMessage] + scrollIntoView(&m.content, offset, 0.1) + } + case "ctrl+j": + if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) { + m.selectedMessage++ + m.updateContent() + offset := m.messageOffsets[m.selectedMessage] + scrollIntoView(&m.content, offset, 0.1) + } + case "ctrl+r": + // resubmit the conversation with all messages up until and including the selected message + if m.waitingForReply || len(m.messages) == 0 { + return nil + } + m.messages = m.messages[:m.selectedMessage+1] + m.messageCache = m.messageCache[:m.selectedMessage+1] + m.updateContent() + m.content.GotoBottom() + return m.promptLLM() + } + return nil +} + +func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc": + m.focus = focusMessages + if len(m.messages) > 0 { + if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) { + m.selectedMessage = len(m.messages) - 1 + } + offset := m.messageOffsets[m.selectedMessage] + scrollIntoView(&m.content, offset, 0.1) + } + m.updateContent() + m.input.Blur() + case "ctrl+s": + userInput := strings.TrimSpace(m.input.Value()) + if strings.TrimSpace(userInput) == "" { + return nil + } + + if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser { + return wrapError(fmt.Errorf("Can't reply to a user message")) + } + + reply := models.Message{ + Role: models.MessageRoleUser, + Content: userInput, + } + + if m.persistence { + var err error + if m.conversation.ID == 0 { + err = m.ctx.Store.SaveConversation(m.conversation) + } + if err != nil { + return wrapError(err) + } + + // ensure all messages up to the one we're about to add are persisted + cmd := m.persistConversation() + if cmd != nil { + return cmd + } + + savedReply, err := m.ctx.Store.AddReply(m.conversation, reply) + if err != nil { + return wrapError(err) + } + reply = *savedReply + } + + m.input.SetValue("") + m.addMessage(reply) + + m.updateContent() + m.content.GotoBottom() + return m.promptLLM() + case "ctrl+e": + cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n") + m.editorTarget = input + return cmd + } + return nil +} + +func (m *model) renderMessageHeading(i int, message *models.Message) string { + icon := "" + friendly := message.Role.FriendlyRole() + style := lipgloss.NewStyle().Faint(true).Bold(true) + + switch message.Role { + case models.MessageRoleSystem: + icon = "⚙️" + case models.MessageRoleUser: + style = userStyle + case models.MessageRoleAssistant: + style = assistantStyle + case models.MessageRoleToolCall: + style = assistantStyle + friendly = models.MessageRoleAssistant.FriendlyRole() + case models.MessageRoleToolResult: + icon = "🔧" + } + + user := style.Render(icon + friendly) + + var prefix string + var suffix string + + faint := lipgloss.NewStyle().Faint(true) + if m.focus == focusMessages { + if i == m.selectedMessage { + prefix = "> " + } + } + + if message.ID == 0 { + suffix += faint.Render(" (not saved)") + } + + return messageHeadingStyle.Render(prefix + user + suffix) +} + +func (m *model) renderMessage(msg *models.Message) string { + sb := &strings.Builder{} + sb.Grow(len(msg.Content) * 2) + if msg.Content != "" { + err := m.ctx.Chroma.Highlight(sb, msg.Content) + if err != nil { + sb.Reset() + sb.WriteString(msg.Content) + } + } + + var toolString string + switch msg.Role { + case models.MessageRoleToolCall: + bytes, err := yaml.Marshal(msg.ToolCalls) + if err != nil { + toolString = "Could not serialize ToolCalls" + } else { + toolString = "tool_calls:\n" + string(bytes) + } + case models.MessageRoleToolResult: + if !m.showToolResults { + break + } + + type renderedResult struct { + ToolName string `yaml:"tool"` + Result any + } + + var toolResults []renderedResult + for _, result := range msg.ToolResults { + var jsonResult interface{} + err := json.Unmarshal([]byte(result.Result), &jsonResult) + if err != nil { + // If parsing as JSON fails, treat Result as a plain string + toolResults = append(toolResults, renderedResult{ + ToolName: result.ToolName, + Result: result.Result, + }) + } else { + // If parsing as JSON succeeds, marshal the parsed JSON into YAML + toolResults = append(toolResults, renderedResult{ + ToolName: result.ToolName, + Result: &jsonResult, + }) + } + } + + bytes, err := yaml.Marshal(toolResults) + if err != nil { + toolString = "Could not serialize ToolResults" + } else { + toolString = "tool_results:\n" + string(bytes) + } + } + + if toolString != "" { + toolString = strings.TrimRight(toolString, "\n") + if msg.Content != "" { + sb.WriteString("\n\n") + } + _ = m.ctx.Chroma.HighlightLang(sb, toolString, "yaml") + } + + content := strings.TrimRight(sb.String(), "\n") + + if m.wrap { + wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 1 + content = wordwrap.String(content, wrapWidth) + } + + return messageStyle.Width(0).Render(content) +} + +// render the conversation into a string +func (m *model) conversationMessagesView() string { + sb := strings.Builder{} + + m.messageOffsets = make([]int, len(m.messages)) + lineCnt := 1 + for i, message := range m.messages { + m.messageOffsets[i] = lineCnt + + switch message.Role { + case models.MessageRoleToolCall: + if !m.showToolResults && message.Content == "" { + continue + } + case models.MessageRoleToolResult: + if !m.showToolResults { + continue + } + } + + heading := m.renderMessageHeading(i, &message) + sb.WriteString(heading) + sb.WriteString("\n") + lineCnt += lipgloss.Height(heading) + + cached := m.messageCache[i] + sb.WriteString(cached) + sb.WriteString("\n") + lineCnt += lipgloss.Height(cached) + } + + return sb.String() +} + +func (m *model) setMessages(messages []models.Message) { + m.messages = messages + m.rebuildMessageCache() +} + +func (m *model) setMessage(i int, msg models.Message) { + if i >= len(m.messages) { + panic("i out of range") + } + m.messages[i] = msg + m.messageCache[i] = m.renderMessage(&msg) +} + +func (m *model) addMessage(msg models.Message) { + m.messages = append(m.messages, msg) + m.messageCache = append(m.messageCache, m.renderMessage(&msg)) +} + +func (m *model) setMessageContents(i int, content string) { + if i >= len(m.messages) { + panic("i out of range") + } + m.messages[i].Content = content + m.messageCache[i] = m.renderMessage(&m.messages[i]) +} + +func (m *model) rebuildMessageCache() { + m.messageCache = make([]string, len(m.messages)) + for i, msg := range m.messages { + m.messageCache[i] = m.renderMessage(&msg) + } +} + +func (m *model) updateContent() { + atBottom := m.content.AtBottom() + m.content.SetContent(m.conversationMessagesView()) + if atBottom { + // if we were at bottom before the update, scroll with the output + m.content.GotoBottom() + } +} + +func (m *model) loadConversation(shortname string) tea.Cmd { + return func() tea.Msg { + if shortname == "" { + return nil + } + c, err := m.ctx.Store.ConversationByShortName(shortname) + if err != nil { + return msgError(fmt.Errorf("Could not lookup conversation: %v", err)) + } + if c.ID == 0 { + return msgError(fmt.Errorf("Conversation not found: %s", shortname)) + } + return msgConversationLoaded(c) + } +} + +func (m *model) loadMessages(c *models.Conversation) tea.Cmd { + return func() tea.Msg { + messages, err := m.ctx.Store.Messages(c) + if err != nil { + return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) + } + return msgMessagesLoaded(messages) + } +} + +func (m *model) persistConversation() tea.Cmd { + existingMessages, err := m.ctx.Store.Messages(m.conversation) + if err != nil { + return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err)) + } + + existingById := make(map[uint]*models.Message, len(existingMessages)) + for _, msg := range existingMessages { + existingById[msg.ID] = &msg + } + + currentById := make(map[uint]*models.Message, len(m.messages)) + for _, msg := range m.messages { + currentById[msg.ID] = &msg + } + + for _, msg := range existingMessages { + _, ok := currentById[msg.ID] + if !ok { + err := m.ctx.Store.DeleteMessage(&msg) + if err != nil { + return wrapError(fmt.Errorf("Failed to remove messages: %v", err)) + } + } + } + + for i, msg := range m.messages { + if msg.ID > 0 { + exist, ok := existingById[msg.ID] + if ok { + if msg.Content == exist.Content { + continue + } + // update message when contents don't match that of store + err := m.ctx.Store.UpdateMessage(&msg) + if err != nil { + return wrapError(err) + } + } else { + // this would be quite odd... and I'm not sure how to handle + // it at the time of writing this + } + } else { + newMessage, err := m.ctx.Store.AddReply(m.conversation, msg) + if err != nil { + return wrapError(err) + } + m.setMessage(i, *newMessage) + } + } + return nil +} + +func (m *model) generateConversationTitle() tea.Cmd { + return func() tea.Msg { + title, err := cmdutil.GenerateTitle(m.ctx, m.conversation) + if err != nil { + return msgError(err) + } + return msgConversationTitleChanged(title) + } +} + +func (m *model) waitForReply() tea.Cmd { + return func() tea.Msg { + return msgAssistantReply(<-m.replyChan) + } +} + +func (m *model) waitForChunk() tea.Cmd { + return func() tea.Msg { + return msgResponseChunk(<-m.replyChunkChan) + } +} + + +func (m *model) promptLLM() tea.Cmd { + m.waitingForReply = true + m.status = "Press ctrl+c to cancel" + + return func() tea.Msg { + completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model) + if err != nil { + return msgError(err) + } + + requestParams := models.RequestParameters{ + Model: *m.ctx.Config.Defaults.Model, + MaxTokens: *m.ctx.Config.Defaults.MaxTokens, + Temperature: *m.ctx.Config.Defaults.Temperature, + ToolBag: m.ctx.EnabledTools, + } + + replyHandler := func(msg models.Message) { + m.replyChan <- msg + } + + ctx, cancel := context.WithCancel(context.Background()) + + canceled := false + go func() { + select { + case <-m.stopSignal: + canceled = true + cancel() + } + }() + + resp, err := completionProvider.CreateChatCompletionStream( + ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, + ) + + if err != nil && !canceled { + return msgResponseError(err) + } + + return msgResponseEnd(resp) + } +} diff --git a/pkg/tui/conversation_list.go b/pkg/tui/conversation_list.go index 4887e22..efcb299 100644 --- a/pkg/tui/conversation_list.go +++ b/pkg/tui/conversation_list.go @@ -12,6 +12,11 @@ import ( "github.com/charmbracelet/lipgloss" ) +type ( + // send when conversation list is loaded + msgConversationsLoaded []models.Conversation +) + func (m *model) handleConversationListInput(msg tea.KeyMsg) tea.Cmd { switch msg.String() { case "enter": diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index 6669902..50b1c47 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -6,16 +6,12 @@ package tui // - change model // - rename conversation // - set system prompt -// - system prompt library? import ( - "context" - "encoding/json" "fmt" "strings" "time" - cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/charmbracelet/bubbles/spinner" @@ -23,9 +19,6 @@ import ( "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/muesli/reflow/ansi" - "github.com/muesli/reflow/wordwrap" - "gopkg.in/yaml.v2" ) type appState int @@ -51,6 +44,21 @@ const ( selectedMessage ) +// we populate these fields as part of Update(), and let View() be +// responsible for returning the final composition of elements +type views struct { + header string + content string + error string + input string + footer string +} + +type ( + // sent when an error occurs + msgError error +) + type model struct { width int height int @@ -61,7 +69,7 @@ type model struct { // application state state appState conversations []models.Conversation - lastReplies []models.Message + lastReplies []models.Message conversation *models.Conversation messages []models.Message selectedMessage int @@ -88,73 +96,55 @@ type model struct { views *views } -// we populate these fields in the main Update() function, and let View() -// be responsible for returning the final composition of elements -type views struct { - header string - content string - error string - input string - footer string -} +func initialModel(ctx *lmcli.Context, convShortname string) model { + m := model{ + ctx: ctx, + convShortname: convShortname, + conversation: &models.Conversation{}, + persistence: true, -// styles -var ( - headerStyle = lipgloss.NewStyle(). - PaddingLeft(1). - PaddingRight(1). - Background(lipgloss.Color("0")) + stopSignal: make(chan interface{}), + replyChan: make(chan models.Message), + replyChunkChan: make(chan string), - messageHeadingStyle = lipgloss.NewStyle(). - MarginTop(1). - MarginBottom(1). - PaddingLeft(1). - Bold(true) + wrap: true, + selectedMessage: -1, - userStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("10")) - - assistantStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("12")) - - messageStyle = lipgloss.NewStyle(). - PaddingLeft(2). - PaddingRight(2) - - inputFocusedStyle = lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder(), true, true, true, false) - - inputBlurredStyle = lipgloss.NewStyle(). - Faint(true). - Border(lipgloss.RoundedBorder(), true, true, true, false) - - footerStyle = lipgloss.NewStyle() -) - -// custom tea.Msg types -type ( - // sent on each chunk received from LLM - msgResponseChunk string - // sent when response is finished being received - msgResponseEnd string - // a special case of msgError that stops the response waiting animation - msgResponseError error - // sent on each completed reply - msgAssistantReply models.Message - // sent when a conversation is (re)loaded - msgConversationLoaded *models.Conversation - // sent when a new conversation title is set - msgConversationTitleChanged string - // sent when a conversation's messages are laoded - msgMessagesLoaded []models.Message - // send when conversation list is loaded - msgConversationsLoaded []models.Conversation - // sent when an error occurs - msgError error -) - -func wrapError(err error) tea.Cmd { - return func() tea.Msg { - return msgError(err) + views: &views{}, } + + m.state = stateConversation + + m.content = viewport.New(0, 0) + + m.input = textarea.New() + m.input.MaxHeight = 0 + m.input.CharLimit = 0 + m.input.Placeholder = "Enter a message" + + m.input.Focus() + m.input.FocusedStyle.CursorLine = lipgloss.NewStyle() + m.input.FocusedStyle.Base = inputFocusedStyle + m.input.BlurredStyle.Base = inputBlurredStyle + m.input.ShowLineNumbers = false + + m.spinner = spinner.New(spinner.WithSpinner( + spinner.Spinner{ + Frames: []string{ + ". ", + ".. ", + "...", + ".. ", + ". ", + " ", + }, + FPS: time.Second / 3, + }, + )) + + m.waitingForReply = false + m.status = "Press ctrl+s to send" + return m } func (m model) Init() tea.Cmd { @@ -179,7 +169,6 @@ func (m *model) handleGlobalInput(msg tea.KeyMsg) tea.Cmd { switch msg.String() { case "ctrl+c": if m.waitingForReply { - m.status = "Cancelling..." m.stopSignal <- "" return nil } else { @@ -200,193 +189,6 @@ func (m *model) handleGlobalInput(msg tea.KeyMsg) tea.Cmd { return nil } -func (m *model) handleConversationInput(msg tea.KeyMsg) tea.Cmd { - switch msg.String() { - case "esc": - m.state = stateConversationList - return m.loadConversations() - case "ctrl+p": - m.persistence = !m.persistence - case "ctrl+t": - m.showToolResults = !m.showToolResults - m.rebuildMessageCache() - m.updateContent() - case "ctrl+w": - m.wrap = !m.wrap - m.rebuildMessageCache() - m.updateContent() - default: - switch m.focus { - case focusInput: - return m.handleInputKey(msg) - case focusMessages: - return m.handleMessagesKey(msg) - } - } - return nil -} - -func (m *model) handleConversationUpdate(msg tea.Msg) []tea.Cmd { - var cmds []tea.Cmd - switch msg := msg.(type) { - case msgTempfileEditorClosed: - contents := string(msg) - switch m.editorTarget { - case input: - m.input.SetValue(contents) - case selectedMessage: - m.setMessageContents(m.selectedMessage, contents) - if m.persistence && m.messages[m.selectedMessage].ID > 0 { - // update persisted message - err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage]) - if err != nil { - cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err))) - } - } - m.updateContent() - } - case msgConversationLoaded: - m.conversation = (*models.Conversation)(msg) - cmds = append(cmds, m.loadMessages(m.conversation)) - case msgMessagesLoaded: - m.setMessages(msg) - m.updateContent() - case msgResponseChunk: - chunk := string(msg) - last := len(m.messages) - 1 - if last >= 0 && m.messages[last].Role.IsAssistant() { - m.setMessageContents(last, m.messages[last].Content+chunk) - } else { - m.addMessage(models.Message{ - Role: models.MessageRoleAssistant, - Content: chunk, - }) - } - m.updateContent() - cmds = append(cmds, m.waitForChunk()) // wait for the next chunk - case msgAssistantReply: - // the last reply that was being worked on is finished - reply := models.Message(msg) - reply.Content = strings.TrimSpace(reply.Content) - - last := len(m.messages) - 1 - if last < 0 { - panic("Unexpected empty messages handling msgAssistantReply") - } - - if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() { - // this was a continuation, so replace the previous message with the completed reply - m.setMessage(last, reply) - } else { - m.addMessage(reply) - } - - if m.persistence { - var err error - if m.conversation.ID == 0 { - err = m.ctx.Store.SaveConversation(m.conversation) - } - if err != nil { - cmds = append(cmds, wrapError(err)) - } else { - cmds = append(cmds, m.persistConversation()) - } - } - - if m.conversation.Title == "" { - cmds = append(cmds, m.generateConversationTitle()) - } - - m.updateContent() - cmds = append(cmds, m.waitForReply()) - case msgResponseEnd: - m.waitingForReply = false - last := len(m.messages) - 1 - if last < 0 { - panic("Unexpected empty messages handling msgResponseEnd") - } - m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content)) - m.updateContent() - m.status = "Press ctrl+s to send" - case msgResponseError: - m.waitingForReply = false - m.status = "Press ctrl+s to send" - m.err = error(msg) - case msgConversationTitleChanged: - title := string(msg) - m.conversation.Title = title - if m.persistence { - err := m.ctx.Store.SaveConversation(m.conversation) - if err != nil { - cmds = append(cmds, wrapError(err)) - } - } - } - - var cmd tea.Cmd - m.spinner, cmd = m.spinner.Update(msg) - if cmd != nil { - cmds = append(cmds, cmd) - } - - prevInputLineCnt := m.input.LineCount() - inputCaptured := false - m.input, cmd = m.input.Update(msg) - if cmd != nil { - inputCaptured = true - cmds = append(cmds, cmd) - } - - if !inputCaptured { - m.content, cmd = m.content.Update(msg) - if cmd != nil { - cmds = append(cmds, cmd) - } - } - - // update views once window dimensions are known - if m.width > 0 { - m.views.header = m.headerView() - m.views.footer = m.footerView() - m.views.error = m.errorView() - fixedHeight := height(m.views.header) + height(m.views.error) + height(m.views.footer) - - // calculate clamped input height to accomodate input text - newHeight := max(4, min((m.height-fixedHeight-1)/2, m.input.LineCount())) - m.input.SetHeight(newHeight) - m.views.input = m.input.View() - - m.content.Height = m.height - fixedHeight - height(m.views.input) - m.views.content = m.content.View() - } - - // this is a pretty nasty hack to ensure the input area viewport doesn't - // scroll below its content, which can happen when the input viewport - // height has grown, or previously entered lines have been deleted - if prevInputLineCnt != m.input.LineCount() { - // dist is the distance we'd need to scroll up from the current cursor - // position to position the last input line at the bottom of the - // viewport. if negative, we're already scrolled above the bottom - dist := m.input.Line() - (m.input.LineCount() - m.input.Height()) - if dist > 0 { - for i := 0; i < dist; i++ { - // move cursor up until content reaches the bottom of the viewport - m.input.CursorUp() - } - m.input, cmd = m.input.Update(nil) - cmds = append(cmds, cmd) - for i := 0; i < dist; i++ { - // move cursor back down to its previous position - m.input.CursorDown() - } - m.input, cmd = m.input.Update(nil) - cmds = append(cmds, cmd) - } - } - - return cmds -} - func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { var cmds []tea.Cmd @@ -417,29 +219,6 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(cmds...) } -func height(str string) int { - if str == "" { - return 0 - } - return strings.Count(str, "\n") + 1 -} - -func truncateToCellWidth(str string, width int, tail string) string { - cellWidth := ansi.PrintableRuneWidth(str) - if cellWidth <= width { - return str - } - tailWidth := ansi.PrintableRuneWidth(tail) - for { - str = str[:len(str)-((cellWidth+tailWidth)-width)] - cellWidth = ansi.PrintableRuneWidth(str) - if cellWidth+tailWidth <= max(width, 0) { - break - } - } - return str + tail -} - func (m model) View() string { if m.width == 0 { // this is the case upon initial startup, but it's also a safe bet that @@ -548,511 +327,12 @@ func (m *model) footerView() string { return footerStyle.Width(m.width).Render(footer) } -func initialModel(ctx *lmcli.Context, convShortname string) model { - m := model{ - ctx: ctx, - convShortname: convShortname, - conversation: &models.Conversation{}, - persistence: true, - - stopSignal: make(chan interface{}), - replyChan: make(chan models.Message), - replyChunkChan: make(chan string), - - wrap: true, - selectedMessage: -1, - - views: &views{}, - } - - m.state = stateConversation - - m.content = viewport.New(0, 0) - - m.input = textarea.New() - m.input.CharLimit = 0 - m.input.Placeholder = "Enter a message" - - m.input.Focus() - m.input.FocusedStyle.CursorLine = lipgloss.NewStyle() - m.input.FocusedStyle.Base = inputFocusedStyle - m.input.BlurredStyle.Base = inputBlurredStyle - m.input.ShowLineNumbers = false - - m.spinner = spinner.New(spinner.WithSpinner( - spinner.Spinner{ - Frames: []string{ - ". ", - ".. ", - "...", - ".. ", - ". ", - " ", - }, - FPS: time.Second / 3, - }, - )) - - m.waitingForReply = false - m.status = "Press ctrl+s to send" - return m -} - -// fraction is the fraction of the total screen height into view the offset -// should be scrolled into view. 0.5 = items will be snapped to middle of -// view -func scrollIntoView(vp *viewport.Model, offset int, fraction float32) { - currentOffset := vp.YOffset - if offset >= currentOffset && offset < currentOffset+vp.Height { - return - } - distance := currentOffset - offset - if distance < 0 { - // we should scroll down until it just comes into view - vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1) - } else { - // we should scroll up - vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction)) - } -} - -func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd { - switch msg.String() { - case "tab": - m.focus = focusInput - m.updateContent() - m.input.Focus() - case "e": - message := m.messages[m.selectedMessage] - cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n") - m.editorTarget = selectedMessage - return cmd - case "ctrl+k": - if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) { - m.selectedMessage-- - m.updateContent() - offset := m.messageOffsets[m.selectedMessage] - scrollIntoView(&m.content, offset, 0.1) - } - case "ctrl+j": - if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) { - m.selectedMessage++ - m.updateContent() - offset := m.messageOffsets[m.selectedMessage] - scrollIntoView(&m.content, offset, 0.1) - } - case "ctrl+r": - // resubmit the conversation with all messages up until and including the selected message - if m.waitingForReply || len(m.messages) == 0 { - return nil - } - m.messages = m.messages[:m.selectedMessage+1] - m.messageCache = m.messageCache[:m.selectedMessage+1] - m.updateContent() - m.content.GotoBottom() - return m.promptLLM() - } - return nil -} - -func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd { - switch msg.String() { - case "esc": - m.focus = focusMessages - if len(m.messages) > 0 { - if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) { - m.selectedMessage = len(m.messages) - 1 - } - offset := m.messageOffsets[m.selectedMessage] - scrollIntoView(&m.content, offset, 0.1) - } - m.updateContent() - m.input.Blur() - case "ctrl+s": - userInput := strings.TrimSpace(m.input.Value()) - if strings.TrimSpace(userInput) == "" { - return nil - } - - if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser { - return wrapError(fmt.Errorf("Can't reply to a user message")) - } - - reply := models.Message{ - Role: models.MessageRoleUser, - Content: userInput, - } - - if m.persistence { - var err error - if m.conversation.ID == 0 { - err = m.ctx.Store.SaveConversation(m.conversation) - } - if err != nil { - return wrapError(err) - } - - // ensure all messages up to the one we're about to add are persisted - cmd := m.persistConversation() - if cmd != nil { - return cmd - } - - savedReply, err := m.ctx.Store.AddReply(m.conversation, reply) - if err != nil { - return wrapError(err) - } - reply = *savedReply - } - - m.input.SetValue("") - m.addMessage(reply) - - m.updateContent() - m.content.GotoBottom() - return m.promptLLM() - case "ctrl+e": - cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n") - m.editorTarget = input - return cmd - } - return nil -} - -func (m *model) loadConversation(shortname string) tea.Cmd { +func wrapError(err error) tea.Cmd { return func() tea.Msg { - if shortname == "" { - return nil - } - c, err := m.ctx.Store.ConversationByShortName(shortname) - if err != nil { - return msgError(fmt.Errorf("Could not lookup conversation: %v", err)) - } - if c.ID == 0 { - return msgError(fmt.Errorf("Conversation not found: %s", shortname)) - } - return msgConversationLoaded(c) + return msgError(err) } } -func (m *model) loadMessages(c *models.Conversation) tea.Cmd { - return func() tea.Msg { - messages, err := m.ctx.Store.Messages(c) - if err != nil { - return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err)) - } - return msgMessagesLoaded(messages) - } -} - -func (m *model) waitForReply() tea.Cmd { - return func() tea.Msg { - return msgAssistantReply(<-m.replyChan) - } -} - -func (m *model) waitForChunk() tea.Cmd { - return func() tea.Msg { - return msgResponseChunk(<-m.replyChunkChan) - } -} - -func (m *model) generateConversationTitle() tea.Cmd { - return func() tea.Msg { - title, err := cmdutil.GenerateTitle(m.ctx, m.conversation) - if err != nil { - return msgError(err) - } - return msgConversationTitleChanged(title) - } -} - -func (m *model) promptLLM() tea.Cmd { - m.waitingForReply = true - m.status = "Press ctrl+c to cancel" - - return func() tea.Msg { - completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model) - if err != nil { - return msgError(err) - } - - requestParams := models.RequestParameters{ - Model: *m.ctx.Config.Defaults.Model, - MaxTokens: *m.ctx.Config.Defaults.MaxTokens, - Temperature: *m.ctx.Config.Defaults.Temperature, - ToolBag: m.ctx.EnabledTools, - } - - replyHandler := func(msg models.Message) { - m.replyChan <- msg - } - - ctx, cancel := context.WithCancel(context.Background()) - - canceled := false - go func() { - select { - case <-m.stopSignal: - canceled = true - cancel() - } - }() - - resp, err := completionProvider.CreateChatCompletionStream( - ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, - ) - - if err != nil && !canceled { - return msgResponseError(err) - } - - return msgResponseEnd(resp) - } -} - -func (m *model) persistConversation() tea.Cmd { - existingMessages, err := m.ctx.Store.Messages(m.conversation) - if err != nil { - return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err)) - } - - existingById := make(map[uint]*models.Message, len(existingMessages)) - for _, msg := range existingMessages { - existingById[msg.ID] = &msg - } - - currentById := make(map[uint]*models.Message, len(m.messages)) - for _, msg := range m.messages { - currentById[msg.ID] = &msg - } - - for _, msg := range existingMessages { - _, ok := currentById[msg.ID] - if !ok { - err := m.ctx.Store.DeleteMessage(&msg) - if err != nil { - return wrapError(fmt.Errorf("Failed to remove messages: %v", err)) - } - } - } - - for i, msg := range m.messages { - if msg.ID > 0 { - exist, ok := existingById[msg.ID] - if ok { - if msg.Content == exist.Content { - continue - } - // update message when contents don't match that of store - err := m.ctx.Store.UpdateMessage(&msg) - if err != nil { - return wrapError(err) - } - } else { - // this would be quite odd... and I'm not sure how to handle - // it at the time of writing this - } - } else { - newMessage, err := m.ctx.Store.AddReply(m.conversation, msg) - if err != nil { - return wrapError(err) - } - m.setMessage(i, *newMessage) - } - } - return nil -} - -func (m *model) renderMessageHeading(i int, message *models.Message) string { - icon := "" - friendly := message.Role.FriendlyRole() - style := lipgloss.NewStyle().Faint(true).Bold(true) - - switch message.Role { - case models.MessageRoleSystem: - icon = "⚙️" - case models.MessageRoleUser: - style = userStyle - case models.MessageRoleAssistant: - style = assistantStyle - case models.MessageRoleToolCall: - style = assistantStyle - friendly = models.MessageRoleAssistant.FriendlyRole() - case models.MessageRoleToolResult: - icon = "🔧" - } - - user := style.Render(icon + friendly) - - var prefix string - var suffix string - - faint := lipgloss.NewStyle().Faint(true) - if m.focus == focusMessages { - if i == m.selectedMessage { - prefix = "> " - } - } - - if message.ID == 0 { - suffix += faint.Render(" (not saved)") - } - - return messageHeadingStyle.Render(prefix + user + suffix) -} - -func (m *model) renderMessage(msg *models.Message) string { - sb := &strings.Builder{} - sb.Grow(len(msg.Content) * 2) - if msg.Content != "" { - err := m.ctx.Chroma.Highlight(sb, msg.Content) - if err != nil { - sb.Reset() - sb.WriteString(msg.Content) - } - } - - var toolString string - switch msg.Role { - case models.MessageRoleToolCall: - bytes, err := yaml.Marshal(msg.ToolCalls) - if err != nil { - toolString = "Could not serialize ToolCalls" - } else { - toolString = "tool_calls:\n" + string(bytes) - } - case models.MessageRoleToolResult: - if !m.showToolResults { - break - } - - type renderedResult struct { - ToolName string `yaml:"tool"` - Result any - } - - var toolResults []renderedResult - for _, result := range msg.ToolResults { - var jsonResult interface{} - err := json.Unmarshal([]byte(result.Result), &jsonResult) - if err != nil { - // If parsing as JSON fails, treat Result as a plain string - toolResults = append(toolResults, renderedResult{ - ToolName: result.ToolName, - Result: result.Result, - }) - } else { - // If parsing as JSON succeeds, marshal the parsed JSON into YAML - toolResults = append(toolResults, renderedResult{ - ToolName: result.ToolName, - Result: &jsonResult, - }) - } - } - - bytes, err := yaml.Marshal(toolResults) - if err != nil { - toolString = "Could not serialize ToolResults" - } else { - toolString = "tool_results:\n" + string(bytes) - } - } - - if toolString != "" { - toolString = strings.TrimRight(toolString, "\n") - if msg.Content != "" { - sb.WriteString("\n\n") - } - _ = m.ctx.Chroma.HighlightLang(sb, toolString, "yaml") - } - - content := strings.TrimRight(sb.String(), "\n") - - if m.wrap { - wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 1 - content = wordwrap.String(content, wrapWidth) - } - - return messageStyle.Width(0).Render(content) -} - -func (m *model) setMessages(messages []models.Message) { - m.messages = messages - m.rebuildMessageCache() -} - -func (m *model) setMessage(i int, msg models.Message) { - if i >= len(m.messages) { - panic("i out of range") - } - m.messages[i] = msg - m.messageCache[i] = m.renderMessage(&msg) -} - -func (m *model) addMessage(msg models.Message) { - m.messages = append(m.messages, msg) - m.messageCache = append(m.messageCache, m.renderMessage(&msg)) -} - -func (m *model) setMessageContents(i int, content string) { - if i >= len(m.messages) { - panic("i out of range") - } - m.messages[i].Content = content - m.messageCache[i] = m.renderMessage(&m.messages[i]) -} - -func (m *model) rebuildMessageCache() { - m.messageCache = make([]string, len(m.messages)) - for i, msg := range m.messages { - m.messageCache[i] = m.renderMessage(&msg) - } -} - -func (m *model) updateContent() { - atBottom := m.content.AtBottom() - m.content.SetContent(m.conversationMessagesView()) - if atBottom { - // if we were at bottom before the update, scroll with the output - m.content.GotoBottom() - } -} - -// render the conversation into a string -func (m *model) conversationMessagesView() string { - sb := strings.Builder{} - - m.messageOffsets = make([]int, len(m.messages)) - lineCnt := 1 - for i, message := range m.messages { - m.messageOffsets[i] = lineCnt - - switch message.Role { - case models.MessageRoleToolCall: - if !m.showToolResults && message.Content == "" { - continue - } - case models.MessageRoleToolResult: - if !m.showToolResults { - continue - } - } - - heading := m.renderMessageHeading(i, &message) - sb.WriteString(heading) - sb.WriteString("\n") - lineCnt += lipgloss.Height(heading) - - cached := m.messageCache[i] - sb.WriteString(cached) - sb.WriteString("\n") - lineCnt += lipgloss.Height(cached) - } - - return sb.String() -} - func Launch(ctx *lmcli.Context, convShortname string) error { p := tea.NewProgram(initialModel(ctx, convShortname), tea.WithAltScreen()) if _, err := p.Run(); err != nil { diff --git a/pkg/tui/util.go b/pkg/tui/util.go index c0d4fff..f236f50 100644 --- a/pkg/tui/util.go +++ b/pkg/tui/util.go @@ -5,7 +5,9 @@ import ( "os/exec" "strings" + "github.com/charmbracelet/bubbles/viewport" tea "github.com/charmbracelet/bubbletea" + "github.com/muesli/reflow/ansi" ) type msgTempfileEditorClosed string @@ -41,3 +43,47 @@ func openTempfileEditor(pattern string, content string, placeholder string) tea. return msgTempfileEditorClosed(stripped) }) } + +// similar to lipgloss.Height, except returns 0 on empty strings +func height(str string) int { + if str == "" { + return 0 + } + return strings.Count(str, "\n") + 1 +} + +// truncate a string until its rendered cell width + the provided tail fits +// within the given width +func truncateToCellWidth(str string, width int, tail string) string { + cellWidth := ansi.PrintableRuneWidth(str) + if cellWidth <= width { + return str + } + tailWidth := ansi.PrintableRuneWidth(tail) + for { + str = str[:len(str)-((cellWidth+tailWidth)-width)] + cellWidth = ansi.PrintableRuneWidth(str) + if cellWidth+tailWidth <= max(width, 0) { + break + } + } + return str + tail +} + +// fraction is the fraction of the total screen height into view the offset +// should be scrolled into view. 0.5 = items will be snapped to middle of +// view +func scrollIntoView(vp *viewport.Model, offset int, fraction float32) { + currentOffset := vp.YOffset + if offset >= currentOffset && offset < currentOffset+vp.Height { + return + } + distance := currentOffset - offset + if distance < 0 { + // we should scroll down until it just comes into view + vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1) + } else { + // we should scroll up + vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction)) + } +}