From 91d3c9c2e126d39225c0e14129b2b7d1da93ae36 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Tue, 12 Mar 2024 20:36:24 +0000 Subject: [PATCH] Update ChatCompletionClient Instead of CreateChatCompletion* accepting a pointer to a slice of reply messages, it accepts a callback which is called with each successive reply the conversation. This gives the caller more flexibility in how it handles replies (e.g. it can react to them immediately now, instead of waiting for the entire call to finish) --- pkg/cmd/continue.go | 4 +-- pkg/cmd/prompt.go | 2 +- pkg/cmd/util/util.go | 35 +++++++++++------------ pkg/lmcli/provider/anthropic/anthropic.go | 24 +++++++++------- pkg/lmcli/provider/openai/openai.go | 30 +++++++++++-------- pkg/lmcli/provider/provider.go | 6 ++-- 6 files changed, 56 insertions(+), 45 deletions(-) diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index c164fea..6927db2 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { fmt.Print(lastMessage.Content) // Submit the LLM request, allowing it to continue the last message - continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages) + continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) if err != nil { return fmt.Errorf("error fetching LLM response: %v", err) } // Append the new response to the original message - lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ") + lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ") // Update the original message err = ctx.Store.UpdateMessage(lastMessage) diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 4362c29..7e30d47 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { }, } - _, err := cmdutil.FetchAndShowCompletion(ctx, messages) + _, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) if err != nil { return fmt.Errorf("Error fetching LLM response: %v", err) } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index d96cd0d..e5fb923 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -15,7 +15,7 @@ import ( // fetchAndShowCompletion prompts the LLM with the given messages and streams // the response to stdout. Returns all model reply messages. -func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) { +func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { content := make(chan string) // receives the reponse from LLM defer close(content) @@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) if err != nil { - return nil, err + return "", err } requestParams := model.RequestParameters{ @@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod ToolBag: ctx.EnabledTools, } - var apiReplies []model.Message response, err := completionProvider.CreateChatCompletionStream( - context.Background(), requestParams, messages, &apiReplies, content, + context.Background(), requestParams, messages, callback, content, ) if response != "" { // there was some content, so break to a new line after it @@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod err = nil } } - - return apiReplies, err + return response, nil } // lookupConversation either returns the conversation found by the @@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist // render a message header with no contents RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) - replies, err := FetchAndShowCompletion(ctx, allMessages) - if err != nil { - lmcli.Fatal("Error fetching LLM response: %v\n", err) + replyCallback := func(reply model.Message) { + if !persist { + return + } + + reply.ConversationID = c.ID + err = ctx.Store.SaveMessage(&reply) + if err != nil { + lmcli.Warn("Could not save reply: %v\n", err) + } } - if persist { - for _, reply := range replies { - reply.ConversationID = c.ID - - err = ctx.Store.SaveMessage(&reply) - if err != nil { - lmcli.Warn("Could not save reply: %v\n", err) - } - } + _, err = FetchAndShowCompletion(ctx, allMessages, replyCallback) + if err != nil { + lmcli.Fatal("Error fetching LLM response: %v\n", err) } } diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index e58954e..d889428 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback provider.ReplyCallback, ) (string, error) { request := buildRequest(params, messages) @@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion( default: return "", fmt.Errorf("unsupported message type: %s", content.Type) } - *replies = append(*replies, reply) + if callback != nil { + callback(reply) + } } return sb.String(), nil @@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback provider.ReplyCallback, output chan<- string, ) (string, error) { request := buildRequest(params, messages) @@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream( ToolResults: toolResults, } - if replies != nil { - *replies = append(append(*replies, toolCall), toolReply) + if callback != nil { + callback(toolCall) + callback(toolReply) } // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages messages = append(append(messages, toolCall), toolReply) - return c.CreateChatCompletionStream(ctx, params, messages, replies, output) + return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } } case "message_stop": // return the completed message - reply := model.Message{ - Role: model.MessageRoleAssistant, - Content: sb.String(), + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: sb.String(), + }) } - *replies = append(*replies, reply) return sb.String(), nil case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index 35df832..8a01149 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -9,6 +9,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" openai "github.com/sashabaranov/go-openai" ) @@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback provider.ReplyCallback, ) (string, error) { client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) @@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion( if err != nil { return "", err } - if results != nil { - *replies = append(*replies, results...) + if callback != nil { + for _, result := range results { + callback(result) + } } // Recurse into CreateChatCompletion with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletion(ctx, params, messages, replies) + return c.CreateChatCompletion(ctx, params, messages, callback) } - if replies != nil { - *replies = append(*replies, model.Message{ + if callback != nil { + callback(model.Message{ Role: model.MessageRoleAssistant, Content: choice.Message.Content, }) @@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callbback provider.ReplyCallback, output chan<- string, ) (string, error) { client := openai.NewClient(c.APIKey) @@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream( if err != nil { return content.String(), err } - if results != nil { - *replies = append(*replies, results...) + + if callbback != nil { + for _, result := range results { + callbback(result) + } } // Recurse into CreateChatCompletionStream with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletionStream(ctx, params, messages, replies, output) + return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) } - if replies != nil { - *replies = append(*replies, model.Message{ + if callbback != nil { + callbback(model.Message{ Role: model.MessageRoleAssistant, Content: content.String(), }) diff --git a/pkg/lmcli/provider/provider.go b/pkg/lmcli/provider/provider.go index 6c17a69..1946966 100644 --- a/pkg/lmcli/provider/provider.go +++ b/pkg/lmcli/provider/provider.go @@ -6,6 +6,8 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" ) +type ReplyCallback func(model.Message) + type ChatCompletionClient interface { // CreateChatCompletion requests a response to the provided messages. // Replies are appended to the given replies struct, and the @@ -14,7 +16,7 @@ type ChatCompletionClient interface { ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback ReplyCallback, ) (string, error) // Like CreateChageCompletion, except the response is streamed via @@ -23,7 +25,7 @@ type ChatCompletionClient interface { ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback ReplyCallback, output chan<- string, ) (string, error) }