From 0cf0a4ff0ddea080a9e37ff1b1deaafc2185e8d3 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Tue, 29 Jul 2025 01:41:58 +0000 Subject: [PATCH] Work to simplify model config handling Keep a direct reference to a provider.ModelConfig in TUI's AppModel, rather than the names of the provider and model --- pkg/cmd/util/util.go | 4 +- pkg/lmcli/lmcli.go | 132 ++++++++++++++++------------- pkg/tui/model/model.go | 25 ++---- pkg/tui/views/chat/view.go | 2 +- pkg/tui/views/settings/settings.go | 17 ++-- 5 files changed, 90 insertions(+), 90 deletions(-) diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 493f60b..9d56541 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -19,7 +19,7 @@ import ( // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) { - modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") + modelConfig, err := ctx.LookupModelProvider(*ctx.Config.Defaults.Model, "") if err != nil { return nil, err } @@ -203,7 +203,7 @@ Example response: }, } - modelConfig, err := ctx.GetModelProvider( + modelConfig, err := ctx.LookupModelProvider( *ctx.Config.Conversations.TitleGenerationModel, "", ) if err != nil { diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index f2a9586..5804057 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -139,9 +139,6 @@ func (c *Context) DefaultSystemPrompt() string { } func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig { - // Set model name - cfg.Model = m.Name - // Set max tokens if m.MaxTokens == nil { cfg.MaxTokens = *c.Config.Defaults.MaxTokens @@ -164,7 +161,76 @@ func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider. return cfg } -func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) { +func (c *Context) GetDefaultModel() (*provider.ModelConfig, error) { + return c.LookupModelProvider(*c.Config.Defaults.Model, "") +} + +func (c *Context) GetProviderModels(p *Provider) []provider.ModelConfig { + models := []provider.ModelConfig{} + for _, m := range p.Models() { + var cfg provider.ModelConfig + switch p.Kind { + case "anthropic": + url := "https://api.anthropic.com" + if p.BaseURL != "" { + url = p.BaseURL + } + cfg = provider.ModelConfig{ + Client: &anthropic.AnthropicClient{ + BaseURL: url, + APIKey: p.APIKey, + }, + } + case "google": + url := "https://generativelanguage.googleapis.com" + if p.BaseURL != "" { + url = p.BaseURL + } + cfg = provider.ModelConfig{ + Client: &google.Client{ + BaseURL: url, + APIKey: p.APIKey, + }, + } + case "ollama": + url := "http://localhost:11434/api" + if p.BaseURL != "" { + url = p.BaseURL + } + cfg = provider.ModelConfig{ + Client: &ollama.OllamaClient{ + BaseURL: url, + }, + } + case "openai": + url := "https://api.openai.com/v1" + if p.BaseURL != "" { + url = p.BaseURL + } + cfg = provider.ModelConfig{ + Client: &openai.OpenAIClient{ + BaseURL: url, + APIKey: p.APIKey, + Headers: p.Headers, + }, + } + } + + cfg.Model = m.Name + + if len(p.Name) > 0 { + cfg.Provider = p.Name + } else { + cfg.Provider = p.Kind + } + + c.fillModelConfig(&cfg, m) + models = append(models, cfg) + } + return models +} + +func (c *Context) LookupModelProvider(model string, providerName string) (*provider.ModelConfig, error) { parts := strings.Split(model, "@") if providerName == "" && len(parts) > 1 { @@ -182,61 +248,9 @@ func (c *Context) GetModelProvider(model string, providerName string) (*provider continue } - for _, m := range p.Models() { - var cfg *provider.ModelConfig - if m.Name == model { - switch p.Kind { - case "anthropic": - url := "https://api.anthropic.com" - if p.BaseURL != "" { - url = p.BaseURL - } - cfg = &provider.ModelConfig{ - Client: &anthropic.AnthropicClient{ - BaseURL: url, - APIKey: p.APIKey, - }, - } - return c.fillModelConfig(cfg, m), nil - case "google": - url := "https://generativelanguage.googleapis.com" - if p.BaseURL != "" { - url = p.BaseURL - } - cfg := &provider.ModelConfig{ - Client: &google.Client{ - BaseURL: url, - APIKey: p.APIKey, - }, - } - return c.fillModelConfig(cfg, m), nil - case "ollama": - url := "http://localhost:11434/api" - if p.BaseURL != "" { - url = p.BaseURL - } - cfg := &provider.ModelConfig{ - Client: &ollama.OllamaClient{ - BaseURL: url, - }, - } - return c.fillModelConfig(cfg, m), nil - case "openai": - url := "https://api.openai.com/v1" - if p.BaseURL != "" { - url = p.BaseURL - } - cfg := &provider.ModelConfig{ - Client: &openai.OpenAIClient{ - BaseURL: url, - APIKey: p.APIKey, - Headers: p.Headers, - }, - } - return c.fillModelConfig(cfg, m), nil - default: - return nil, fmt.Errorf("unknown provider kind: %s", p.Kind) - } + for _, modelConfig := range c.GetProviderModels(p) { + if modelConfig.Model == model { + return &modelConfig, nil } } } diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 0cec615..1f108c6 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -18,8 +18,7 @@ type AppModel struct { Conversations conversation.ConversationList Conversation conversation.Conversation Messages []conversation.Message - Model string - ProviderName string + Model provider.ModelConfig Provider provider.ChatCompletionProvider Agent *lmcli.Agent @@ -29,7 +28,6 @@ type AppModel struct { func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel { app := &AppModel{ Ctx: ctx, - Model: *ctx.Config.Defaults.Model, modifiedMessages: make(map[uint]bool), } @@ -39,9 +37,8 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat } - modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") - app.Model = modelConfig.Model - app.ProviderName = modelConfig.Provider + modelConfig, _ := ctx.GetDefaultModel() + app.Model = *modelConfig app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent) return app } @@ -54,7 +51,7 @@ var ( func (a *AppModel) ActiveModel(style lipgloss.Style) string { defaultStyle := style.Inherit(defaultStyle) accentStyle := style.Inherit(accentStyle) - return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName) + return defaultStyle.Render(a.Model.Model) + accentStyle.Render("@") + defaultStyle.Render(a.Model.Provider) } type MessageCycleDirection int @@ -257,13 +254,7 @@ func (a *AppModel) Prompt( chatReplyChunks chan provider.Chunk, stopSignal chan struct{}, ) (*conversation.Message, error) { - modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) - if err != nil { - return nil, err - } - p := modelConfig.Client - - params := provider.NewRequestParameters(*modelConfig) + params := provider.NewRequestParameters(a.Model) if a.Agent != nil { params.Toolbox = a.Agent.Toolbox @@ -278,14 +269,14 @@ func (a *AppModel) Prompt( } }() - msg, err := p.CreateChatCompletionStream( + msg, err := a.Model.Client.CreateChatCompletionStream( ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks, ) if msg != nil { msg := conversation.MessageFromAPI(*msg) - msg.Metadata.GenerationProvider = &a.ProviderName - msg.Metadata.GenerationModel = &a.Model + msg.Metadata.GenerationProvider = &a.Model.Provider + msg.Metadata.GenerationModel = &a.Model.Model return &msg, err } return nil, err diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index b52c650..84fde06 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -261,7 +261,7 @@ func (m *Model) conversationMessagesView() string { heading := m.renderMessageHeading(-1, &conversation.Message{ Role: api.MessageRoleAssistant, Metadata: conversation.MessageMeta{ - GenerationModel: &m.App.Model, + GenerationModel: &m.App.Model.Model, }, }) sb.WriteString(heading) diff --git a/pkg/tui/views/settings/settings.go b/pkg/tui/views/settings/settings.go index ff4fae8..089077d 100644 --- a/pkg/tui/views/settings/settings.go +++ b/pkg/tui/views/settings/settings.go @@ -3,6 +3,7 @@ package settings import ( "strings" + "git.mlow.ca/mlow/lmcli/pkg/provider" "git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" @@ -21,11 +22,6 @@ type Model struct { height int } -type modelOpt struct { - provider string - model string -} - const ( modelListId int = iota + 1 ) @@ -72,9 +68,8 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { case list.MsgOptionSelected: switch msg.ID { case modelListId: - if modelOpt, ok := msg.Option.Value.(modelOpt); ok { - m.App.Model = modelOpt.model - m.App.ProviderName = modelOpt.provider + if modelConfig, ok := msg.Option.Value.(provider.ModelConfig); ok { + m.App.Model = modelConfig } return m, shared.ChangeView(m.prevView) } @@ -103,10 +98,10 @@ func (m *Model) getModelOptions() []list.OptionGroup { group := list.OptionGroup{ Name: providerLabel, } - for _, model := range p.Models() { + for _, model := range m.App.Ctx.GetProviderModels(p) { group.Options = append(group.Options, list.Option{ - Label: model.Name, - Value: modelOpt{provider, model.Name}, + Label: model.Model, + Value: model, }) } modelOpts = append(modelOpts, group)