Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply messages, it accepts a callback which is called with each successive reply the conversation. This gives the caller more flexibility in how it handles replies (e.g. it can react to them immediately now, instead of waiting for the entire call to finish)
This commit is contained in:
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||
// the response to stdout. Returns all model reply messages.
|
||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
|
||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||
content := make(chan string) // receives the reponse from LLM
|
||||
defer close(content)
|
||||
|
||||
@@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
|
||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestParams := model.RequestParameters{
|
||||
@@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
ToolBag: ctx.EnabledTools,
|
||||
}
|
||||
|
||||
var apiReplies []model.Message
|
||||
response, err := completionProvider.CreateChatCompletionStream(
|
||||
context.Background(), requestParams, messages, &apiReplies, content,
|
||||
context.Background(), requestParams, messages, callback, content,
|
||||
)
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
@@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return apiReplies, err
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// lookupConversation either returns the conversation found by the
|
||||
@@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
||||
// render a message header with no contents
|
||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||
|
||||
replies, err := FetchAndShowCompletion(ctx, allMessages)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
replyCallback := func(reply model.Message) {
|
||||
if !persist {
|
||||
return
|
||||
}
|
||||
|
||||
reply.ConversationID = c.ID
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, reply := range replies {
|
||||
reply.ConversationID = c.ID
|
||||
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user