Compare commits

..

4 Commits

Author SHA1 Message Date
3f7f34812f tui: add focus switching between input/messages view 2024-03-12 18:26:03 +00:00
98e92d1ff4 tui: removed confirm before send, dynamic footer
footer now rendered based on model data, instead of being set to a fixed
string
2024-03-12 18:26:03 +00:00
e23dc17555 tui: use ctx chroma highlighter 2024-03-12 18:26:03 +00:00
e0cc97e177 Add initial TUI 2024-03-12 18:26:03 +00:00
7 changed files with 60 additions and 80 deletions

View File

@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) _, err := cmdutil.FetchAndShowCompletion(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }

View File

@ -15,7 +15,7 @@ import (
// fetchAndShowCompletion prompts the LLM with the given messages and streams // fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages. // the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return "", err return nil, err
} }
requestParams := model.RequestParameters{ requestParams := model.RequestParameters{
@ -34,8 +34,9 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream( response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content, context.Background(), requestParams, messages, &apiReplies, content,
) )
if response != "" { if response != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
@ -46,7 +47,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
err = nil err = nil
} }
} }
return response, nil
return apiReplies, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
@ -97,22 +99,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replyCallback := func(reply model.Message) { replies, err := FetchAndShowCompletion(ctx, allMessages)
if !persist {
return
}
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
if persist {
for _, reply := range replies {
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
}
} }
func FormatForExternalPrompt(messages []model.Message, system bool) string { func FormatForExternalPrompt(messages []model.Message, system bool) string {

View File

@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, replies *[]model.Message,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -162,9 +162,7 @@ func (c *AnthropicClient) CreateChatCompletion(
default: default:
return "", fmt.Errorf("unsupported message type: %s", content.Type) return "", fmt.Errorf("unsupported message type: %s", content.Type)
} }
if callback != nil { *replies = append(*replies, reply)
callback(reply)
}
} }
return sb.String(), nil return sb.String(), nil
@ -174,7 +172,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, replies *[]model.Message,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -293,25 +291,23 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ToolResults: toolResults, ToolResults: toolResults,
} }
if callback != nil { if replies != nil {
callback(toolCall) *replies = append(append(*replies, toolCall), toolReply)
callback(toolReply)
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages // added to the original messages
messages = append(append(messages, toolCall), toolReply) messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
if callback != nil { reply := model.Message{
callback(model.Message{ Role: model.MessageRoleAssistant,
Role: model.MessageRoleAssistant, Content: sb.String(),
Content: sb.String(),
})
} }
*replies = append(*replies, reply)
return sb.String(), nil return sb.String(), nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -9,7 +9,6 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
@ -161,7 +160,7 @@ func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, replies *[]model.Message,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
@ -178,19 +177,17 @@ func (c *OpenAIClient) CreateChatCompletion(
if err != nil { if err != nil {
return "", err return "", err
} }
if callback != nil { if results != nil {
for _, result := range results { *replies = append(*replies, results...)
callback(result)
}
} }
// Recurse into CreateChatCompletion with the tool call replies // Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback) return c.CreateChatCompletion(ctx, params, messages, replies)
} }
if callback != nil { if replies != nil {
callback(model.Message{ *replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: choice.Message.Content, Content: choice.Message.Content,
}) })
@ -204,7 +201,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callbback provider.ReplyCallback, replies *[]model.Message,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
@ -255,20 +252,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }
if results != nil {
if callbback != nil { *replies = append(*replies, results...)
for _, result := range results {
callbback(result)
}
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
} }
if callbback != nil { if replies != nil {
callbback(model.Message{ *replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) })

View File

@ -6,8 +6,6 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type ReplyCallback func(model.Message)
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
@ -16,7 +14,7 @@ type ChatCompletionClient interface {
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, replies *[]model.Message,
) (string, error) ) (string, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
@ -25,7 +23,7 @@ type ChatCompletionClient interface {
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, replies *[]model.Message,
output chan<- string, output chan<- string,
) (string, error) ) (string, error)
} }

View File

@ -30,16 +30,15 @@ type model struct {
convShortname string convShortname string
// application state // application state
conversation *models.Conversation conversation *models.Conversation
messages []models.Message messages []models.Message
waitingForReply bool replyChan chan string
replyChan chan string err error
replyCancelFunc context.CancelFunc
err error
// ui state // ui state
focus focusState focus focusState
status string // a general status message isWaiting bool
status string // a general status message
// ui elements // ui elements
content viewport.Model content viewport.Model
@ -91,11 +90,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyMsg: case tea.KeyMsg:
switch msg.String() { switch msg.String() {
case "ctrl+c": case "ctrl+c":
if m.waitingForReply { return m, tea.Quit
m.replyCancelFunc()
} else {
return m, tea.Quit
}
case "q": case "q":
if m.focus != focusInput { if m.focus != focusInput {
return m, tea.Quit return m, tea.Quit
@ -140,8 +135,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} }
cmd = waitForChunk(m.replyChan) // wait for the next chunk cmd = waitForChunk(m.replyChan) // wait for the next chunk
case msgResponseEnd: case msgResponseEnd:
m.replyCancelFunc = nil m.isWaiting = false
m.waitingForReply = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
} }
@ -190,7 +184,7 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
m.updateContent() m.updateContent()
m.waitingForReply = false m.isWaiting = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
return m return m
} }
@ -223,8 +217,8 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
m.updateContent() m.updateContent()
m.content.GotoBottom() m.content.GotoBottom()
m.waitingForReply = true m.isWaiting = true
m.status = "Waiting for response, press ctrl+c to cancel..." m.status = "Waiting for response... (Press 's' to stop)"
return m.promptLLM() return m.promptLLM()
} }
return nil return nil
@ -284,12 +278,9 @@ func (m *model) promptLLM() tea.Cmd {
ToolBag: toolBag, ToolBag: toolBag,
} }
ctx, replyCancelFunc := context.WithCancel(context.Background()) var apiReplies []models.Message
m.replyCancelFunc = replyCancelFunc
// TODO: supply a reply callback and handle error
resp, _ := completionProvider.CreateChatCompletionStream( resp, _ := completionProvider.CreateChatCompletionStream(
ctx, requestParams, m.messages, nil, m.replyChan, context.Background(), requestParams, m.messages, &apiReplies, m.replyChan,
) )
return msgResponseEnd(resp) return msgResponseEnd(resp)
@ -320,7 +311,7 @@ func (m *model) updateContent() {
func (m model) inputView() string { func (m model) inputView() string {
var inputView string var inputView string
if m.waitingForReply { if m.isWaiting {
inputView = inputStyle.Faint(true).Render(m.input.View()) inputView = inputStyle.Faint(true).Render(m.input.View())
} else { } else {
inputView = inputStyle.Render(m.input.View()) inputView = inputStyle.Render(m.input.View())