Wrap chunk content in a Chunk type
Preparing to include additional information with each chunk (e.g. token count)
This commit is contained in:
parent
c963747066
commit
d2d946b776
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
)
|
)
|
||||||
@ -17,7 +18,7 @@ import (
|
|||||||
// Prompt prompts the configured the configured model and streams the response
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||||
content := make(chan string) // receives the reponse from LLM
|
content := make(chan provider.Chunk) // receives the reponse from LLM
|
||||||
defer close(content)
|
defer close(content)
|
||||||
|
|
||||||
// render all content received over the channel
|
// render all content received over the channel
|
||||||
@ -251,7 +252,7 @@ func ShowWaitAnimation(signal chan any) {
|
|||||||
// chunked) content is received on the channel, the waiting animation is
|
// chunked) content is received on the channel, the waiting animation is
|
||||||
// replaced by the content.
|
// replaced by the content.
|
||||||
// Blocks until the channel is closed.
|
// Blocks until the channel is closed.
|
||||||
func ShowDelayedContent(content <-chan string) {
|
func ShowDelayedContent(content <-chan provider.Chunk) {
|
||||||
waitSignal := make(chan any)
|
waitSignal := make(chan any)
|
||||||
go ShowWaitAnimation(waitSignal)
|
go ShowWaitAnimation(waitSignal)
|
||||||
|
|
||||||
@ -264,7 +265,7 @@ func ShowDelayedContent(content <-chan string) {
|
|||||||
<-waitSignal
|
<-waitSignal
|
||||||
firstChunk = false
|
firstChunk = false
|
||||||
}
|
}
|
||||||
fmt.Print(chunk)
|
fmt.Print(chunk.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- provider.Chunk,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -242,7 +242,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
return "", fmt.Errorf("invalid text delta")
|
return "", fmt.Errorf("invalid text delta")
|
||||||
}
|
}
|
||||||
sb.WriteString(text)
|
sb.WriteString(text)
|
||||||
output <- text
|
output <- provider.Chunk{
|
||||||
|
Content: text,
|
||||||
|
}
|
||||||
case "content_block_stop":
|
case "content_block_stop":
|
||||||
// ignore?
|
// ignore?
|
||||||
case "message_delta":
|
case "message_delta":
|
||||||
@ -262,7 +264,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||||
output <- FUNCTION_STOP_SEQUENCE
|
output <- provider.Chunk{
|
||||||
|
Content: FUNCTION_STOP_SEQUENCE,
|
||||||
|
}
|
||||||
|
|
||||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- provider.Chunk,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -393,7 +393,9 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
if part.FunctionCall != nil {
|
if part.FunctionCall != nil {
|
||||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
} else if part.Text != "" {
|
} else if part.Text != "" {
|
||||||
output <- part.Text
|
output <- provider.Chunk {
|
||||||
|
Content: part.Text,
|
||||||
|
}
|
||||||
content.WriteString(part.Text)
|
content.WriteString(part.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,7 +132,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- provider.Chunk,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -181,7 +181,9 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(streamResp.Message.Content) > 0 {
|
if len(streamResp.Message.Content) > 0 {
|
||||||
output <- streamResp.Message.Content
|
output <- provider.Chunk{
|
||||||
|
Content: streamResp.Message.Content,
|
||||||
|
}
|
||||||
content.WriteString(streamResp.Message.Content)
|
content.WriteString(streamResp.Message.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,7 +245,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback provider.ReplyCallback,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- provider.Chunk,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return "", fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -319,7 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(delta.Content) > 0 {
|
if len(delta.Content) > 0 {
|
||||||
output <- delta.Content
|
output <- provider.Chunk {
|
||||||
|
Content: delta.Content,
|
||||||
|
}
|
||||||
content.WriteString(delta.Content)
|
content.WriteString(delta.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,10 @@ import (
|
|||||||
|
|
||||||
type ReplyCallback func(model.Message)
|
type ReplyCallback func(model.Message)
|
||||||
|
|
||||||
|
type Chunk struct {
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
|
||||||
type ChatCompletionClient interface {
|
type ChatCompletionClient interface {
|
||||||
// CreateChatCompletion requests a response to the provided messages.
|
// CreateChatCompletion requests a response to the provided messages.
|
||||||
// Replies are appended to the given replies struct, and the
|
// Replies are appended to the given replies struct, and the
|
||||||
@ -26,6 +30,6 @@ type ChatCompletionClient interface {
|
|||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
callback ReplyCallback,
|
callback ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- Chunk,
|
||||||
) (string, error)
|
) (string, error)
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
"github.com/charmbracelet/bubbles/cursor"
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
"github.com/charmbracelet/bubbles/spinner"
|
"github.com/charmbracelet/bubbles/spinner"
|
||||||
@ -16,7 +17,7 @@ import (
|
|||||||
// custom tea.Msg types
|
// custom tea.Msg types
|
||||||
type (
|
type (
|
||||||
// sent on each chunk received from LLM
|
// sent on each chunk received from LLM
|
||||||
msgResponseChunk string
|
msgResponseChunk provider.Chunk
|
||||||
// sent when response is finished being received
|
// sent when response is finished being received
|
||||||
msgResponseEnd string
|
msgResponseEnd string
|
||||||
// a special case of common.MsgError that stops the response waiting animation
|
// a special case of common.MsgError that stops the response waiting animation
|
||||||
@ -82,7 +83,7 @@ type Model struct {
|
|||||||
editorTarget editorTarget
|
editorTarget editorTarget
|
||||||
stopSignal chan struct{}
|
stopSignal chan struct{}
|
||||||
replyChan chan models.Message
|
replyChan chan models.Message
|
||||||
replyChunkChan chan string
|
replyChunkChan chan provider.Chunk
|
||||||
persistence bool // whether we will save new messages in the conversation
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
// ui state
|
// ui state
|
||||||
@ -114,7 +115,7 @@ func Chat(shared shared.Shared) Model {
|
|||||||
|
|
||||||
stopSignal: make(chan struct{}),
|
stopSignal: make(chan struct{}),
|
||||||
replyChan: make(chan models.Message),
|
replyChan: make(chan models.Message),
|
||||||
replyChunkChan: make(chan string),
|
replyChunkChan: make(chan provider.Chunk),
|
||||||
|
|
||||||
wrap: true,
|
wrap: true,
|
||||||
selectedMessage: -1,
|
selectedMessage: -1,
|
||||||
|
@ -90,20 +90,19 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
case msgResponseChunk:
|
case msgResponseChunk:
|
||||||
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
|
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
|
||||||
|
|
||||||
chunk := string(msg)
|
if msg.Content == "" {
|
||||||
if chunk == "" {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
last := len(m.messages) - 1
|
last := len(m.messages) - 1
|
||||||
if last >= 0 && m.messages[last].Role.IsAssistant() {
|
if last >= 0 && m.messages[last].Role.IsAssistant() {
|
||||||
// append chunk to existing message
|
// append chunk to existing message
|
||||||
m.setMessageContents(last, m.messages[last].Content+chunk)
|
m.setMessageContents(last, m.messages[last].Content+msg.Content)
|
||||||
} else {
|
} else {
|
||||||
// use chunk in new message
|
// use chunk in new message
|
||||||
m.addMessage(models.Message{
|
m.addMessage(models.Message{
|
||||||
Role: models.MessageRoleAssistant,
|
Role: models.MessageRoleAssistant,
|
||||||
Content: chunk,
|
Content: msg.Content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
|
Loading…
Reference in New Issue
Block a user