Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply messages, it accepts a callback which is called with each successive reply the conversation. This gives the caller more flexibility in how it handles replies (e.g. it can react to them immediately now, instead of waiting for the entire call to finish)
This commit is contained in:
parent
8bdb155bf7
commit
91d3c9c2e1
@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
fmt.Print(lastMessage.Content)
|
fmt.Print(lastMessage.Content)
|
||||||
|
|
||||||
// Submit the LLM request, allowing it to continue the last message
|
// Submit the LLM request, allowing it to continue the last message
|
||||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the new response to the original message
|
// Append the new response to the original message
|
||||||
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
|
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
|
||||||
|
|
||||||
// Update the original message
|
// Update the original message
|
||||||
err = ctx.Store.UpdateMessage(lastMessage)
|
err = ctx.Store.UpdateMessage(lastMessage)
|
||||||
|
@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
|
|
||||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||||
// the response to stdout. Returns all model reply messages.
|
// the response to stdout. Returns all model reply messages.
|
||||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
|
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||||
content := make(chan string) // receives the reponse from LLM
|
content := make(chan string) // receives the reponse from LLM
|
||||||
defer close(content)
|
defer close(content)
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
|||||||
|
|
||||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := model.RequestParameters{
|
requestParams := model.RequestParameters{
|
||||||
@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
|||||||
ToolBag: ctx.EnabledTools,
|
ToolBag: ctx.EnabledTools,
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiReplies []model.Message
|
|
||||||
response, err := completionProvider.CreateChatCompletionStream(
|
response, err := completionProvider.CreateChatCompletionStream(
|
||||||
context.Background(), requestParams, messages, &apiReplies, content,
|
context.Background(), requestParams, messages, callback, content,
|
||||||
)
|
)
|
||||||
if response != "" {
|
if response != "" {
|
||||||
// there was some content, so break to a new line after it
|
// there was some content, so break to a new line after it
|
||||||
@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return response, nil
|
||||||
return apiReplies, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
// lookupConversation either returns the conversation found by the
|
||||||
@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
|||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||||
|
|
||||||
replies, err := FetchAndShowCompletion(ctx, allMessages)
|
replyCallback := func(reply model.Message) {
|
||||||
if err != nil {
|
if !persist {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if persist {
|
|
||||||
for _, reply := range replies {
|
|
||||||
reply.ConversationID = c.ID
|
reply.ConversationID = c.ID
|
||||||
|
|
||||||
err = ctx.Store.SaveMessage(&reply)
|
err = ctx.Store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback provider.ReplyCallback,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
|
|
||||||
@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||||
}
|
}
|
||||||
*replies = append(*replies, reply)
|
if callback != nil {
|
||||||
|
callback(reply)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
ToolResults: toolResults,
|
ToolResults: toolResults,
|
||||||
}
|
}
|
||||||
|
|
||||||
if replies != nil {
|
if callback != nil {
|
||||||
*replies = append(append(*replies, toolCall), toolReply)
|
callback(toolCall)
|
||||||
|
callback(toolReply)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
// added to the original messages
|
// added to the original messages
|
||||||
messages = append(append(messages, toolCall), toolReply)
|
messages = append(append(messages, toolCall), toolReply)
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// return the completed message
|
// return the completed message
|
||||||
reply := model.Message{
|
if callback != nil {
|
||||||
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: sb.String(),
|
Content: sb.String(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
*replies = append(*replies, reply)
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
case "error":
|
case "error":
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback provider.ReplyCallback,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
client := openai.NewClient(c.APIKey)
|
client := openai.NewClient(c.APIKey)
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if results != nil {
|
if callback != nil {
|
||||||
*replies = append(*replies, results...)
|
for _, result := range results {
|
||||||
|
callback(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletion with the tool call replies
|
// Recurse into CreateChatCompletion with the tool call replies
|
||||||
messages = append(messages, results...)
|
messages = append(messages, results...)
|
||||||
return c.CreateChatCompletion(ctx, params, messages, replies)
|
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
if replies != nil {
|
if callback != nil {
|
||||||
*replies = append(*replies, model.Message{
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: choice.Message.Content,
|
Content: choice.Message.Content,
|
||||||
})
|
})
|
||||||
@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callbback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
client := openai.NewClient(c.APIKey)
|
client := openai.NewClient(c.APIKey)
|
||||||
@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return content.String(), err
|
return content.String(), err
|
||||||
}
|
}
|
||||||
if results != nil {
|
|
||||||
*replies = append(*replies, results...)
|
if callbback != nil {
|
||||||
|
for _, result := range results {
|
||||||
|
callbback(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
messages = append(messages, results...)
|
messages = append(messages, results...)
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
if replies != nil {
|
if callbback != nil {
|
||||||
*replies = append(*replies, model.Message{
|
callbback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: content.String(),
|
Content: content.String(),
|
||||||
})
|
})
|
||||||
|
@ -6,6 +6,8 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ReplyCallback func(model.Message)
|
||||||
|
|
||||||
type ChatCompletionClient interface {
|
type ChatCompletionClient interface {
|
||||||
// CreateChatCompletion requests a response to the provided messages.
|
// CreateChatCompletion requests a response to the provided messages.
|
||||||
// Replies are appended to the given replies struct, and the
|
// Replies are appended to the given replies struct, and the
|
||||||
@ -14,7 +16,7 @@ type ChatCompletionClient interface {
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback ReplyCallback,
|
||||||
) (string, error)
|
) (string, error)
|
||||||
|
|
||||||
// Like CreateChageCompletion, except the response is streamed via
|
// Like CreateChageCompletion, except the response is streamed via
|
||||||
@ -23,7 +25,7 @@ type ChatCompletionClient interface {
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error)
|
) (string, error)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user