Fixed handling of long (slash separated) and short model identifiers

Renamed `GetCompletionProvider` to `GetModelProvider` and update it to
return the model's short name (the one to use when making requests)
This commit is contained in:
Matt Low 2024-05-30 19:04:48 +00:00
parent b29a4c8b84
commit 465b1d333e
3 changed files with 15 additions and 15 deletions

View File

@ -23,19 +23,19 @@ func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Me
// render all content received over the channel // render all content received over the channel
go ShowDelayedContent(content) go ShowDelayedContent(content)
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return "", err return "", err
} }
requestParams := model.RequestParameters{ requestParams := model.RequestParameters{
Model: *ctx.Config.Defaults.Model, Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens, MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature, Temperature: *ctx.Config.Defaults.Temperature,
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
response, err := completionProvider.CreateChatCompletionStream( response, err := provider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content, context.Background(), requestParams, messages, callback, content,
) )
if response != "" { if response != "" {
@ -187,17 +187,17 @@ Example response:
}, },
} }
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel) m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel)
if err != nil { if err != nil {
return "", err return "", err
} }
requestParams := model.RequestParameters{ requestParams := model.RequestParameters{
Model: *ctx.Config.Conversations.TitleGenerationModel, Model: m,
MaxTokens: 25, MaxTokens: 25,
} }
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -78,7 +78,7 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) {
parts := strings.Split(model, "/") parts := strings.Split(model, "/")
var provider string var provider string
@ -100,7 +100,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
return &anthropic.AnthropicClient{ return model, &anthropic.AnthropicClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil }, nil
@ -109,7 +109,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
return &google.Client{ return model, &google.Client{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil }, nil
@ -118,17 +118,17 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
if p.BaseURL != nil { if p.BaseURL != nil {
url = *p.BaseURL url = *p.BaseURL
} }
return &openai.OpenAIClient{ return model, &openai.OpenAIClient{
BaseURL: url, BaseURL: url,
APIKey: *p.APIKey, APIKey: *p.APIKey,
}, nil }, nil
default: default:
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
} }
} }
} }
} }
return nil, fmt.Errorf("unknown model: %s", model) return "", nil, fmt.Errorf("unknown model: %s", model)
} }
func (c *Context) GetSystemPrompt() string { func (c *Context) GetSystemPrompt() string {

View File

@ -1051,13 +1051,13 @@ func (m *Model) promptLLM() tea.Cmd {
m.elapsed = 0 m.elapsed = 0
return func() tea.Msg { return func() tea.Msg {
completionProvider, err := m.State.Ctx.GetCompletionProvider(*m.State.Ctx.Config.Defaults.Model) model, provider, err := m.State.Ctx.GetModelProvider(*m.State.Ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
requestParams := models.RequestParameters{ requestParams := models.RequestParameters{
Model: *m.State.Ctx.Config.Defaults.Model, Model: model,
MaxTokens: *m.State.Ctx.Config.Defaults.MaxTokens, MaxTokens: *m.State.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.State.Ctx.Config.Defaults.Temperature, Temperature: *m.State.Ctx.Config.Defaults.Temperature,
ToolBag: m.State.Ctx.EnabledTools, ToolBag: m.State.Ctx.EnabledTools,
@ -1078,7 +1078,7 @@ func (m *Model) promptLLM() tea.Cmd {
} }
}() }()
resp, err := completionProvider.CreateChatCompletionStream( resp, err := provider.CreateChatCompletionStream(
ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan, ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan,
) )