Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply messages, it accepts a callback which is called with each successive reply the conversation. This gives the caller more flexibility in how it handles replies (e.g. it can react to them immediately now, instead of waiting for the entire call to finish)
This commit is contained in:
@@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
@@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
*replies = append(*replies, reply)
|
||||
if callback != nil {
|
||||
callback(reply)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
@@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
@@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(append(*replies, toolCall), toolReply)
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolReply)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, toolCall), toolReply)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
reply := model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
})
|
||||
}
|
||||
*replies = append(*replies, reply)
|
||||
return sb.String(), nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
@@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
@@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if results != nil {
|
||||
*replies = append(*replies, results...)
|
||||
if callback != nil {
|
||||
for _, result := range results {
|
||||
callback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletion(ctx, params, messages, replies)
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, model.Message{
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: choice.Message.Content,
|
||||
})
|
||||
@@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callbback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
@@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
if results != nil {
|
||||
*replies = append(*replies, results...)
|
||||
|
||||
if callbback != nil {
|
||||
for _, result := range results {
|
||||
callbback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
*replies = append(*replies, model.Message{
|
||||
if callbback != nil {
|
||||
callbback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type ReplyCallback func(model.Message)
|
||||
|
||||
type ChatCompletionClient interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
@@ -14,7 +16,7 @@ type ChatCompletionClient interface {
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback ReplyCallback,
|
||||
) (string, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
@@ -23,7 +25,7 @@ type ChatCompletionClient interface {
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
callback ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user