From d2d946b7766e4189d71410616ec446bd544e99e0 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sat, 8 Jun 2024 23:37:58 +0000 Subject: [PATCH] Wrap chunk content in a Chunk type Preparing to include additional information with each chunk (e.g. token count) --- pkg/cmd/util/util.go | 7 ++++--- pkg/lmcli/provider/anthropic/anthropic.go | 10 +++++++--- pkg/lmcli/provider/google/google.go | 6 ++++-- pkg/lmcli/provider/ollama/ollama.go | 6 ++++-- pkg/lmcli/provider/openai/openai.go | 6 ++++-- pkg/lmcli/provider/provider.go | 6 +++++- pkg/tui/views/chat/chat.go | 7 ++++--- pkg/tui/views/chat/update.go | 7 +++---- 8 files changed, 35 insertions(+), 20 deletions(-) diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 40cbe92..9c85376 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -10,6 +10,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/util" "github.com/charmbracelet/lipgloss" ) @@ -17,7 +18,7 @@ import ( // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { - content := make(chan string) // receives the reponse from LLM + content := make(chan provider.Chunk) // receives the reponse from LLM defer close(content) // render all content received over the channel @@ -251,7 +252,7 @@ func ShowWaitAnimation(signal chan any) { // chunked) content is received on the channel, the waiting animation is // replaced by the content. // Blocks until the channel is closed. -func ShowDelayedContent(content <-chan string) { +func ShowDelayedContent(content <-chan provider.Chunk) { waitSignal := make(chan any) go ShowWaitAnimation(waitSignal) @@ -264,7 +265,7 @@ func ShowDelayedContent(content <-chan string) { <-waitSignal firstChunk = false } - fmt.Print(chunk) + fmt.Print(chunk.Content) } } diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index 51586c4..2b0b075 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -161,7 +161,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( params model.RequestParameters, messages []model.Message, callback provider.ReplyCallback, - output chan<- string, + output chan<- provider.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") @@ -242,7 +242,9 @@ func (c *AnthropicClient) CreateChatCompletionStream( return "", fmt.Errorf("invalid text delta") } sb.WriteString(text) - output <- text + output <- provider.Chunk{ + Content: text, + } case "content_block_stop": // ignore? case "message_delta": @@ -262,7 +264,9 @@ func (c *AnthropicClient) CreateChatCompletionStream( } sb.WriteString(FUNCTION_STOP_SEQUENCE) - output <- FUNCTION_STOP_SEQUENCE + output <- provider.Chunk{ + Content: FUNCTION_STOP_SEQUENCE, + } funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE diff --git a/pkg/lmcli/provider/google/google.go b/pkg/lmcli/provider/google/google.go index 289fb7b..c7a3d35 100644 --- a/pkg/lmcli/provider/google/google.go +++ b/pkg/lmcli/provider/google/google.go @@ -326,7 +326,7 @@ func (c *Client) CreateChatCompletionStream( params model.RequestParameters, messages []model.Message, callback provider.ReplyCallback, - output chan<- string, + output chan<- provider.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") @@ -393,7 +393,9 @@ func (c *Client) CreateChatCompletionStream( if part.FunctionCall != nil { toolCalls = append(toolCalls, *part.FunctionCall) } else if part.Text != "" { - output <- part.Text + output <- provider.Chunk { + Content: part.Text, + } content.WriteString(part.Text) } } diff --git a/pkg/lmcli/provider/ollama/ollama.go b/pkg/lmcli/provider/ollama/ollama.go index dc233d4..b2df01e 100644 --- a/pkg/lmcli/provider/ollama/ollama.go +++ b/pkg/lmcli/provider/ollama/ollama.go @@ -132,7 +132,7 @@ func (c *OllamaClient) CreateChatCompletionStream( params model.RequestParameters, messages []model.Message, callback provider.ReplyCallback, - output chan<- string, + output chan<- provider.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") @@ -181,7 +181,9 @@ func (c *OllamaClient) CreateChatCompletionStream( } if len(streamResp.Message.Content) > 0 { - output <- streamResp.Message.Content + output <- provider.Chunk{ + Content: streamResp.Message.Content, + } content.WriteString(streamResp.Message.Content) } } diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index 1a34a17..79915cc 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -245,7 +245,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( params model.RequestParameters, messages []model.Message, callback provider.ReplyCallback, - output chan<- string, + output chan<- provider.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") @@ -319,7 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream( } } if len(delta.Content) > 0 { - output <- delta.Content + output <- provider.Chunk { + Content: delta.Content, + } content.WriteString(delta.Content) } } diff --git a/pkg/lmcli/provider/provider.go b/pkg/lmcli/provider/provider.go index 1946966..d3fc1c8 100644 --- a/pkg/lmcli/provider/provider.go +++ b/pkg/lmcli/provider/provider.go @@ -8,6 +8,10 @@ import ( type ReplyCallback func(model.Message) +type Chunk struct { + Content string +} + type ChatCompletionClient interface { // CreateChatCompletion requests a response to the provided messages. // Replies are appended to the given replies struct, and the @@ -26,6 +30,6 @@ type ChatCompletionClient interface { params model.RequestParameters, messages []model.Message, callback ReplyCallback, - output chan<- string, + output chan<- Chunk, ) (string, error) } diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index 1b83112..4ccd983 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -4,6 +4,7 @@ import ( "time" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/spinner" @@ -16,7 +17,7 @@ import ( // custom tea.Msg types type ( // sent on each chunk received from LLM - msgResponseChunk string + msgResponseChunk provider.Chunk // sent when response is finished being received msgResponseEnd string // a special case of common.MsgError that stops the response waiting animation @@ -82,7 +83,7 @@ type Model struct { editorTarget editorTarget stopSignal chan struct{} replyChan chan models.Message - replyChunkChan chan string + replyChunkChan chan provider.Chunk persistence bool // whether we will save new messages in the conversation // ui state @@ -114,7 +115,7 @@ func Chat(shared shared.Shared) Model { stopSignal: make(chan struct{}), replyChan: make(chan models.Message), - replyChunkChan: make(chan string), + replyChunkChan: make(chan provider.Chunk), wrap: true, selectedMessage: -1, diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index a897b28..19ee29f 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -90,20 +90,19 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { case msgResponseChunk: cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk - chunk := string(msg) - if chunk == "" { + if msg.Content == "" { break } last := len(m.messages) - 1 if last >= 0 && m.messages[last].Role.IsAssistant() { // append chunk to existing message - m.setMessageContents(last, m.messages[last].Content+chunk) + m.setMessageContents(last, m.messages[last].Content+msg.Content) } else { // use chunk in new message m.addMessage(models.Message{ Role: models.MessageRoleAssistant, - Content: chunk, + Content: msg.Content, }) } m.updateContent()