Compare commits

...

2 Commits

Author SHA1 Message Date
91d3c9c2e1 Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply
messages, it accepts a callback which is called with each successive
reply the conversation.

This gives the caller more flexibility in how it handles replies (e.g.
it can react to them immediately now, instead of waiting for the entire
call to finish)
2024-03-12 20:39:34 +00:00
8bdb155bf7 Update ChatCompletionClient to accept context.Context 2024-03-12 18:24:46 +00:00
6 changed files with 77 additions and 54 deletions

View File

@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err)
}
// Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
// Update the original message
err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
},
}
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err)
}

View File

@ -1,6 +1,7 @@
package util
import (
"context"
"fmt"
"os"
"strings"
@ -14,7 +15,7 @@ import (
// fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM
defer close(content)
@ -23,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil {
return nil, err
return "", err
}
requestParams := model.RequestParameters{
@ -33,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
ToolBag: ctx.EnabledTools,
}
var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream(
requestParams, messages, &apiReplies, content,
context.Background(), requestParams, messages, callback, content,
)
if response != "" {
// there was some content, so break to a new line after it
@ -46,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
err = nil
}
}
return apiReplies, err
return response, nil
}
// lookupConversation either returns the conversation found by the
@ -98,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replies, err := FetchAndShowCompletion(ctx, allMessages)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
replyCallback := func(reply model.Message) {
if !persist {
return
}
if persist {
for _, reply := range replies {
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
}
@ -153,7 +153,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
MaxTokens: 25,
}
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
if err != nil {
return "", err
}

View File

@ -3,14 +3,15 @@ package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
)
type AnthropicClient struct {
@ -102,7 +103,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
return requestBody
}
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r)
@ -110,7 +111,7 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
@ -129,13 +130,14 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
callback provider.ReplyCallback,
) (string, error) {
request := buildRequest(params, messages)
resp, err := sendRequest(c, request)
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
@ -160,22 +162,25 @@ func (c *AnthropicClient) CreateChatCompletion(
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
*replies = append(*replies, reply)
if callback != nil {
callback(reply)
}
}
return sb.String(), nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
callback provider.ReplyCallback,
output chan<- string,
) (string, error) {
request := buildRequest(params, messages)
request.Stream = true
resp, err := sendRequest(c, request)
resp, err := sendRequest(ctx, c, request)
if err != nil {
return "", err
}
@ -288,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ToolResults: toolResults,
}
if replies != nil {
*replies = append(append(*replies, toolCall), toolReply)
if callback != nil {
callback(toolCall)
callback(toolReply)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(params, messages, replies, output)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
}
case "message_stop":
// return the completed message
reply := model.Message{
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
}
*replies = append(*replies, reply)
return sb.String(), nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -9,6 +9,7 @@ import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
)
@ -157,13 +158,14 @@ func handleToolCalls(
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
callback provider.ReplyCallback,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(context.Background(), req)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
@ -176,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
if err != nil {
return "", err
}
if results != nil {
*replies = append(*replies, results...)
if callback != nil {
for _, result := range results {
callback(result)
}
}
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(params, messages, replies)
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if replies != nil {
*replies = append(*replies, model.Message{
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: choice.Message.Content,
})
@ -197,15 +201,16 @@ func (c *OpenAIClient) CreateChatCompletion(
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
callbback provider.ReplyCallback,
output chan<- string,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(context.Background(), req)
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return "", err
}
@ -250,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err != nil {
return content.String(), err
}
if results != nil {
*replies = append(*replies, results...)
if callbback != nil {
for _, result := range results {
callbback(result)
}
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(params, messages, replies, output)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
}
if replies != nil {
*replies = append(*replies, model.Message{
if callbback != nil {
callbback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})

View File

@ -1,23 +1,31 @@
package provider
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
import (
"context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type ReplyCallback func(model.Message)
type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string.
CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
callback ReplyCallback,
) (string, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received.
CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
callback ReplyCallback,
output chan<- string,
) (string, error)
}