diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index cd34895..40cbe92 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -23,19 +23,19 @@ func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Me // render all content received over the channel go ShowDelayedContent(content) - completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) + m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) if err != nil { return "", err } requestParams := model.RequestParameters{ - Model: *ctx.Config.Defaults.Model, + Model: m, MaxTokens: *ctx.Config.Defaults.MaxTokens, Temperature: *ctx.Config.Defaults.Temperature, ToolBag: ctx.EnabledTools, } - response, err := completionProvider.CreateChatCompletionStream( + response, err := provider.CreateChatCompletionStream( context.Background(), requestParams, messages, callback, content, ) if response != "" { @@ -187,17 +187,17 @@ Example response: }, } - completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel) + m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel) if err != nil { return "", err } requestParams := model.RequestParameters{ - Model: *ctx.Config.Conversations.TitleGenerationModel, + Model: m, MaxTokens: 25, } - response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) + response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) if err != nil { return "", err } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 814be7a..f47b8d9 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -78,7 +78,7 @@ func (c *Context) GetModels() (models []string) { return } -func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) { +func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) { parts := strings.Split(model, "/") var provider string @@ -100,7 +100,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl if p.BaseURL != nil { url = *p.BaseURL } - return &anthropic.AnthropicClient{ + return model, &anthropic.AnthropicClient{ BaseURL: url, APIKey: *p.APIKey, }, nil @@ -109,7 +109,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl if p.BaseURL != nil { url = *p.BaseURL } - return &google.Client{ + return model, &google.Client{ BaseURL: url, APIKey: *p.APIKey, }, nil @@ -118,17 +118,17 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl if p.BaseURL != nil { url = *p.BaseURL } - return &openai.OpenAIClient{ + return model, &openai.OpenAIClient{ BaseURL: url, APIKey: *p.APIKey, }, nil default: - return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) + return "", nil, fmt.Errorf("unknown provider kind: %s", *p.Kind) } } } } - return nil, fmt.Errorf("unknown model: %s", model) + return "", nil, fmt.Errorf("unknown model: %s", model) } func (c *Context) GetSystemPrompt() string { diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index cd69232..c8185f2 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -1051,13 +1051,13 @@ func (m *Model) promptLLM() tea.Cmd { m.elapsed = 0 return func() tea.Msg { - completionProvider, err := m.State.Ctx.GetCompletionProvider(*m.State.Ctx.Config.Defaults.Model) + model, provider, err := m.State.Ctx.GetModelProvider(*m.State.Ctx.Config.Defaults.Model) if err != nil { return shared.MsgError(err) } requestParams := models.RequestParameters{ - Model: *m.State.Ctx.Config.Defaults.Model, + Model: model, MaxTokens: *m.State.Ctx.Config.Defaults.MaxTokens, Temperature: *m.State.Ctx.Config.Defaults.Temperature, ToolBag: m.State.Ctx.EnabledTools, @@ -1078,7 +1078,7 @@ func (m *Model) promptLLM() tea.Cmd { } }() - resp, err := completionProvider.CreateChatCompletionStream( + resp, err := provider.CreateChatCompletionStream( ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan, )