Compare commits
6 Commits
3f7f34812f
...
3f765234de
Author | SHA1 | Date | |
---|---|---|---|
3f765234de | |||
21411c2732 | |||
99794addee | |||
a47c1a76b4 | |||
96fdae982e | |||
91d3c9c2e1 |
@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
fmt.Print(lastMessage.Content)
|
||||
|
||||
// Submit the LLM request, allowing it to continue the last message
|
||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||
}
|
||||
|
||||
// Append the new response to the original message
|
||||
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
|
||||
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
|
||||
|
||||
// Update the original message
|
||||
err = ctx.Store.UpdateMessage(lastMessage)
|
||||
|
@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
},
|
||||
}
|
||||
|
||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
|
||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||
// the response to stdout. Returns all model reply messages.
|
||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
|
||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||
content := make(chan string) // receives the reponse from LLM
|
||||
defer close(content)
|
||||
|
||||
@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
|
||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestParams := model.RequestParameters{
|
||||
@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
ToolBag: ctx.EnabledTools,
|
||||
}
|
||||
|
||||
var apiReplies []model.Message
|
||||
response, err := completionProvider.CreateChatCompletionStream(
|
||||
context.Background(), requestParams, messages, &apiReplies, content,
|
||||
context.Background(), requestParams, messages, callback, content,
|
||||
)
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return apiReplies, err
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// lookupConversation either returns the conversation found by the
|
||||
@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
||||
// render a message header with no contents
|
||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||
|
||||
replies, err := FetchAndShowCompletion(ctx, allMessages)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
replyCallback := func(reply model.Message) {
|
||||
if !persist {
|
||||
return
|
||||
}
|
||||
|
||||
reply.ConversationID = c.ID
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = c.ID
|
||||
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
*replies = append(*replies, reply)
|
||||
if callback != nil {
|
||||
callback(reply)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, toolCall), toolReply)
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolReply)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, toolCall), toolReply)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
reply := model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
})
|
||||
}
|
||||
*replies = append(*replies, reply)
|
||||
return sb.String(), nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if results != nil {
|
||||
*replies = append(*replies, results...)
|
||||
if callback != nil {
|
||||
for _, result := range results {
|
||||
callback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletion(ctx, params, messages, replies)
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, model.Message{
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: choice.Message.Content,
|
||||
})
|
||||
@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callbback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
if results != nil {
|
||||
*replies = append(*replies, results...)
|
||||
|
||||
if callbback != nil {
|
||||
for _, result := range results {
|
||||
callbback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, model.Message{
|
||||
if callbback != nil {
|
||||
callbback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type ReplyCallback func(model.Message)
|
||||
|
||||
type ChatCompletionClient interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
@ -14,7 +16,7 @@ type ChatCompletionClient interface {
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback ReplyCallback,
|
||||
) (string, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
@ -23,7 +25,7 @@ type ChatCompletionClient interface {
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error)
|
||||
}
|
||||
|
@ -30,15 +30,16 @@ type model struct {
|
||||
convShortname string
|
||||
|
||||
// application state
|
||||
conversation *models.Conversation
|
||||
messages []models.Message
|
||||
replyChan chan string
|
||||
err error
|
||||
conversation *models.Conversation
|
||||
messages []models.Message
|
||||
waitingForReply bool
|
||||
replyChan chan string
|
||||
replyCancelFunc context.CancelFunc
|
||||
err error
|
||||
|
||||
// ui state
|
||||
focus focusState
|
||||
isWaiting bool
|
||||
status string // a general status message
|
||||
focus focusState
|
||||
status string // a general status message
|
||||
|
||||
// ui elements
|
||||
content viewport.Model
|
||||
@ -90,7 +91,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c":
|
||||
return m, tea.Quit
|
||||
if m.waitingForReply {
|
||||
m.replyCancelFunc()
|
||||
} else {
|
||||
return m, tea.Quit
|
||||
}
|
||||
case "q":
|
||||
if m.focus != focusInput {
|
||||
return m, tea.Quit
|
||||
@ -135,7 +140,8 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
cmd = waitForChunk(m.replyChan) // wait for the next chunk
|
||||
case msgResponseEnd:
|
||||
m.isWaiting = false
|
||||
m.replyCancelFunc = nil
|
||||
m.waitingForReply = false
|
||||
m.status = "Press ctrl+s to send"
|
||||
}
|
||||
|
||||
@ -184,7 +190,7 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
|
||||
|
||||
m.updateContent()
|
||||
|
||||
m.isWaiting = false
|
||||
m.waitingForReply = false
|
||||
m.status = "Press ctrl+s to send"
|
||||
return m
|
||||
}
|
||||
@ -217,8 +223,8 @@ func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
||||
m.updateContent()
|
||||
m.content.GotoBottom()
|
||||
|
||||
m.isWaiting = true
|
||||
m.status = "Waiting for response... (Press 's' to stop)"
|
||||
m.waitingForReply = true
|
||||
m.status = "Waiting for response, press ctrl+c to cancel..."
|
||||
return m.promptLLM()
|
||||
}
|
||||
return nil
|
||||
@ -278,9 +284,12 @@ func (m *model) promptLLM() tea.Cmd {
|
||||
ToolBag: toolBag,
|
||||
}
|
||||
|
||||
var apiReplies []models.Message
|
||||
ctx, replyCancelFunc := context.WithCancel(context.Background())
|
||||
m.replyCancelFunc = replyCancelFunc
|
||||
|
||||
// TODO: supply a reply callback and handle error
|
||||
resp, _ := completionProvider.CreateChatCompletionStream(
|
||||
context.Background(), requestParams, m.messages, &apiReplies, m.replyChan,
|
||||
ctx, requestParams, m.messages, nil, m.replyChan,
|
||||
)
|
||||
|
||||
return msgResponseEnd(resp)
|
||||
@ -311,7 +320,7 @@ func (m *model) updateContent() {
|
||||
|
||||
func (m model) inputView() string {
|
||||
var inputView string
|
||||
if m.isWaiting {
|
||||
if m.waitingForReply {
|
||||
inputView = inputStyle.Faint(true).Render(m.input.View())
|
||||
} else {
|
||||
inputView = inputStyle.Render(m.input.View())
|
||||
|
Loading…
Reference in New Issue
Block a user