Update ChatCompletionClient to accept context.Context
This commit is contained in:
parent
045146bb5c
commit
8bdb155bf7
@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@ -35,7 +36,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
||||
|
||||
var apiReplies []model.Message
|
||||
response, err := completionProvider.CreateChatCompletionStream(
|
||||
requestParams, messages, &apiReplies, content,
|
||||
context.Background(), requestParams, messages, &apiReplies, content,
|
||||
)
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
@ -153,7 +154,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
||||
MaxTokens: 25,
|
||||
}
|
||||
|
||||
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
|
||||
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -3,14 +3,15 @@ package anthropic
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
type AnthropicClient struct {
|
||||
@ -102,7 +103,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
url := "https://api.anthropic.com/v1/messages"
|
||||
|
||||
jsonBody, err := json.Marshal(r)
|
||||
@ -110,7 +111,7 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||
}
|
||||
@ -129,13 +130,14 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
resp, err := sendRequest(c, request)
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -167,6 +169,7 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
@ -175,7 +178,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
request := buildRequest(params, messages)
|
||||
request.Stream = true
|
||||
|
||||
resp, err := sendRequest(c, request)
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -295,7 +298,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, toolCall), toolReply)
|
||||
return c.CreateChatCompletionStream(params, messages, replies, output)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
|
@ -157,13 +157,14 @@ func handleToolCalls(
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -182,7 +183,7 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletion(params, messages, replies)
|
||||
return c.CreateChatCompletion(ctx, params, messages, replies)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
@ -197,6 +198,7 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
@ -205,7 +207,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -256,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletionStream(params, messages, replies, output)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
|
||||
}
|
||||
|
||||
if replies != nil {
|
||||
|
@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type ChatCompletionClient interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
// complete user-facing response is returned as a string.
|
||||
CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
@ -15,6 +20,7 @@ type ChatCompletionClient interface {
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel as it's received.
|
||||
CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
replies *[]model.Message,
|
||||
|
Loading…
Reference in New Issue
Block a user