diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index c164fea..6927db2 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { fmt.Print(lastMessage.Content) // Submit the LLM request, allowing it to continue the last message - continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages) + continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) if err != nil { return fmt.Errorf("error fetching LLM response: %v", err) } // Append the new response to the original message - lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ") + lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ") // Update the original message err = ctx.Store.UpdateMessage(lastMessage) diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 4362c29..7e30d47 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { }, } - _, err := cmdutil.FetchAndShowCompletion(ctx, messages) + _, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) if err != nil { return fmt.Errorf("Error fetching LLM response: %v", err) } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index d96cd0d..e5fb923 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -15,7 +15,7 @@ import ( // fetchAndShowCompletion prompts the LLM with the given messages and streams // the response to stdout. Returns all model reply messages. -func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) { +func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { content := make(chan string) // receives the reponse from LLM defer close(content) @@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) if err != nil { - return nil, err + return "", err } requestParams := model.RequestParameters{ @@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod ToolBag: ctx.EnabledTools, } - var apiReplies []model.Message response, err := completionProvider.CreateChatCompletionStream( - context.Background(), requestParams, messages, &apiReplies, content, + context.Background(), requestParams, messages, callback, content, ) if response != "" { // there was some content, so break to a new line after it @@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod err = nil } } - - return apiReplies, err + return response, nil } // lookupConversation either returns the conversation found by the @@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist // render a message header with no contents RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) - replies, err := FetchAndShowCompletion(ctx, allMessages) - if err != nil { - lmcli.Fatal("Error fetching LLM response: %v\n", err) + replyCallback := func(reply model.Message) { + if !persist { + return + } + + reply.ConversationID = c.ID + err = ctx.Store.SaveMessage(&reply) + if err != nil { + lmcli.Warn("Could not save reply: %v\n", err) + } } - if persist { - for _, reply := range replies { - reply.ConversationID = c.ID - - err = ctx.Store.SaveMessage(&reply) - if err != nil { - lmcli.Warn("Could not save reply: %v\n", err) - } - } + _, err = FetchAndShowCompletion(ctx, allMessages, replyCallback) + if err != nil { + lmcli.Fatal("Error fetching LLM response: %v\n", err) } } diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index e58954e..d889428 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback provider.ReplyCallback, ) (string, error) { request := buildRequest(params, messages) @@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion( default: return "", fmt.Errorf("unsupported message type: %s", content.Type) } - *replies = append(*replies, reply) + if callback != nil { + callback(reply) + } } return sb.String(), nil @@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback provider.ReplyCallback, output chan<- string, ) (string, error) { request := buildRequest(params, messages) @@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream( ToolResults: toolResults, } - if replies != nil { - *replies = append(append(*replies, toolCall), toolReply) + if callback != nil { + callback(toolCall) + callback(toolReply) } // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages messages = append(append(messages, toolCall), toolReply) - return c.CreateChatCompletionStream(ctx, params, messages, replies, output) + return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } } case "message_stop": // return the completed message - reply := model.Message{ - Role: model.MessageRoleAssistant, - Content: sb.String(), + if callback != nil { + callback(model.Message{ + Role: model.MessageRoleAssistant, + Content: sb.String(), + }) } - *replies = append(*replies, reply) return sb.String(), nil case "error": return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index 35df832..8a01149 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -9,6 +9,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" openai "github.com/sashabaranov/go-openai" ) @@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback provider.ReplyCallback, ) (string, error) { client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) @@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion( if err != nil { return "", err } - if results != nil { - *replies = append(*replies, results...) + if callback != nil { + for _, result := range results { + callback(result) + } } // Recurse into CreateChatCompletion with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletion(ctx, params, messages, replies) + return c.CreateChatCompletion(ctx, params, messages, callback) } - if replies != nil { - *replies = append(*replies, model.Message{ + if callback != nil { + callback(model.Message{ Role: model.MessageRoleAssistant, Content: choice.Message.Content, }) @@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callbback provider.ReplyCallback, output chan<- string, ) (string, error) { client := openai.NewClient(c.APIKey) @@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream( if err != nil { return content.String(), err } - if results != nil { - *replies = append(*replies, results...) + + if callbback != nil { + for _, result := range results { + callbback(result) + } } // Recurse into CreateChatCompletionStream with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletionStream(ctx, params, messages, replies, output) + return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) } - if replies != nil { - *replies = append(*replies, model.Message{ + if callbback != nil { + callbback(model.Message{ Role: model.MessageRoleAssistant, Content: content.String(), }) diff --git a/pkg/lmcli/provider/provider.go b/pkg/lmcli/provider/provider.go index 6c17a69..1946966 100644 --- a/pkg/lmcli/provider/provider.go +++ b/pkg/lmcli/provider/provider.go @@ -6,6 +6,8 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" ) +type ReplyCallback func(model.Message) + type ChatCompletionClient interface { // CreateChatCompletion requests a response to the provided messages. // Replies are appended to the given replies struct, and the @@ -14,7 +16,7 @@ type ChatCompletionClient interface { ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback ReplyCallback, ) (string, error) // Like CreateChageCompletion, except the response is streamed via @@ -23,7 +25,7 @@ type ChatCompletionClient interface { ctx context.Context, params model.RequestParameters, messages []model.Message, - replies *[]model.Message, + callback ReplyCallback, output chan<- string, ) (string, error) }