From 8e4ff90ab4b60d22edbf8361fd7c97bc3ce0ce84 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 5 May 2024 08:08:17 +0000 Subject: [PATCH] Multiple provider configuration Add support for having multiple openai or anthropic compatible providers accessible via different baseUrls --- pkg/lmcli/config.go | 16 +++---- pkg/lmcli/lmcli.go | 52 ++++++++++++++--------- pkg/lmcli/provider/anthropic/anthropic.go | 4 +- pkg/lmcli/provider/anthropic/types.go | 1 + 4 files changed, 41 insertions(+), 32 deletions(-) diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index b5eb5ab..c281f53 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -19,16 +19,14 @@ type Config struct { TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` } `yaml:"conversations"` Tools *struct { - EnabledTools *[]string `yaml:"enabledTools"` + EnabledTools []string `yaml:"enabledTools"` } `yaml:"tools"` - OpenAI *struct { - APIKey *string `yaml:"apiKey" default:"your_key_here"` - Models *[]string `yaml:"models"` - } `yaml:"openai"` - Anthropic *struct { - APIKey *string `yaml:"apiKey" default:"your_key_here"` - Models *[]string `yaml:"models"` - } `yaml:"anthropic"` + Providers []*struct { + Kind *string `yaml:"kind"` + BaseURL *string `yaml:"baseUrl"` + APIKey *string `yaml:"apiKey"` + Models *[]string `yaml:"models"` + } `yaml:"providers"` Chroma *struct { Style *string `yaml:"style" default:"onedark"` Formatter *string `yaml:"formatter" default:"terminal16m"` diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 97ea288..0f814e6 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -43,7 +43,7 @@ func NewContext() (*Context, error) { chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) var enabledTools []model.Tool - for _, toolName := range *config.Tools.EnabledTools { + for _, toolName := range config.Tools.EnabledTools { tool, ok := tools.AvailableTools[toolName] if ok { enabledTools = append(enabledTools, tool) @@ -54,31 +54,43 @@ func NewContext() (*Context, error) { } func (c *Context) GetModels() (models []string) { - for _, m := range *c.Config.Anthropic.Models { - models = append(models, m) - } - for _, m := range *c.Config.OpenAI.Models { - models = append(models, m) + for _, p := range c.Config.Providers { + for _, m := range *p.Models { + models = append(models, m) + } } return } func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { - for _, m := range *c.Config.Anthropic.Models { - if m == model { - anthropic := &anthropic.AnthropicClient{ - APIKey: *c.Config.Anthropic.APIKey, + for _, p := range c.Config.Providers { + for _, m := range *p.Models { + if m == model { + switch *p.Kind { + case "anthropic": + url := "https://api.anthropic.com/v1" + if p.BaseURL != nil { + url = *p.BaseURL + } + anthropic := &anthropic.AnthropicClient{ + BaseURL: url, + APIKey: *p.APIKey, + } + return anthropic, nil + case "openai": + url := "https://api.openai.com/v1" + if p.BaseURL != nil { + url = *p.BaseURL + } + openai := &openai.OpenAIClient{ + BaseURL: url, + APIKey: *p.APIKey, + } + return openai, nil + default: + return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) + } } - return anthropic, nil - } - } - for _, m := range *c.Config.OpenAI.Models { - if m == model { - openai := &openai.OpenAIClient{ - BaseURL: "https://api.openai.com/v1", - APIKey: *c.Config.OpenAI.APIKey, - } - return openai, nil } } return nil, fmt.Errorf("unknown model: %s", model) diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index ecf6674..8d951ca 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -81,14 +81,12 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ } func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) { - url := "https://api.anthropic.com/v1/messages" - jsonBody, err := json.Marshal(r) if err != nil { return nil, fmt.Errorf("failed to marshal request body: %v", err) } - req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody)) if err != nil { return nil, fmt.Errorf("failed to create HTTP request: %v", err) } diff --git a/pkg/lmcli/provider/anthropic/types.go b/pkg/lmcli/provider/anthropic/types.go index 54b03ed..a1173e4 100644 --- a/pkg/lmcli/provider/anthropic/types.go +++ b/pkg/lmcli/provider/anthropic/types.go @@ -1,6 +1,7 @@ package anthropic type AnthropicClient struct { + BaseURL string APIKey string }