Add name prefix and / separator (e.g. anthropic/claude-3-haiku...)

This commit is contained in:
Matt Low 2024-05-19 02:38:47 +00:00
parent a291e7b42c
commit 1bd953676d
2 changed files with 23 additions and 0 deletions

View File

@ -22,6 +22,7 @@ type Config struct {
EnabledTools []string `yaml:"enabledTools"` EnabledTools []string `yaml:"enabledTools"`
} `yaml:"tools"` } `yaml:"tools"`
Providers []*struct { Providers []*struct {
Name *string `yaml:"name"`
Kind *string `yaml:"kind"` Kind *string `yaml:"kind"`
BaseURL *string `yaml:"baseUrl"` BaseURL *string `yaml:"baseUrl"`
APIKey *string `yaml:"apiKey"` APIKey *string `yaml:"apiKey"`

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
@ -58,16 +59,37 @@ func NewContext() (*Context, error) {
} }
func (c *Context) GetModels() (models []string) { func (c *Context) GetModels() (models []string) {
modelCounts := make(map[string]int)
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
for _, m := range *p.Models { for _, m := range *p.Models {
modelCounts[m]++
models = append(models, *p.Name+"/"+m)
}
}
for m, c := range modelCounts {
if c == 1 {
models = append(models, m) models = append(models, m)
} }
} }
return return
} }
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
parts := strings.Split(model, "/")
var provider string
if len(parts) > 1 {
provider = parts[0]
model = parts[1]
}
for _, p := range c.Config.Providers { for _, p := range c.Config.Providers {
if provider != "" && *p.Name != provider {
continue
}
for _, m := range *p.Models { for _, m := range *p.Models {
if m == model { if m == model {
switch *p.Kind { switch *p.Kind {