From 1bd6baa83709a30ddeb988261e0fb7385ff9975b Mon Sep 17 00:00:00 2001 From: Matt Low Date: Wed, 13 Mar 2024 02:05:48 +0000 Subject: [PATCH] tui: handle multi part responses --- pkg/tui/tui.go | 45 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index ab15909..f20c663 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -33,7 +33,8 @@ type model struct { conversation *models.Conversation messages []models.Message waitingForReply bool - replyChan chan string + replyChan chan models.Message + replyChunkChan chan string replyCancelFunc context.CancelFunc err error @@ -57,6 +58,8 @@ type ( msgResponseChunk string // sent when response is finished being received msgResponseEnd string + // sent on each completed reply + msgReply models.Message // sent when a conversation is (re)loaded msgConversationLoaded *models.Conversation // send when a conversation's messages are laoded @@ -80,7 +83,8 @@ func (m model) Init() tea.Cmd { return tea.Batch( textarea.Blink, m.loadConversation(m.convShortname), - waitForChunk(m.replyChan), + m.waitForChunk(), + m.waitForReply(), ) } @@ -135,7 +139,21 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { }) } m.updateContent() - cmd = waitForChunk(m.replyChan) // wait for the next chunk + cmd = m.waitForChunk() // wait for the next chunk + case msgReply: + // the last reply that was being worked on is finished + reply := models.Message(msg) + last := len(m.messages) - 1 + if last < 0 { + panic("Unexpected messages length handling msgReply") + } + if reply.Role == models.MessageRoleToolCall && m.messages[last].Role == models.MessageRoleAssistant { + m.messages[last] = reply + } else if reply.Role != models.MessageRoleAssistant { + m.messages = append(m.messages, reply) + } + m.updateContent() + cmd = m.waitForReply() case msgResponseEnd: m.replyCancelFunc = nil m.waitingForReply = false @@ -173,7 +191,8 @@ func initialModel(ctx *lmcli.Context, convShortname string) model { ctx: ctx, convShortname: convShortname, - replyChan: make(chan string), + replyChan: make(chan models.Message), + replyChunkChan: make(chan string), } m.content = viewport.New(0, 0) @@ -253,9 +272,15 @@ func (m *model) loadMessages(c *models.Conversation) tea.Cmd { } } -func waitForChunk(ch chan string) tea.Cmd { +func (m *model) waitForReply() tea.Cmd { return func() tea.Msg { - return msgResponseChunk(<-ch) + return msgReply(<-m.replyChan) + } +} + +func (m *model) waitForChunk() tea.Cmd { + return func() tea.Msg { + return msgResponseChunk(<-m.replyChunkChan) } } @@ -281,12 +306,16 @@ func (m *model) promptLLM() tea.Cmd { ToolBag: toolBag, } + replyHandler := func(msg models.Message) { + m.replyChan <- msg + } + ctx, replyCancelFunc := context.WithCancel(context.Background()) m.replyCancelFunc = replyCancelFunc - // TODO: supply a reply callback and handle error + // TODO: handle error resp, _ := completionProvider.CreateChatCompletionStream( - ctx, requestParams, m.messages, nil, m.replyChan, + ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, ) return msgResponseEnd(resp)