Private
Public Access
1
0

Fixed handling of long (slash separated) and short model identifiers

Renamed `GetCompletionProvider` to `GetModelProvider` and update it to
return the model's short name (the one to use when making requests)
This commit is contained in:
2024-05-30 19:04:48 +00:00
parent b29a4c8b84
commit 465b1d333e
3 changed files with 15 additions and 15 deletions

View File

@@ -78,7 +78,7 @@ func (c *Context) GetModels() (models []string) {
return
}
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) {
parts := strings.Split(model, "/")
var provider string
@@ -100,7 +100,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil {
url = *p.BaseURL
}
return &anthropic.AnthropicClient{
return model, &anthropic.AnthropicClient{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
@@ -109,7 +109,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil {
url = *p.BaseURL
}
return &google.Client{
return model, &google.Client{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
@@ -118,17 +118,17 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil {
url = *p.BaseURL
}
return &openai.OpenAIClient{
return model, &openai.OpenAIClient{
BaseURL: url,
APIKey: *p.APIKey,
}, nil
default:
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
}
}
}
}
return nil, fmt.Errorf("unknown model: %s", model)
return "", nil, fmt.Errorf("unknown model: %s", model)
}
func (c *Context) GetSystemPrompt() string {