Compare commits

..

7 Commits

Author SHA1 Message Date
657416780d Lead anthropic function call XML with newline 2024-03-17 18:26:32 +00:00
c143d863cb tui: support for message retry/continue
Better handling of persistence, and we now ensure the response we
persist is trimmed of whitespace, particularly important when a response
is cancelled mid-stream
2024-03-17 18:18:45 +00:00
3aff5514e4 Fix double reply callback on tool calls 2024-03-17 01:07:52 +00:00
5acdbb5675 tui: handle text wrapping ourselves, add ctrl+w wrap toggle
Gets rid of those pesky trailing characters
2024-03-17 00:43:07 +00:00
c53e952acc tui: open input/messages for editing in $EDITOR 2024-03-17 00:11:27 +00:00
3d8d3b61b3 tui: add ability to select a message 2024-03-16 05:49:04 +00:00
4fb059c850 tui: conversation rendering tweaks, remove input character limit 2024-03-16 00:37:08 +00:00
4 changed files with 279 additions and 68 deletions

View File

@ -94,7 +94,11 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
if err != nil { if err != nil {
panic("Could not serialize []ToolCall to XMLFunctionCall") panic("Could not serialize []ToolCall to XMLFunctionCall")
} }
message.Content += xmlString if len(message.Content) > 0 {
message.Content += fmt.Sprintf("\n\n%s", xmlString)
} else {
message.Content = xmlString
}
case model.MessageRoleToolResult: case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString() xmlString, err := xmlFuncResults.XMLString()
@ -197,6 +201,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{} sb := strings.Builder{}
isToolCall := false
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@ -271,6 +277,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found") return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
} }
isToolCall = true
funcCallXml := content[start:] funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE funcCallXml += FUNCTION_STOP_SEQUENCE
@ -316,10 +324,12 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "message_stop": case "message_stop":
// return the completed message // return the completed message
if callback != nil { if callback != nil {
callback(model.Message{ if !isToolCall {
Role: model.MessageRoleAssistant, callback(model.Message{
Content: sb.String(), Role: model.MessageRoleAssistant,
}) Content: sb.String(),
})
}
} }
return sb.String(), nil return sb.String(), nil
case "error": case "error":

View File

@ -204,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callbback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
@ -256,22 +256,22 @@ func (c *OpenAIClient) CreateChatCompletionStream(
return content.String(), err return content.String(), err
} }
if callbback != nil { if callback != nil {
for _, result := range results { for _, result := range results {
callbback(result) callback(result)
} }
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} } else {
if callback != nil {
if callbback != nil { callback(model.Message{
callbback(model.Message{ Role: model.MessageRoleAssistant,
Role: model.MessageRoleAssistant, Content: content.String(),
Content: content.String(), })
}) }
} }
return content.String(), err return content.String(), err

View File

@ -2,8 +2,6 @@ package tui
// The terminal UI for lmcli, launched from the `lmcli chat` command // The terminal UI for lmcli, launched from the `lmcli chat` command
// TODO: // TODO:
// - binding to open selected message/input in $EDITOR
// - ability to continue or retry previous response
// - conversation list view // - conversation list view
// - change model // - change model
// - rename conversation // - rename conversation
@ -24,6 +22,7 @@ import (
"github.com/charmbracelet/bubbles/viewport" "github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/wordwrap"
) )
type focusState int type focusState int
@ -33,6 +32,13 @@ const (
focusMessages focusMessages
) )
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type model struct { type model struct {
width int width int
height int height int
@ -44,16 +50,20 @@ type model struct {
conversation *models.Conversation conversation *models.Conversation
messages []models.Message messages []models.Message
waitingForReply bool waitingForReply bool
editorTarget editorTarget
stopSignal chan interface{} stopSignal chan interface{}
replyChan chan models.Message replyChan chan models.Message
replyChunkChan chan string replyChunkChan chan string
err error
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
err error
// ui state // ui state
focus focusState focus focusState
status string // a general status message wrap bool // whether message content is wrapped to viewport width
highlightCache []string // a cache of syntax highlighted message content status string // a general status message
highlightCache []string // a cache of syntax highlighted message content
messageOffsets []int
selectedMessage int
// ui elements // ui elements
content viewport.Model content viewport.Model
@ -90,11 +100,12 @@ type (
var ( var (
userStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("10")) userStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("10"))
assistantStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("12")) assistantStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("12"))
messageStyle = lipgloss.NewStyle().PaddingLeft(1) messageStyle = lipgloss.NewStyle().PaddingLeft(2).PaddingRight(2)
headerStyle = lipgloss.NewStyle(). headerStyle = lipgloss.NewStyle().
Background(lipgloss.Color("0")) Background(lipgloss.Color("0"))
contentStyle = lipgloss.NewStyle(). conversationStyle = lipgloss.NewStyle().
Padding(1) MarginTop(1).
MarginBottom(1)
footerStyle = lipgloss.NewStyle(). footerStyle = lipgloss.NewStyle().
BorderTop(true). BorderTop(true).
BorderStyle(lipgloss.NormalBorder()) BorderStyle(lipgloss.NormalBorder())
@ -120,16 +131,35 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd var cmds []tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case msgTempfileEditorClosed:
contents := string(msg)
switch m.editorTarget {
case input:
m.input.SetValue(contents)
case selectedMessage:
m.setMessageContents(m.selectedMessage, contents)
if m.persistence && m.messages[m.selectedMessage].ID > 0 {
// update persisted message
err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage])
if err != nil {
cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err)))
}
}
m.updateContent()
}
case tea.KeyMsg: case tea.KeyMsg:
switch msg.String() { switch msg.String() {
case "ctrl+c": case "ctrl+c":
if m.waitingForReply { if m.waitingForReply {
m.stopSignal <- "stahp!" m.stopSignal <- ""
} else { } else {
return m, tea.Quit return m, tea.Quit
} }
case "ctrl+p": case "ctrl+p":
m.persistence = !m.persistence m.persistence = !m.persistence
case "ctrl+w":
m.wrap = !m.wrap
m.updateContent()
case "q": case "q":
if m.focus != focusInput { if m.focus != focusInput {
return m, tea.Quit return m, tea.Quit
@ -177,11 +207,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
reply := models.Message(msg) reply := models.Message(msg)
last := len(m.messages) - 1 last := len(m.messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected messages length handling msgReply") panic("Unexpected empty messages handling msgReply")
} }
if reply.Role == models.MessageRoleToolCall && m.messages[last].Role == models.MessageRoleAssistant { m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.setMessage(last, reply) if m.messages[last].Role == models.MessageRoleAssistant {
} else if reply.Role != models.MessageRoleAssistant { // the last message was an assistant message, so this is a continuation
if reply.Role == models.MessageRoleToolCall {
// update last message rrole to tool call
m.messages[last].Role = models.MessageRoleToolCall
}
} else {
m.addMessage(reply) m.addMessage(reply)
} }
@ -193,7 +228,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if err != nil { if err != nil {
cmds = append(cmds, wrapError(err)) cmds = append(cmds, wrapError(err))
} else { } else {
cmds = append(cmds, m.persistRecentMessages()) cmds = append(cmds, m.persistConversation())
} }
} }
@ -205,6 +240,12 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
cmds = append(cmds, m.waitForReply()) cmds = append(cmds, m.waitForReply())
case msgResponseEnd: case msgResponseEnd:
m.waitingForReply = false m.waitingForReply = false
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgResponseEnd")
}
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.updateContent()
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
case msgResponseError: case msgResponseError:
m.waitingForReply = false m.waitingForReply = false
@ -376,11 +417,15 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
stopSignal: make(chan interface{}), stopSignal: make(chan interface{}),
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan string), replyChunkChan: make(chan string),
wrap: true,
selectedMessage: -1,
} }
m.content = viewport.New(0, 0) m.content = viewport.New(0, 0)
m.input = textarea.New() m.input = textarea.New()
m.input.CharLimit = 0
m.input.Placeholder = "Enter a message" m.input.Placeholder = "Enter a message"
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle() m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
@ -407,11 +452,60 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
return m return m
} }
// fraction is the fraction of the total screen height into view the offset
// should be scrolled into view. 0.5 = items will be snapped to middle of
// view
func scrollIntoView(vp *viewport.Model, offset int, fraction float32) {
currentOffset := vp.YOffset
if offset >= currentOffset && offset < currentOffset+vp.Height {
return
}
distance := currentOffset - offset
if distance < 0 {
// we should scroll down until it just comes into view
vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1)
} else {
// we should scroll up
vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction))
}
}
func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd { func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() { switch msg.String() {
case "tab": case "tab":
m.focus = focusInput m.focus = focusInput
m.updateContent()
m.input.Focus() m.input.Focus()
case "e":
message := m.messages[m.selectedMessage]
cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
m.editorTarget = selectedMessage
return cmd
case "ctrl+k":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage--
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+j":
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage++
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+r":
// resubmit the conversation with all messages up until and including
// the selected message
if len(m.messages) == 0 {
return nil
}
m.messages = m.messages[:m.selectedMessage+1]
m.highlightCache = m.highlightCache[:m.selectedMessage+1]
m.updateContent()
m.content.GotoBottom()
return m.promptLLM()
} }
return nil return nil
} }
@ -420,6 +514,10 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() { switch msg.String() {
case "esc": case "esc":
m.focus = focusMessages m.focus = focusMessages
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
m.selectedMessage = len(m.messages) - 1
}
m.updateContent()
m.input.Blur() m.input.Blur()
case "ctrl+s": case "ctrl+s":
userInput := strings.TrimSpace(m.input.Value()) userInput := strings.TrimSpace(m.input.Value())
@ -447,7 +545,7 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
// ensure all messages up to the one we're about to add are // ensure all messages up to the one we're about to add are
// persistent // persistent
cmd := m.persistRecentMessages() cmd := m.persistConversation()
if cmd != nil { if cmd != nil {
return cmd return cmd
} }
@ -464,19 +562,11 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
return m.promptLLM()
case "ctrl+r":
if len(m.messages) == 0 {
return nil
}
// TODO: retry from selected message
if m.messages[len(m.messages)-1].Role != models.MessageRoleUser {
m.messages = m.messages[:len(m.messages)-1]
m.updateContent()
}
m.content.GotoBottom()
return m.promptLLM() return m.promptLLM()
case "ctrl+e":
cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input
return cmd
} }
return nil return nil
} }
@ -573,16 +663,55 @@ func (m *model) promptLLM() tea.Cmd {
} }
} }
func (m *model) persistRecentMessages() tea.Cmd { func (m *model) persistConversation() tea.Cmd {
existingMessages, err := m.ctx.Store.Messages(m.conversation)
if err != nil {
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
}
existingById := make(map[uint]*models.Message, len(existingMessages))
for _, msg := range existingMessages {
existingById[msg.ID] = &msg
}
currentById := make(map[uint]*models.Message, len(m.messages))
for _, msg := range m.messages {
currentById[msg.ID] = &msg
}
for _, msg := range existingMessages {
_, ok := currentById[msg.ID]
if !ok {
err := m.ctx.Store.DeleteMessage(&msg)
if err != nil {
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
}
}
}
for i, msg := range m.messages { for i, msg := range m.messages {
if msg.ID > 0 { if msg.ID > 0 {
continue exist, ok := existingById[msg.ID]
if ok {
if msg.Content == exist.Content {
continue
}
// update message when contents don't match that of store
err := m.ctx.Store.UpdateMessage(&msg)
if err != nil {
return wrapError(err)
}
} else {
// this would be quite odd... and I'm not sure how to handle
// it at the time of writing this
}
} else {
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
} }
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
} }
return nil return nil
} }
@ -620,11 +749,25 @@ func (m *model) setMessageContents(i int, content string) {
m.highlightCache[i] = highlighted m.highlightCache[i] = highlighted
} }
// render the conversation into the main viewport
func (m *model) updateContent() { func (m *model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationView())
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
}
// render the conversation into a string
func (m *model) conversationView() string {
sb := strings.Builder{} sb := strings.Builder{}
msgCnt := len(m.messages) msgCnt := len(m.messages)
m.messageOffsets = make([]int, len(m.messages))
lineCnt := conversationStyle.GetMarginTop()
for i, message := range m.messages { for i, message := range m.messages {
m.messageOffsets[i] = lineCnt
icon := "⚙️" icon := "⚙️"
friendly := message.Role.FriendlyRole() friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Bold(true).Faint(true) style := lipgloss.NewStyle().Bold(true).Faint(true)
@ -640,37 +783,53 @@ func (m *model) updateContent() {
icon = "🔧" icon = "🔧"
} }
// write message heading with space for content
user := style.Render(icon + friendly)
var saved string var saved string
if message.ID == 0 { if message.ID == 0 {
saved = lipgloss.NewStyle().Faint(true).Render(" (not saved)") saved = lipgloss.NewStyle().Faint(true).Render(" (not saved)")
} }
// write message heading with space for content var selectedPrefix string
header := fmt.Sprintf("%s\n\n", style.Render(icon+friendly)+saved) if m.focus == focusMessages && i == m.selectedMessage {
selectedPrefix = "> "
}
header := lipgloss.NewStyle().PaddingLeft(1).Render(selectedPrefix + user + saved)
sb.WriteString(header) sb.WriteString(header)
lineCnt += lipgloss.Height(header)
// TODO: special rendering for tool calls/results? // TODO: special rendering for tool calls/results?
if message.Content != "" {
sb.WriteString("\n\n")
lineCnt += 1
// write message contents // write message contents
var highlighted string var highlighted string
if m.highlightCache[i] == "" { if m.highlightCache[i] == "" {
highlighted = message.Content highlighted = message.Content
} else { } else {
highlighted = m.highlightCache[i] highlighted = m.highlightCache[i]
}
var contents string
if m.wrap {
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 2
wrapped := wordwrap.String(highlighted, wrapWidth)
contents = wrapped
} else {
contents = highlighted
}
sb.WriteString(messageStyle.Width(0).Render(contents))
lineCnt += lipgloss.Height(contents)
} }
contents := messageStyle.Width(m.content.Width - 5).Render(highlighted)
sb.WriteString(contents)
if i < msgCnt-1 { if i < msgCnt-1 {
sb.WriteString("\n\n") sb.WriteString("\n\n")
lineCnt += 1
} }
} }
atBottom := m.content.AtBottom() return conversationStyle.Render(sb.String())
m.content.SetContent(contentStyle.Render(sb.String()))
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
} }
func Launch(ctx *lmcli.Context, convShortname string) error { func Launch(ctx *lmcli.Context, convShortname string) error {

42
pkg/tui/util.go Normal file
View File

@ -0,0 +1,42 @@
package tui
import (
"os"
"os/exec"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type msgTempfileEditorClosed string
// openTempfileEditor opens an $EDITOR on a new temporary file with the given
// content. Upon closing, the contents of the file are read back returned
// wrapped in a msgTempfileEditorClosed returned by the tea.Cmd
func openTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
msgFile, _ := os.CreateTemp("/tmp", pattern)
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
if err != nil {
return wrapError(err)
}
editor := os.Getenv("EDITOR")
if editor == "" {
editor = "vim"
}
c := exec.Command(editor, msgFile.Name())
return tea.ExecProcess(c, func(err error) tea.Msg {
bytes, err := os.ReadFile(msgFile.Name())
if err != nil {
return msgError(err)
}
fileContents := string(bytes)
if strings.HasPrefix(fileContents, placeholder) {
fileContents = fileContents[len(placeholder):]
}
stripped := strings.Trim(fileContents, "\n \t")
return msgTempfileEditorClosed(stripped)
})
}