From 677cfcfebf2d4f66b7ca7ce1c1f0be852c1018db Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 23 Jun 2024 04:17:53 +0000 Subject: [PATCH] Slight cleanup to openai Remove /v1 from base url, removed some slight repetition --- pkg/api/provider/openai/openai.go | 35 ++++++++++++------------------- pkg/lmcli/lmcli.go | 2 +- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/pkg/api/provider/openai/openai.go b/pkg/api/provider/openai/openai.go index 5e75231..3acd24c 100644 --- a/pkg/api/provider/openai/openai.go +++ b/pkg/api/provider/openai/openai.go @@ -185,7 +185,17 @@ func createChatCompletionRequest( return request } -func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) { +func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) { + jsonData, err := json.Marshal(r) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.APIKey) @@ -213,17 +223,8 @@ func (c *OpenAIClient) CreateChatCompletion( } req := createChatCompletionRequest(params, messages) - jsonData, err := json.Marshal(req) - if err != nil { - return nil, err - } - httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, err - } - - resp, err := c.sendRequest(httpReq) + resp, err := c.sendRequest(ctx, req) if err != nil { return nil, err } @@ -273,17 +274,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( req := createChatCompletionRequest(params, messages) req.Stream = true - jsonData, err := json.Marshal(req) - if err != nil { - return nil, err - } - - httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) - if err != nil { - return nil, err - } - - resp, err := c.sendRequest(httpReq) + resp, err := c.sendRequest(ctx, req) if err != nil { return nil, err } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index db93f82..651088f 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -122,7 +122,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv BaseURL: url, }, nil case "openai": - url := "https://api.openai.com/v1" + url := "https://api.openai.com" if p.BaseURL != nil { url = *p.BaseURL }