Compare commits

..

6 Commits

Author SHA1 Message Date
3f765234de tui: ability to cancel request in flight 2024-03-12 20:41:34 +00:00
21411c2732 tui: add focus switching between input/messages view 2024-03-12 20:40:49 +00:00
99794addee tui: removed confirm before send, dynamic footer
footer now rendered based on model data, instead of being set to a fixed
string
2024-03-12 20:40:49 +00:00
a47c1a76b4 tui: use ctx chroma highlighter 2024-03-12 20:40:49 +00:00
96fdae982e Add initial TUI 2024-03-12 20:40:44 +00:00
91d3c9c2e1 Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply
messages, it accepts a callback which is called with each successive
reply the conversation.

This gives the caller more flexibility in how it handles replies (e.g.
it can react to them immediately now, instead of waiting for the entire
call to finish)
2024-03-12 20:39:34 +00:00
7 changed files with 80 additions and 60 deletions

View File

@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages) continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
_, err := cmdutil.FetchAndShowCompletion(ctx, messages) _, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }

View File

@ -15,7 +15,7 @@ import (
// fetchAndShowCompletion prompts the LLM with the given messages and streams // fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages. // the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) { func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return nil, err return "", err
} }
requestParams := model.RequestParameters{ requestParams := model.RequestParameters{
@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream( response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, &apiReplies, content, context.Background(), requestParams, messages, callback, content,
) )
if response != "" { if response != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
err = nil err = nil
} }
} }
return response, nil
return apiReplies, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replies, err := FetchAndShowCompletion(ctx, allMessages) replyCallback := func(reply model.Message) {
if err != nil { if !persist {
lmcli.Fatal("Error fetching LLM response: %v\n", err) return
}
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
} }
if persist { _, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
for _, reply := range replies { if err != nil {
reply.ConversationID = c.ID lmcli.Fatal("Error fetching LLM response: %v\n", err)
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
} }
} }

View File

@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion(
default: default:
return "", fmt.Errorf("unsupported message type: %s", content.Type) return "", fmt.Errorf("unsupported message type: %s", content.Type)
} }
*replies = append(*replies, reply) if callback != nil {
callback(reply)
}
} }
return sb.String(), nil return sb.String(), nil
@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ToolResults: toolResults, ToolResults: toolResults,
} }
if replies != nil { if callback != nil {
*replies = append(append(*replies, toolCall), toolReply) callback(toolCall)
callback(toolReply)
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages // added to the original messages
messages = append(append(messages, toolCall), toolReply) messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, replies, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
reply := model.Message{ if callback != nil {
Role: model.MessageRoleAssistant, callback(model.Message{
Content: sb.String(), Role: model.MessageRoleAssistant,
Content: sb.String(),
})
} }
*replies = append(*replies, reply)
return sb.String(), nil return sb.String(), nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
if err != nil { if err != nil {
return "", err return "", err
} }
if results != nil { if callback != nil {
*replies = append(*replies, results...) for _, result := range results {
callback(result)
}
} }
// Recurse into CreateChatCompletion with the tool call replies // Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, replies) return c.CreateChatCompletion(ctx, params, messages, callback)
} }
if replies != nil { if callback != nil {
*replies = append(*replies, model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: choice.Message.Content, Content: choice.Message.Content,
}) })
@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callbback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }
if results != nil {
*replies = append(*replies, results...) if callbback != nil {
for _, result := range results {
callbback(result)
}
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, replies, output) return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
} }
if replies != nil { if callbback != nil {
*replies = append(*replies, model.Message{ callbback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) })

View File

@ -6,6 +6,8 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type ReplyCallback func(model.Message)
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
@ -14,7 +16,7 @@ type ChatCompletionClient interface {
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback ReplyCallback,
) (string, error) ) (string, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
@ -23,7 +25,7 @@ type ChatCompletionClient interface {
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) ) (string, error)
} }

View File

@ -30,15 +30,16 @@ type model struct {
convShortname string convShortname string
// application state // application state
conversation *models.Conversation conversation *models.Conversation
messages []models.Message messages []models.Message
replyChan chan string waitingForReply bool
err error replyChan chan string
replyCancelFunc context.CancelFunc
err error
// ui state // ui state
focus focusState focus focusState
isWaiting bool status string // a general status message
status string // a general status message
// ui elements // ui elements
content viewport.Model content viewport.Model
@ -90,7 +91,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyMsg: case tea.KeyMsg:
switch msg.String() { switch msg.String() {
case "ctrl+c": case "ctrl+c":
return m, tea.Quit if m.waitingForReply {
m.replyCancelFunc()
} else {
return m, tea.Quit
}
case "q": case "q":
if m.focus != focusInput { if m.focus != focusInput {
return m, tea.Quit return m, tea.Quit
@ -135,7 +140,8 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} }
cmd = waitForChunk(m.replyChan) // wait for the next chunk cmd = waitForChunk(m.replyChan) // wait for the next chunk
case msgResponseEnd: case msgResponseEnd:
m.isWaiting = false m.replyCancelFunc = nil
m.waitingForReply = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
} }
@ -184,7 +190,7 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
m.updateContent() m.updateContent()
m.isWaiting = false m.waitingForReply = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
return m return m
} }
@ -217,8 +223,8 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
m.isWaiting = true m.waitingForReply = true
m.status = "Waiting for response... (Press 's' to stop)" m.status = "Waiting for response, press ctrl+c to cancel..."
return m.promptLLM() return m.promptLLM()
} }
return nil return nil
@ -278,9 +284,12 @@ func (m *model) promptLLM() tea.Cmd {
ToolBag: toolBag, ToolBag: toolBag,
} }
var apiReplies []models.Message ctx, replyCancelFunc := context.WithCancel(context.Background())
m.replyCancelFunc = replyCancelFunc
// TODO: supply a reply callback and handle error
resp, _ := completionProvider.CreateChatCompletionStream( resp, _ := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, m.messages, &apiReplies, m.replyChan, ctx, requestParams, m.messages, nil, m.replyChan,
) )
return msgResponseEnd(resp) return msgResponseEnd(resp)
@ -311,7 +320,7 @@ func (m *model) updateContent() {
func (m model) inputView() string { func (m model) inputView() string {
var inputView string var inputView string
if m.isWaiting { if m.waitingForReply {
inputView = inputStyle.Faint(true).Render(m.input.View()) inputView = inputStyle.Faint(true).Render(m.input.View())
} else { } else {
inputView = inputStyle.Render(m.input.View()) inputView = inputStyle.Render(m.input.View())