Private
Public Access
1
0

Work to simplify model config handling

Keep a direct reference to a provider.ModelConfig in TUI's AppModel,
rather than the names of the provider and model
This commit is contained in:
2025-07-29 01:41:58 +00:00
parent 5335b5c28f
commit 0cf0a4ff0d
5 changed files with 90 additions and 90 deletions

View File

@@ -19,7 +19,7 @@ import (
// Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
modelConfig, err := ctx.LookupModelProvider(*ctx.Config.Defaults.Model, "")
if err != nil {
return nil, err
}
@@ -203,7 +203,7 @@ Example response:
},
}
modelConfig, err := ctx.GetModelProvider(
modelConfig, err := ctx.LookupModelProvider(
*ctx.Config.Conversations.TitleGenerationModel, "",
)
if err != nil {

View File

@@ -139,9 +139,6 @@ func (c *Context) DefaultSystemPrompt() string {
}
func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig {
// Set model name
cfg.Model = m.Name
// Set max tokens
if m.MaxTokens == nil {
cfg.MaxTokens = *c.Config.Defaults.MaxTokens
@@ -164,7 +161,76 @@ func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.
return cfg
}
func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
func (c *Context) GetDefaultModel() (*provider.ModelConfig, error) {
return c.LookupModelProvider(*c.Config.Defaults.Model, "")
}
func (c *Context) GetProviderModels(p *Provider) []provider.ModelConfig {
models := []provider.ModelConfig{}
for _, m := range p.Models() {
var cfg provider.ModelConfig
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
},
}
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &google.Client{
BaseURL: url,
APIKey: p.APIKey,
},
}
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &ollama.OllamaClient{
BaseURL: url,
},
}
case "openai":
url := "https://api.openai.com/v1"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = provider.ModelConfig{
Client: &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
Headers: p.Headers,
},
}
}
cfg.Model = m.Name
if len(p.Name) > 0 {
cfg.Provider = p.Name
} else {
cfg.Provider = p.Kind
}
c.fillModelConfig(&cfg, m)
models = append(models, cfg)
}
return models
}
func (c *Context) LookupModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
parts := strings.Split(model, "@")
if providerName == "" && len(parts) > 1 {
@@ -182,61 +248,9 @@ func (c *Context) GetModelProvider(model string, providerName string) (*provider
continue
}
for _, m := range p.Models() {
var cfg *provider.ModelConfig
if m.Name == model {
switch p.Kind {
case "anthropic":
url := "https://api.anthropic.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg = &provider.ModelConfig{
Client: &anthropic.AnthropicClient{
BaseURL: url,
APIKey: p.APIKey,
},
}
return c.fillModelConfig(cfg, m), nil
case "google":
url := "https://generativelanguage.googleapis.com"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &provider.ModelConfig{
Client: &google.Client{
BaseURL: url,
APIKey: p.APIKey,
},
}
return c.fillModelConfig(cfg, m), nil
case "ollama":
url := "http://localhost:11434/api"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &provider.ModelConfig{
Client: &ollama.OllamaClient{
BaseURL: url,
},
}
return c.fillModelConfig(cfg, m), nil
case "openai":
url := "https://api.openai.com/v1"
if p.BaseURL != "" {
url = p.BaseURL
}
cfg := &provider.ModelConfig{
Client: &openai.OpenAIClient{
BaseURL: url,
APIKey: p.APIKey,
Headers: p.Headers,
},
}
return c.fillModelConfig(cfg, m), nil
default:
return nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
}
for _, modelConfig := range c.GetProviderModels(p) {
if modelConfig.Model == model {
return &modelConfig, nil
}
}
}

View File

@@ -18,8 +18,7 @@ type AppModel struct {
Conversations conversation.ConversationList
Conversation conversation.Conversation
Messages []conversation.Message
Model string
ProviderName string
Model provider.ModelConfig
Provider provider.ChatCompletionProvider
Agent *lmcli.Agent
@@ -29,7 +28,6 @@ type AppModel struct {
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Model: *ctx.Config.Defaults.Model,
modifiedMessages: make(map[uint]bool),
}
@@ -39,9 +37,8 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat
}
modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = modelConfig.Model
app.ProviderName = modelConfig.Provider
modelConfig, _ := ctx.GetDefaultModel()
app.Model = *modelConfig
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app
}
@@ -54,7 +51,7 @@ var (
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
return defaultStyle.Render(a.Model.Model) + accentStyle.Render("@") + defaultStyle.Render(a.Model.Provider)
}
type MessageCycleDirection int
@@ -257,13 +254,7 @@ func (a *AppModel) Prompt(
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*conversation.Message, error) {
modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
}
p := modelConfig.Client
params := provider.NewRequestParameters(*modelConfig)
params := provider.NewRequestParameters(a.Model)
if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox
@@ -278,14 +269,14 @@ func (a *AppModel) Prompt(
}
}()
msg, err := p.CreateChatCompletionStream(
msg, err := a.Model.Client.CreateChatCompletionStream(
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
)
if msg != nil {
msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
msg.Metadata.GenerationProvider = &a.Model.Provider
msg.Metadata.GenerationModel = &a.Model.Model
return &msg, err
}
return nil, err

View File

@@ -261,7 +261,7 @@ func (m *Model) conversationMessagesView() string {
heading := m.renderMessageHeading(-1, &conversation.Message{
Role: api.MessageRoleAssistant,
Metadata: conversation.MessageMeta{
GenerationModel: &m.App.Model,
GenerationModel: &m.App.Model.Model,
},
})
sb.WriteString(heading)

View File

@@ -3,6 +3,7 @@ package settings
import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
@@ -21,11 +22,6 @@ type Model struct {
height int
}
type modelOpt struct {
provider string
model string
}
const (
modelListId int = iota + 1
)
@@ -72,9 +68,8 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
case list.MsgOptionSelected:
switch msg.ID {
case modelListId:
if modelOpt, ok := msg.Option.Value.(modelOpt); ok {
m.App.Model = modelOpt.model
m.App.ProviderName = modelOpt.provider
if modelConfig, ok := msg.Option.Value.(provider.ModelConfig); ok {
m.App.Model = modelConfig
}
return m, shared.ChangeView(m.prevView)
}
@@ -103,10 +98,10 @@ func (m *Model) getModelOptions() []list.OptionGroup {
group := list.OptionGroup{
Name: providerLabel,
}
for _, model := range p.Models() {
for _, model := range m.App.Ctx.GetProviderModels(p) {
group.Options = append(group.Options, list.Option{
Label: model.Name,
Value: modelOpt{provider, model.Name},
Label: model.Model,
Value: model,
})
}
modelOpts = append(modelOpts, group)