Wrap chunk content in a Chunk type

Preparing to include additional information with each chunk (e.g. token
count)
This commit is contained in:
Matt Low 2024-06-08 23:37:58 +00:00
parent c963747066
commit d2d946b776
8 changed files with 35 additions and 20 deletions

View File

@ -10,6 +10,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
@ -17,7 +18,7 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan provider.Chunk) // receives the reponse from LLM
defer close(content) defer close(content)
// render all content received over the channel // render all content received over the channel
@ -251,7 +252,7 @@ func ShowWaitAnimation(signal chan any) {
// chunked) content is received on the channel, the waiting animation is // chunked) content is received on the channel, the waiting animation is
// replaced by the content. // replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func ShowDelayedContent(content <-chan string) { func ShowDelayedContent(content <-chan provider.Chunk) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)
@ -264,7 +265,7 @@ func ShowDelayedContent(content <-chan string) {
<-waitSignal <-waitSignal
firstChunk = false firstChunk = false
} }
fmt.Print(chunk) fmt.Print(chunk.Content)
} }
} }

View File

@ -161,7 +161,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- provider.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -242,7 +242,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", fmt.Errorf("invalid text delta") return "", fmt.Errorf("invalid text delta")
} }
sb.WriteString(text) sb.WriteString(text)
output <- text output <- provider.Chunk{
Content: text,
}
case "content_block_stop": case "content_block_stop":
// ignore? // ignore?
case "message_delta": case "message_delta":
@ -262,7 +264,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE output <- provider.Chunk{
Content: FUNCTION_STOP_SEQUENCE,
}
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE

View File

@ -326,7 +326,7 @@ func (c *Client) CreateChatCompletionStream(
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- provider.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -393,7 +393,9 @@ func (c *Client) CreateChatCompletionStream(
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- part.Text output <- provider.Chunk {
Content: part.Text,
}
content.WriteString(part.Text) content.WriteString(part.Text)
} }
} }

View File

@ -132,7 +132,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- provider.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -181,7 +181,9 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
output <- streamResp.Message.Content output <- provider.Chunk{
Content: streamResp.Message.Content,
}
content.WriteString(streamResp.Message.Content) content.WriteString(streamResp.Message.Content)
} }
} }

View File

@ -245,7 +245,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- provider.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -319,7 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
if len(delta.Content) > 0 { if len(delta.Content) > 0 {
output <- delta.Content output <- provider.Chunk {
Content: delta.Content,
}
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }
} }

View File

@ -8,6 +8,10 @@ import (
type ReplyCallback func(model.Message) type ReplyCallback func(model.Message)
type Chunk struct {
Content string
}
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
@ -26,6 +30,6 @@ type ChatCompletionClient interface {
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, callback ReplyCallback,
output chan<- string, output chan<- Chunk,
) (string, error) ) (string, error)
} }

View File

@ -4,6 +4,7 @@ import (
"time" "time"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/spinner"
@ -16,7 +17,7 @@ import (
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent on each chunk received from LLM // sent on each chunk received from LLM
msgResponseChunk string msgResponseChunk provider.Chunk
// sent when response is finished being received // sent when response is finished being received
msgResponseEnd string msgResponseEnd string
// a special case of common.MsgError that stops the response waiting animation // a special case of common.MsgError that stops the response waiting animation
@ -82,7 +83,7 @@ type Model struct {
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan models.Message replyChan chan models.Message
replyChunkChan chan string replyChunkChan chan provider.Chunk
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // ui state
@ -114,7 +115,7 @@ func Chat(shared shared.Shared) Model {
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan string), replyChunkChan: make(chan provider.Chunk),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,

View File

@ -90,20 +90,19 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
case msgResponseChunk: case msgResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
chunk := string(msg) if msg.Content == "" {
if chunk == "" {
break break
} }
last := len(m.messages) - 1 last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role.IsAssistant() { if last >= 0 && m.messages[last].Role.IsAssistant() {
// append chunk to existing message // append chunk to existing message
m.setMessageContents(last, m.messages[last].Content+chunk) m.setMessageContents(last, m.messages[last].Content+msg.Content)
} else { } else {
// use chunk in new message // use chunk in new message
m.addMessage(models.Message{ m.addMessage(models.Message{
Role: models.MessageRoleAssistant, Role: models.MessageRoleAssistant,
Content: chunk, Content: msg.Content,
}) })
} }
m.updateContent() m.updateContent()