From 434fc4672b0840d510d6f1ff939aa433f4eb9b2c Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 12 Aug 2024 17:14:53 +0000 Subject: [PATCH] Allow custom headers on OpenAI providers (to be added to more later) --- pkg/api/provider/openai/openai.go | 4 ++++ pkg/lmcli/config.go | 11 ++++++----- pkg/lmcli/lmcli.go | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pkg/api/provider/openai/openai.go b/pkg/api/provider/openai/openai.go index fecb76f..6c376c8 100644 --- a/pkg/api/provider/openai/openai.go +++ b/pkg/api/provider/openai/openai.go @@ -16,6 +16,7 @@ import ( type OpenAIClient struct { APIKey string BaseURL string + Headers map[string]string } type ChatCompletionMessage struct { @@ -198,6 +199,9 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.APIKey) + for header, val := range c.Headers { + req.Header.Set(header, val) + } client := &http.Client{} resp, err := client.Do(req) diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index d6a1461..f1949b1 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -31,11 +31,12 @@ type Config struct { Tools []string `yaml:"tools"` } `yaml:"agents"` Providers []*struct { - Name string `yaml:"name,omitempty"` - Kind string `yaml:"kind"` - BaseURL string `yaml:"baseUrl,omitempty"` - APIKey string `yaml:"apiKey,omitempty"` - Models []string `yaml:"models"` + Name string `yaml:"name,omitempty"` + Kind string `yaml:"kind"` + BaseURL string `yaml:"baseUrl,omitempty"` + APIKey string `yaml:"apiKey,omitempty"` + Models []string `yaml:"models"` + Headers map[string]string `yaml:"headers"` } `yaml:"providers"` } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 5e6c620..e5f6e8c 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -179,6 +179,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv return model, &openai.OpenAIClient{ BaseURL: url, APIKey: p.APIKey, + Headers: p.Headers, }, nil default: return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)