Update ChatCompletionClient to accept context.Context

This commit is contained in:
Matt Low 2024-03-12 18:24:05 +00:00
parent 045146bb5c
commit 8bdb155bf7
4 changed files with 25 additions and 13 deletions

View File

@ -1,6 +1,7 @@
package util package util
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@ -35,7 +36,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
var apiReplies []model.Message var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream( response, err := completionProvider.CreateChatCompletionStream(
requestParams, messages, &apiReplies, content, context.Background(), requestParams, messages, &apiReplies, content,
) )
if response != "" { if response != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
@ -153,7 +154,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
MaxTokens: 25, MaxTokens: 25,
} }
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil) response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -3,14 +3,15 @@ package anthropic
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
type AnthropicClient struct { type AnthropicClient struct {
@ -102,7 +103,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
return requestBody return requestBody
} }
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) { func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages" url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r) jsonBody, err := json.Marshal(r)
@ -110,7 +111,7 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
return nil, fmt.Errorf("failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %v", err)
} }
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err) return nil, fmt.Errorf("failed to create HTTP request: %v", err)
} }
@ -129,13 +130,14 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
} }
func (c *AnthropicClient) CreateChatCompletion( func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, replies *[]model.Message,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
resp, err := sendRequest(c, request) resp, err := sendRequest(ctx, c, request)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -167,6 +169,7 @@ func (c *AnthropicClient) CreateChatCompletion(
} }
func (c *AnthropicClient) CreateChatCompletionStream( func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, replies *[]model.Message,
@ -175,7 +178,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
request := buildRequest(params, messages) request := buildRequest(params, messages)
request.Stream = true request.Stream = true
resp, err := sendRequest(c, request) resp, err := sendRequest(ctx, c, request)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -295,7 +298,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages // added to the original messages
messages = append(append(messages, toolCall), toolReply) messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(params, messages, replies, output) return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
} }
} }
case "message_stop": case "message_stop":

View File

@ -157,13 +157,14 @@ func handleToolCalls(
} }
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, replies *[]model.Message,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(context.Background(), req) resp, err := client.CreateChatCompletion(ctx, req)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -182,7 +183,7 @@ func (c *OpenAIClient) CreateChatCompletion(
// Recurse into CreateChatCompletion with the tool call replies // Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletion(params, messages, replies) return c.CreateChatCompletion(ctx, params, messages, replies)
} }
if replies != nil { if replies != nil {
@ -197,6 +198,7 @@ func (c *OpenAIClient) CreateChatCompletion(
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, replies *[]model.Message,
@ -205,7 +207,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(context.Background(), req) stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -256,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(params, messages, replies, output) return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
} }
if replies != nil { if replies != nil {

View File

@ -1,12 +1,17 @@
package provider package provider
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" import (
"context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string. // complete user-facing response is returned as a string.
CreateChatCompletion( CreateChatCompletion(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, replies *[]model.Message,
@ -15,6 +20,7 @@ type ChatCompletionClient interface {
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received. // the output channel as it's received.
CreateChatCompletionStream( CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, replies *[]model.Message,