Private
Public Access
1
0

Work to simplify model config handling

Keep a direct reference to a provider.ModelConfig in TUI's AppModel,
rather than the names of the provider and model
This commit is contained in:
2025-07-29 01:41:58 +00:00
parent 5335b5c28f
commit 0cf0a4ff0d
5 changed files with 90 additions and 90 deletions

View File

@@ -19,7 +19,7 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) { func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") modelConfig, err := ctx.LookupModelProvider(*ctx.Config.Defaults.Model, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -203,7 +203,7 @@ Example response:
}, },
} }
modelConfig, err := ctx.GetModelProvider( modelConfig, err := ctx.LookupModelProvider(
*ctx.Config.Conversations.TitleGenerationModel, "", *ctx.Config.Conversations.TitleGenerationModel, "",
) )
if err != nil { if err != nil {

View File

@@ -139,9 +139,6 @@ func (c *Context) DefaultSystemPrompt() string {
} }
func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig { func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig {
// Set model name
cfg.Model = m.Name
// Set max tokens // Set max tokens
if m.MaxTokens == nil { if m.MaxTokens == nil {
cfg.MaxTokens = *c.Config.Defaults.MaxTokens cfg.MaxTokens = *c.Config.Defaults.MaxTokens
@@ -164,7 +161,76 @@ func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.
return cfg return cfg
} }
func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) { func (c *Context) GetDefaultModel() (*provider.ModelConfig, error) {
return c.LookupModelProvider(*c.Config.Defaults.Model, "")
}
func (c *Context) GetProviderModels(p *Provider) []provider.ModelConfig {
models := []provider.ModelConfig{}
for _, m := range p.Models() {
var cfg provider.ModelConfig
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
},
}
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &google.Client{
BaseURL: url,
APIKey: p.APIKey,
},
}
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &ollama.OllamaClient{
BaseURL: url,
},
}
case "openai":
url := "https://api.openai.com/v1"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
Headers: p.Headers,
},
}
}
cfg.Model = m.Name
if len(p.Name) > 0 {
cfg.Provider = p.Name
} else {
cfg.Provider = p.Kind
}
c.fillModelConfig(&cfg, m)
models = append(models, cfg)
}
return models
}
func (c *Context) LookupModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")
if providerName == "" && len(parts) > 1 { if providerName == "" && len(parts) > 1 {
@@ -182,61 +248,9 @@ func (c *Context) GetModelProvider(model string, providerName string) (*provider
continue continue
} }
for _, m := range p.Models() { for _, modelConfig := range c.GetProviderModels(p) {
var cfg *provider.ModelConfig if modelConfig.Model == model {
if m.Name == model { return &modelConfig, nil
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = &provider.ModelConfig{
Client: &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
},
}
return c.fillModelConfig(cfg, m), nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &provider.ModelConfig{
Client: &google.Client{
BaseURL: url,
APIKey: p.APIKey,
},
}
return c.fillModelConfig(cfg, m), nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &provider.ModelConfig{
Client: &ollama.OllamaClient{
BaseURL: url,
},
}
return c.fillModelConfig(cfg, m), nil
case "openai":
url := "https://api.openai.com/v1"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &provider.ModelConfig{
Client: &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
Headers: p.Headers,
},
}
return c.fillModelConfig(cfg, m), nil
default:
return nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
}
} }
} }
} }

View File

@@ -18,8 +18,7 @@ type AppModel struct {
Conversations conversation.ConversationList Conversations conversation.ConversationList
Conversation conversation.Conversation Conversation conversation.Conversation
Messages []conversation.Message Messages []conversation.Message
Model string Model provider.ModelConfig
ProviderName string
Provider provider.ChatCompletionProvider Provider provider.ChatCompletionProvider
Agent *lmcli.Agent Agent *lmcli.Agent
@@ -29,7 +28,6 @@ type AppModel struct {
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel { func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{ app := &AppModel{
Ctx: ctx, Ctx: ctx,
Model: *ctx.Config.Defaults.Model,
modifiedMessages: make(map[uint]bool), modifiedMessages: make(map[uint]bool),
} }
@@ -39,9 +37,8 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat
} }
modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") modelConfig, _ := ctx.GetDefaultModel()
app.Model = modelConfig.Model app.Model = *modelConfig
app.ProviderName = modelConfig.Provider
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent) app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app return app
} }
@@ -54,7 +51,7 @@ var (
func (a *AppModel) ActiveModel(style lipgloss.Style) string { func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle) defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle) accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName) return defaultStyle.Render(a.Model.Model) + accentStyle.Render("@") + defaultStyle.Render(a.Model.Provider)
} }
type MessageCycleDirection int type MessageCycleDirection int
@@ -257,13 +254,7 @@ func (a *AppModel) Prompt(
chatReplyChunks chan provider.Chunk, chatReplyChunks chan provider.Chunk,
stopSignal chan struct{}, stopSignal chan struct{},
) (*conversation.Message, error) { ) (*conversation.Message, error) {
modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) params := provider.NewRequestParameters(a.Model)
if err != nil {
return nil, err
}
p := modelConfig.Client
params := provider.NewRequestParameters(*modelConfig)
if a.Agent != nil { if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox params.Toolbox = a.Agent.Toolbox
@@ -278,14 +269,14 @@ func (a *AppModel) Prompt(
} }
}() }()
msg, err := p.CreateChatCompletionStream( msg, err := a.Model.Client.CreateChatCompletionStream(
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks, ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
) )
if msg != nil { if msg != nil {
msg := conversation.MessageFromAPI(*msg) msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName msg.Metadata.GenerationProvider = &a.Model.Provider
msg.Metadata.GenerationModel = &a.Model msg.Metadata.GenerationModel = &a.Model.Model
return &msg, err return &msg, err
} }
return nil, err return nil, err

View File

@@ -261,7 +261,7 @@ func (m *Model) conversationMessagesView() string {
heading := m.renderMessageHeading(-1, &conversation.Message{ heading := m.renderMessageHeading(-1, &conversation.Message{
Role: api.MessageRoleAssistant, Role: api.MessageRoleAssistant,
Metadata: conversation.MessageMeta{ Metadata: conversation.MessageMeta{
GenerationModel: &m.App.Model, GenerationModel: &m.App.Model.Model,
}, },
}) })
sb.WriteString(heading) sb.WriteString(heading)

View File

@@ -3,6 +3,7 @@ package settings
import ( import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list" "git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
"git.mlow.ca/mlow/lmcli/pkg/tui/model" "git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
@@ -21,11 +22,6 @@ type Model struct {
height int height int
} }
type modelOpt struct {
provider string
model string
}
const ( const (
modelListId int = iota + 1 modelListId int = iota + 1
) )
@@ -72,9 +68,8 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
case list.MsgOptionSelected: case list.MsgOptionSelected:
switch msg.ID { switch msg.ID {
case modelListId: case modelListId:
if modelOpt, ok := msg.Option.Value.(modelOpt); ok { if modelConfig, ok := msg.Option.Value.(provider.ModelConfig); ok {
m.App.Model = modelOpt.model m.App.Model = modelConfig
m.App.ProviderName = modelOpt.provider
} }
return m, shared.ChangeView(m.prevView) return m, shared.ChangeView(m.prevView)
} }
@@ -103,10 +98,10 @@ func (m *Model) getModelOptions() []list.OptionGroup {
group := list.OptionGroup{ group := list.OptionGroup{
Name: providerLabel, Name: providerLabel,
} }
for _, model := range p.Models() { for _, model := range m.App.Ctx.GetProviderModels(p) {
group.Options = append(group.Options, list.Option{ group.Options = append(group.Options, list.Option{
Label: model.Name, Label: model.Model,
Value: modelOpt{provider, model.Name}, Value: model,
}) })
} }
modelOpts = append(modelOpts, group) modelOpts = append(modelOpts, group)