From 8bdb155bf7ba3d977992a4e092d41594ea872552 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Tue, 12 Mar 2024 18:24:05 +0000 Subject: [PATCH] Update ChatCompletionClient to accept context.Context --- pkg/cmd/util/util.go | 5 +++-- pkg/lmcli/provider/anthropic/anthropic.go | 15 +++++++++------ pkg/lmcli/provider/openai/openai.go | 10 ++++++---- pkg/lmcli/provider/provider.go | 8 +++++++- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 88b41ab..d96cd0d 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -1,6 +1,7 @@ package util import ( + "context" "fmt" "os" "strings" @@ -35,7 +36,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod var apiReplies []model.Message response, err := completionProvider.CreateChatCompletionStream( - requestParams, messages, &apiReplies, content, + context.Background(), requestParams, messages, &apiReplies, content, ) if response != "" { // there was some content, so break to a new line after it @@ -153,7 +154,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) { MaxTokens: 25, } - response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil) + response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) if err != nil { return "", err } diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index 9761bd0..e58954e 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -3,14 +3,15 @@ package anthropic import ( "bufio" "bytes" + "context" "encoding/json" "encoding/xml" "fmt" "net/http" "strings" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) type AnthropicClient struct { @@ -102,7 +103,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ return requestBody } -func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) { +func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) { url := "https://api.anthropic.com/v1/messages" jsonBody, err := json.Marshal(r) @@ -110,7 +111,7 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) { return nil, fmt.Errorf("failed to marshal request body: %v", err) } - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %v", err) } @@ -129,13 +130,14 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) { } func (c *AnthropicClient) CreateChatCompletion( + ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, ) (string, error) { request := buildRequest(params, messages) - resp, err := sendRequest(c, request) + resp, err := sendRequest(ctx, c, request) if err != nil { return "", err } @@ -167,6 +169,7 @@ func (c *AnthropicClient) CreateChatCompletion( } func (c *AnthropicClient) CreateChatCompletionStream( + ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, @@ -175,7 +178,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( request := buildRequest(params, messages) request.Stream = true - resp, err := sendRequest(c, request) + resp, err := sendRequest(ctx, c, request) if err != nil { return "", err } @@ -295,7 +298,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( // Recurse into CreateChatCompletionStream with the tool call replies // added to the original messages messages = append(append(messages, toolCall), toolReply) - return c.CreateChatCompletionStream(params, messages, replies, output) + return c.CreateChatCompletionStream(ctx, params, messages, replies, output) } } case "message_stop": diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index b21791d..35df832 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -157,13 +157,14 @@ func handleToolCalls( } func (c *OpenAIClient) CreateChatCompletion( + ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, ) (string, error) { client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) - resp, err := client.CreateChatCompletion(context.Background(), req) + resp, err := client.CreateChatCompletion(ctx, req) if err != nil { return "", err } @@ -182,7 +183,7 @@ func (c *OpenAIClient) CreateChatCompletion( // Recurse into CreateChatCompletion with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletion(params, messages, replies) + return c.CreateChatCompletion(ctx, params, messages, replies) } if replies != nil { @@ -197,6 +198,7 @@ func (c *OpenAIClient) CreateChatCompletion( } func (c *OpenAIClient) CreateChatCompletionStream( + ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, @@ -205,7 +207,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) - stream, err := client.CreateChatCompletionStream(context.Background(), req) + stream, err := client.CreateChatCompletionStream(ctx, req) if err != nil { return "", err } @@ -256,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( // Recurse into CreateChatCompletionStream with the tool call replies messages = append(messages, results...) - return c.CreateChatCompletionStream(params, messages, replies, output) + return c.CreateChatCompletionStream(ctx, params, messages, replies, output) } if replies != nil { diff --git a/pkg/lmcli/provider/provider.go b/pkg/lmcli/provider/provider.go index f0e4ace..6c17a69 100644 --- a/pkg/lmcli/provider/provider.go +++ b/pkg/lmcli/provider/provider.go @@ -1,12 +1,17 @@ package provider -import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" +import ( + "context" + + "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" +) type ChatCompletionClient interface { // CreateChatCompletion requests a response to the provided messages. // Replies are appended to the given replies struct, and the // complete user-facing response is returned as a string. CreateChatCompletion( + ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, @@ -15,6 +20,7 @@ type ChatCompletionClient interface { // Like CreateChageCompletion, except the response is streamed via // the output channel as it's received. CreateChatCompletionStream( + ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message,