diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index c281f53..f15b6ef 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -22,6 +22,7 @@ type Config struct { EnabledTools []string `yaml:"enabledTools"` } `yaml:"tools"` Providers []*struct { + Name *string `yaml:"name"` Kind *string `yaml:"kind"` BaseURL *string `yaml:"baseUrl"` APIKey *string `yaml:"apiKey"` diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index a4d2012..5efba8d 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" @@ -58,16 +59,37 @@ func NewContext() (*Context, error) { } func (c *Context) GetModels() (models []string) { + modelCounts := make(map[string]int) for _, p := range c.Config.Providers { for _, m := range *p.Models { + modelCounts[m]++ + models = append(models, *p.Name+"/"+m) + } + } + + for m, c := range modelCounts { + if c == 1 { models = append(models, m) } } + return } func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { + parts := strings.Split(model, "/") + + var provider string + if len(parts) > 1 { + provider = parts[0] + model = parts[1] + } + for _, p := range c.Config.Providers { + if provider != "" && *p.Name != provider { + continue + } + for _, m := range *p.Models { if m == model { switch *p.Kind {