Work to simplify model config handling
Keep a direct reference to a provider.ModelConfig in TUI's AppModel, rather than the names of the provider and model
This commit is contained in:
@@ -19,7 +19,7 @@ import (
|
|||||||
// Prompt prompts the configured the configured model and streams the response
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
||||||
modelConfig, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
modelConfig, err := ctx.LookupModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -203,7 +203,7 @@ Example response:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
modelConfig, err := ctx.GetModelProvider(
|
modelConfig, err := ctx.LookupModelProvider(
|
||||||
*ctx.Config.Conversations.TitleGenerationModel, "",
|
*ctx.Config.Conversations.TitleGenerationModel, "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -139,9 +139,6 @@ func (c *Context) DefaultSystemPrompt() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig {
|
func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.ModelConfig {
|
||||||
// Set model name
|
|
||||||
cfg.Model = m.Name
|
|
||||||
|
|
||||||
// Set max tokens
|
// Set max tokens
|
||||||
if m.MaxTokens == nil {
|
if m.MaxTokens == nil {
|
||||||
cfg.MaxTokens = *c.Config.Defaults.MaxTokens
|
cfg.MaxTokens = *c.Config.Defaults.MaxTokens
|
||||||
@@ -164,7 +161,76 @@ func (c *Context) fillModelConfig(cfg *provider.ModelConfig, m Model) *provider.
|
|||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
|
func (c *Context) GetDefaultModel() (*provider.ModelConfig, error) {
|
||||||
|
return c.LookupModelProvider(*c.Config.Defaults.Model, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) GetProviderModels(p *Provider) []provider.ModelConfig {
|
||||||
|
models := []provider.ModelConfig{}
|
||||||
|
for _, m := range p.Models() {
|
||||||
|
var cfg provider.ModelConfig
|
||||||
|
switch p.Kind {
|
||||||
|
case "anthropic":
|
||||||
|
url := "https://api.anthropic.com"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
cfg = provider.ModelConfig{
|
||||||
|
Client: &anthropic.AnthropicClient{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case "google":
|
||||||
|
url := "https://generativelanguage.googleapis.com"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
cfg = provider.ModelConfig{
|
||||||
|
Client: &google.Client{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case "ollama":
|
||||||
|
url := "http://localhost:11434/api"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
cfg = provider.ModelConfig{
|
||||||
|
Client: &ollama.OllamaClient{
|
||||||
|
BaseURL: url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case "openai":
|
||||||
|
url := "https://api.openai.com/v1"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
cfg = provider.ModelConfig{
|
||||||
|
Client: &openai.OpenAIClient{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
Headers: p.Headers,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.Model = m.Name
|
||||||
|
|
||||||
|
if len(p.Name) > 0 {
|
||||||
|
cfg.Provider = p.Name
|
||||||
|
} else {
|
||||||
|
cfg.Provider = p.Kind
|
||||||
|
}
|
||||||
|
|
||||||
|
c.fillModelConfig(&cfg, m)
|
||||||
|
models = append(models, cfg)
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) LookupModelProvider(model string, providerName string) (*provider.ModelConfig, error) {
|
||||||
parts := strings.Split(model, "@")
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
if providerName == "" && len(parts) > 1 {
|
if providerName == "" && len(parts) > 1 {
|
||||||
@@ -182,61 +248,9 @@ func (c *Context) GetModelProvider(model string, providerName string) (*provider
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range p.Models() {
|
for _, modelConfig := range c.GetProviderModels(p) {
|
||||||
var cfg *provider.ModelConfig
|
if modelConfig.Model == model {
|
||||||
if m.Name == model {
|
return &modelConfig, nil
|
||||||
switch p.Kind {
|
|
||||||
case "anthropic":
|
|
||||||
url := "https://api.anthropic.com"
|
|
||||||
if p.BaseURL != "" {
|
|
||||||
url = p.BaseURL
|
|
||||||
}
|
|
||||||
cfg = &provider.ModelConfig{
|
|
||||||
Client: &anthropic.AnthropicClient{
|
|
||||||
BaseURL: url,
|
|
||||||
APIKey: p.APIKey,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return c.fillModelConfig(cfg, m), nil
|
|
||||||
case "google":
|
|
||||||
url := "https://generativelanguage.googleapis.com"
|
|
||||||
if p.BaseURL != "" {
|
|
||||||
url = p.BaseURL
|
|
||||||
}
|
|
||||||
cfg := &provider.ModelConfig{
|
|
||||||
Client: &google.Client{
|
|
||||||
BaseURL: url,
|
|
||||||
APIKey: p.APIKey,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return c.fillModelConfig(cfg, m), nil
|
|
||||||
case "ollama":
|
|
||||||
url := "http://localhost:11434/api"
|
|
||||||
if p.BaseURL != "" {
|
|
||||||
url = p.BaseURL
|
|
||||||
}
|
|
||||||
cfg := &provider.ModelConfig{
|
|
||||||
Client: &ollama.OllamaClient{
|
|
||||||
BaseURL: url,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return c.fillModelConfig(cfg, m), nil
|
|
||||||
case "openai":
|
|
||||||
url := "https://api.openai.com/v1"
|
|
||||||
if p.BaseURL != "" {
|
|
||||||
url = p.BaseURL
|
|
||||||
}
|
|
||||||
cfg := &provider.ModelConfig{
|
|
||||||
Client: &openai.OpenAIClient{
|
|
||||||
BaseURL: url,
|
|
||||||
APIKey: p.APIKey,
|
|
||||||
Headers: p.Headers,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return c.fillModelConfig(cfg, m), nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ type AppModel struct {
|
|||||||
Conversations conversation.ConversationList
|
Conversations conversation.ConversationList
|
||||||
Conversation conversation.Conversation
|
Conversation conversation.Conversation
|
||||||
Messages []conversation.Message
|
Messages []conversation.Message
|
||||||
Model string
|
Model provider.ModelConfig
|
||||||
ProviderName string
|
|
||||||
Provider provider.ChatCompletionProvider
|
Provider provider.ChatCompletionProvider
|
||||||
Agent *lmcli.Agent
|
Agent *lmcli.Agent
|
||||||
|
|
||||||
@@ -29,7 +28,6 @@ type AppModel struct {
|
|||||||
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
|
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
|
||||||
app := &AppModel{
|
app := &AppModel{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Model: *ctx.Config.Defaults.Model,
|
|
||||||
modifiedMessages: make(map[uint]bool),
|
modifiedMessages: make(map[uint]bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,9 +37,8 @@ func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversat
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelConfig, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
modelConfig, _ := ctx.GetDefaultModel()
|
||||||
app.Model = modelConfig.Model
|
app.Model = *modelConfig
|
||||||
app.ProviderName = modelConfig.Provider
|
|
||||||
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
||||||
return app
|
return app
|
||||||
}
|
}
|
||||||
@@ -54,7 +51,7 @@ var (
|
|||||||
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
|
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
|
||||||
defaultStyle := style.Inherit(defaultStyle)
|
defaultStyle := style.Inherit(defaultStyle)
|
||||||
accentStyle := style.Inherit(accentStyle)
|
accentStyle := style.Inherit(accentStyle)
|
||||||
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
|
return defaultStyle.Render(a.Model.Model) + accentStyle.Render("@") + defaultStyle.Render(a.Model.Provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MessageCycleDirection int
|
type MessageCycleDirection int
|
||||||
@@ -257,13 +254,7 @@ func (a *AppModel) Prompt(
|
|||||||
chatReplyChunks chan provider.Chunk,
|
chatReplyChunks chan provider.Chunk,
|
||||||
stopSignal chan struct{},
|
stopSignal chan struct{},
|
||||||
) (*conversation.Message, error) {
|
) (*conversation.Message, error) {
|
||||||
modelConfig, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
params := provider.NewRequestParameters(a.Model)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
p := modelConfig.Client
|
|
||||||
|
|
||||||
params := provider.NewRequestParameters(*modelConfig)
|
|
||||||
|
|
||||||
if a.Agent != nil {
|
if a.Agent != nil {
|
||||||
params.Toolbox = a.Agent.Toolbox
|
params.Toolbox = a.Agent.Toolbox
|
||||||
@@ -278,14 +269,14 @@ func (a *AppModel) Prompt(
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
msg, err := p.CreateChatCompletionStream(
|
msg, err := a.Model.Client.CreateChatCompletionStream(
|
||||||
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
|
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg != nil {
|
if msg != nil {
|
||||||
msg := conversation.MessageFromAPI(*msg)
|
msg := conversation.MessageFromAPI(*msg)
|
||||||
msg.Metadata.GenerationProvider = &a.ProviderName
|
msg.Metadata.GenerationProvider = &a.Model.Provider
|
||||||
msg.Metadata.GenerationModel = &a.Model
|
msg.Metadata.GenerationModel = &a.Model.Model
|
||||||
return &msg, err
|
return &msg, err
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ func (m *Model) conversationMessagesView() string {
|
|||||||
heading := m.renderMessageHeading(-1, &conversation.Message{
|
heading := m.renderMessageHeading(-1, &conversation.Message{
|
||||||
Role: api.MessageRoleAssistant,
|
Role: api.MessageRoleAssistant,
|
||||||
Metadata: conversation.MessageMeta{
|
Metadata: conversation.MessageMeta{
|
||||||
GenerationModel: &m.App.Model,
|
GenerationModel: &m.App.Model.Model,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
sb.WriteString(heading)
|
sb.WriteString(heading)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package settings
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
@@ -21,11 +22,6 @@ type Model struct {
|
|||||||
height int
|
height int
|
||||||
}
|
}
|
||||||
|
|
||||||
type modelOpt struct {
|
|
||||||
provider string
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
modelListId int = iota + 1
|
modelListId int = iota + 1
|
||||||
)
|
)
|
||||||
@@ -72,9 +68,8 @@ func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
|||||||
case list.MsgOptionSelected:
|
case list.MsgOptionSelected:
|
||||||
switch msg.ID {
|
switch msg.ID {
|
||||||
case modelListId:
|
case modelListId:
|
||||||
if modelOpt, ok := msg.Option.Value.(modelOpt); ok {
|
if modelConfig, ok := msg.Option.Value.(provider.ModelConfig); ok {
|
||||||
m.App.Model = modelOpt.model
|
m.App.Model = modelConfig
|
||||||
m.App.ProviderName = modelOpt.provider
|
|
||||||
}
|
}
|
||||||
return m, shared.ChangeView(m.prevView)
|
return m, shared.ChangeView(m.prevView)
|
||||||
}
|
}
|
||||||
@@ -103,10 +98,10 @@ func (m *Model) getModelOptions() []list.OptionGroup {
|
|||||||
group := list.OptionGroup{
|
group := list.OptionGroup{
|
||||||
Name: providerLabel,
|
Name: providerLabel,
|
||||||
}
|
}
|
||||||
for _, model := range p.Models() {
|
for _, model := range m.App.Ctx.GetProviderModels(p) {
|
||||||
group.Options = append(group.Options, list.Option{
|
group.Options = append(group.Options, list.Option{
|
||||||
Label: model.Name,
|
Label: model.Model,
|
||||||
Value: modelOpt{provider, model.Name},
|
Value: model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
modelOpts = append(modelOpts, group)
|
modelOpts = append(modelOpts, group)
|
||||||
|
|||||||
Reference in New Issue
Block a user