From 69cdc0a5aa44a047821f82599d205c36bc40e4f9 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Wed, 25 Sep 2024 15:49:45 +0000 Subject: [PATCH] tui: Add setting view with support for changing the current model --- pkg/cmd/util/util.go | 6 +- pkg/lmcli/config.go | 1 + pkg/lmcli/lmcli.go | 17 +- pkg/tui/bubbles/list/list.go | 263 +++++++++++++++++++++++++++++ pkg/tui/model/model.go | 34 +++- pkg/tui/shared/shared.go | 10 +- pkg/tui/tui.go | 36 ++-- pkg/tui/views/chat/input.go | 11 ++ pkg/tui/views/chat/view.go | 7 +- pkg/tui/views/settings/settings.go | 137 +++++++++++++++ 10 files changed, 485 insertions(+), 37 deletions(-) create mode 100644 pkg/tui/bubbles/list/list.go create mode 100644 pkg/tui/views/settings/settings.go diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index f6944bb..5d853fc 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -17,7 +17,7 @@ import ( // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { - m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) + m, _, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") if err != nil { return nil, err } @@ -204,8 +204,8 @@ Example response: }, } - m, provider, err := ctx.GetModelProvider( - *ctx.Config.Conversations.TitleGenerationModel, + m, _, provider, err := ctx.GetModelProvider( + *ctx.Config.Conversations.TitleGenerationModel, "", ) if err != nil { return "", err diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index e684f3f..914ae8d 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -31,6 +31,7 @@ type Config struct { } `yaml:"agents"` Providers []*struct { Name string `yaml:"name,omitempty"` + Display string `yaml:"display,omitempty"` Kind string `yaml:"kind"` BaseURL string `yaml:"baseUrl,omitempty"` APIKey string `yaml:"apiKey,omitempty"` diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 2235376..1f0b565 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -123,11 +123,10 @@ func (c *Context) DefaultSystemPrompt() string { return c.Config.Defaults.SystemPrompt } -func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { +func (c *Context) GetModelProvider(model string, provider string) (string, string, api.ChatCompletionProvider, error) { parts := strings.Split(model, "@") - var provider string - if len(parts) > 1 { + if provider == "" && len(parts) > 1 { model = parts[0] provider = parts[1] } @@ -150,7 +149,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv if p.BaseURL != "" { url = p.BaseURL } - return model, &anthropic.AnthropicClient{ + return model, name, &anthropic.AnthropicClient{ BaseURL: url, APIKey: p.APIKey, }, nil @@ -159,7 +158,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv if p.BaseURL != "" { url = p.BaseURL } - return model, &google.Client{ + return model, name, &google.Client{ BaseURL: url, APIKey: p.APIKey, }, nil @@ -168,7 +167,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv if p.BaseURL != "" { url = p.BaseURL } - return model, &ollama.OllamaClient{ + return model, name, &ollama.OllamaClient{ BaseURL: url, }, nil case "openai": @@ -176,18 +175,18 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv if p.BaseURL != "" { url = p.BaseURL } - return model, &openai.OpenAIClient{ + return model, name, &openai.OpenAIClient{ BaseURL: url, APIKey: p.APIKey, Headers: p.Headers, }, nil default: - return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) + return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) } } } } - return "", nil, fmt.Errorf("unknown model: %s", model) + return "", "", nil, fmt.Errorf("unknown model: %s", model) } func configDir() string { diff --git a/pkg/tui/bubbles/list/list.go b/pkg/tui/bubbles/list/list.go new file mode 100644 index 0000000..946e313 --- /dev/null +++ b/pkg/tui/bubbles/list/list.go @@ -0,0 +1,263 @@ +package list + +import ( + "fmt" + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/tui/shared" + tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" + "github.com/charmbracelet/bubbles/textinput" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type Option struct { + Label string + Value interface{} +} + +type OptionGroup struct { + Name string + Options []Option +} + +type Model struct { + ID int + HeaderStyle lipgloss.Style + ItemStyle lipgloss.Style + SelectedStyle lipgloss.Style + ItemRender func(Option, bool) string + + Width int + Height int + + optionGroups []OptionGroup + selected int + filterInput textinput.Model + filteredIndices []filteredIndex + content viewport.Model + itemYOffsets []int +} + +type filteredIndex struct { + groupIndex int + optionIndex int +} + +type MsgOptionSelected struct { + ID int + Option Option +} + +func New(opts []Option) Model { + return NewWithGroups([]OptionGroup{{Options: opts}}) +} + +func NewWithGroups(groups []OptionGroup) Model { + ti := textinput.New() + ti.Prompt = "/" + ti.PromptStyle = lipgloss.NewStyle().Faint(true) + + m := Model{ + HeaderStyle: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("12")).Padding(1, 0, 1, 1), + ItemStyle: lipgloss.NewStyle(), + SelectedStyle: lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6")), + + optionGroups: groups, + selected: 0, + filterInput: ti, + filteredIndices: make([]filteredIndex, 0), + content: viewport.New(0, 0), + itemYOffsets: make([]int, 0), + } + + m.filterItems() + m.content.SetContent(m.renderList()) + return m +} + +func (m *Model) Focused() { + m.filterInput.Focused() +} + +func (m *Model) Focus() { + m.filterInput.Focus() +} + +func (m *Model) Blur() { + m.filterInput.Blur() +} + +func (m *Model) filterItems() { + filterText := strings.ToLower(m.filterInput.Value()) + + var prevSelection *filteredIndex + if m.selected <= len(m.filteredIndices)-1 { + prevSelection = &m.filteredIndices[m.selected] + } + + m.filteredIndices = make([]filteredIndex, 0) + + for groupIndex, group := range m.optionGroups { + for optionIndex, option := range group.Options { + if filterText == "" || + strings.Contains(strings.ToLower(option.Label), filterText) || + (group.Name != "" && strings.Contains(strings.ToLower(group.Name), filterText)) { + m.filteredIndices = append(m.filteredIndices, filteredIndex{groupIndex, optionIndex}) + } + } + } + + found := false + if len(m.filteredIndices) > 0 && prevSelection != nil { + // Preserve previous selection if possible + for i, filterIdx := range m.filteredIndices { + if prevSelection.groupIndex == filterIdx.groupIndex { + if prevSelection.optionIndex == filterIdx.optionIndex { + m.selected = i + found = true + break + } + } + } + } + if !found { + m.selected = 0 + } +} + +func (m *Model) Update(msg tea.Msg) (Model, tea.Cmd) { + var cmd tea.Cmd + + switch msg := msg.(type) { + case tea.KeyMsg: + if m.filterInput.Focused() { + switch msg.String() { + case "esc": + m.filterInput.Blur() + m.filterInput.SetValue("") + m.filterItems() + m.refreshContent() + return *m, shared.KeyHandled(msg) + case "enter": + m.filterInput.Blur() + m.refreshContent() + break + case "up", "down": + break + default: + m.filterInput, cmd = m.filterInput.Update(msg) + m.filterItems() + m.refreshContent() + return *m, cmd + } + } + + switch msg.String() { + case "up", "k": + m.moveSelection(-1) + return *m, shared.KeyHandled(msg) + case "down", "j": + m.moveSelection(1) + return *m, shared.KeyHandled(msg) + case "enter": + return *m, func() tea.Msg { + idx := m.filteredIndices[m.selected] + return MsgOptionSelected{ + ID: m.ID, + Option: m.optionGroups[idx.groupIndex].Options[idx.optionIndex], + } + } + case "/": + m.filterInput.Focus() + return *m, textinput.Blink + } + } + + m.content, cmd = m.content.Update(msg) + return *m, cmd +} + +func (m *Model) refreshContent() { + m.content.SetContent(m.renderList()) + m.ensureSelectedVisible() +} + +func (m *Model) ensureSelectedVisible() { + if m.selected == 0 { + m.content.GotoTop() + } else if m.selected == len(m.filteredIndices)-1 { + m.content.GotoBottom() + } else { + tuiutil.ScrollIntoView(&m.content, m.itemYOffsets[m.selected], 0) + } +} + +func (m *Model) moveSelection(delta int) { + prev := m.selected + m.selected = min(len(m.filteredIndices)-1, max(0, m.selected+delta)) + if prev != m.selected { + m.refreshContent() + } +} + +func (m *Model) View() string { + filter := "" + if m.filterInput.Focused() { + m.filterInput.Width = m.Width + filter = m.filterInput.View() + } + + contentHeight := m.Height - tuiutil.Height(filter) + m.content.Width, m.content.Height = m.Width, contentHeight + + parts := []string{m.content.View()} + if filter != "" { + parts = append(parts, filter) + } + return lipgloss.JoinVertical(lipgloss.Left, parts...) +} + +func (m *Model) renderList() string { + var sb strings.Builder + yOffset := 0 + lastGroupIndex := -1 + m.itemYOffsets = make([]int, len(m.filteredIndices)) + + for i, idx := range m.filteredIndices { + if idx.groupIndex != lastGroupIndex { + group := m.optionGroups[idx.groupIndex].Name + if group != "" { + headingStr := m.HeaderStyle.Render(group) + yOffset += tuiutil.Height(headingStr) + sb.WriteString(headingStr) + sb.WriteRune('\n') + } + lastGroupIndex = idx.groupIndex + } + + m.itemYOffsets[i] = yOffset + option := m.optionGroups[idx.groupIndex].Options[idx.optionIndex] + var item string + if m.ItemRender != nil { + item = m.ItemRender(option, i == m.selected) + } else { + prefix := " " + if i == m.selected { + prefix = "> " + item = m.SelectedStyle.Render(option.Label) + } else { + item = m.ItemStyle.Render(option.Label) + } + item = fmt.Sprintf("%s%s", prefix, item) + } + sb.WriteString(item) + yOffset += tuiutil.Height(item) + if i < len(m.filteredIndices)-1 { + sb.WriteRune('\n') + } + } + + return sb.String() +} diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index dc4648d..72371d6 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -8,6 +8,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" + "github.com/charmbracelet/lipgloss" ) type LoadedConversation struct { @@ -21,6 +22,37 @@ type AppModel struct { Conversation *api.Conversation RootMessages []api.Message Messages []api.Message + Model string + ProviderName string + Provider api.ChatCompletionProvider +} + +func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel { + app := &AppModel{ + Ctx: ctx, + Conversation: initialConversation, + Model: *ctx.Config.Defaults.Model, + } + + if initialConversation == nil { + app.NewConversation() + } + + model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") + app.Model = model + app.ProviderName = provider + return app +} + +var ( + defaultStyle = lipgloss.NewStyle().Faint(true) + accentStyle = defaultStyle.Foreground(lipgloss.Color("6")) +) + +func (a *AppModel) ActiveModel(style lipgloss.Style) string { + defaultStyle := style.Inherit(defaultStyle) + accentStyle := style.Inherit(accentStyle) + return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName) } type MessageCycleDirection int @@ -194,7 +226,7 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, } func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) { - model, provider, err := a.Ctx.GetModelProvider(*a.Ctx.Config.Defaults.Model) + model, _, provider, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) if err != nil { return nil, err } diff --git a/pkg/tui/shared/shared.go b/pkg/tui/shared/shared.go index 38600c1..70e8eed 100644 --- a/pkg/tui/shared/shared.go +++ b/pkg/tui/shared/shared.go @@ -22,24 +22,24 @@ type View int const ( ViewChat View = iota ViewConversations - //StateSettings + ViewSettings //StateHelp ) type ( // send to change the current state MsgViewChange View - // sent to a state when it is entered - MsgViewEnter struct{} + // sent to a state when it is entered, with the view we're leaving + MsgViewEnter View // sent when a recoverable error occurs (displayed to user) MsgError struct { Err error } // sent when the view has handled a key input MsgKeyHandled tea.KeyMsg ) -func ViewEnter() tea.Cmd { +func ViewEnter(from View) tea.Cmd { return func() tea.Msg { - return MsgViewEnter{} + return MsgViewEnter(from) } } diff --git a/pkg/tui/tui.go b/pkg/tui/tui.go index 8f3c2b6..23db52c 100644 --- a/pkg/tui/tui.go +++ b/pkg/tui/tui.go @@ -10,15 +10,11 @@ import ( tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" "git.mlow.ca/mlow/lmcli/pkg/tui/views/chat" "git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations" + "git.mlow.ca/mlow/lmcli/pkg/tui/views/settings" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" ) -type LaunchOptions struct { - InitialConversation *api.Conversation - InitialView shared.View -} - type Model struct { App *model.AppModel @@ -26,7 +22,8 @@ type Model struct { width int height int - // errors we will display to the user and allow them to dismiss + // errors to display + // TODO: allow dismissing errors errs []error activeView shared.View @@ -34,11 +31,7 @@ type Model struct { } func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model { - app := &model.AppModel{ - Ctx: ctx, - Conversation: opts.InitialConversation, - } - app.NewConversation() + app := model.NewAppModel(ctx, opts.InitialConversation) m := Model{ App: app, @@ -46,6 +39,7 @@ func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model { views: map[shared.View]shared.ViewModel{ shared.ViewChat: chat.Chat(app), shared.ViewConversations: conversations.Conversations(app), + shared.ViewSettings: settings.Settings(app), }, } @@ -53,14 +47,15 @@ func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model { } func (m *Model) Init() tea.Cmd { - cmds := []tea.Cmd{ - func() tea.Msg { - return shared.MsgViewChange(m.activeView) - }, - } + var cmds []tea.Cmd for _, v := range m.views { + // Init views cmds = append(cmds, v.Init()) } + cmds = append(cmds, func() tea.Msg { + // Initial view change + return shared.MsgViewChange(m.activeView) + }) return tea.Batch(cmds...) } @@ -88,11 +83,11 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, cmd } case shared.MsgViewChange: + currView := m.activeView m.activeView = shared.View(msg) - return m, tea.Batch(tea.WindowSize(), shared.ViewEnter()) + return m, tea.Batch(tea.WindowSize(), shared.ViewEnter(currView)) case shared.MsgError: m.errs = append(m.errs, msg.Err) - return m, nil } view, cmd := m.views[m.activeView].Update(msg) @@ -134,6 +129,11 @@ func (m *Model) View() string { return lipgloss.JoinVertical(lipgloss.Left, sections...) } +type LaunchOptions struct { + InitialConversation *api.Conversation + InitialView shared.View +} + type LaunchOption func(*LaunchOptions) func WithInitialConversation(conv *api.Conversation) LaunchOption { diff --git a/pkg/tui/views/chat/input.go b/pkg/tui/views/chat/input.go index f2f56ad..ecd6fcd 100644 --- a/pkg/tui/views/chat/input.go +++ b/pkg/tui/views/chat/input.go @@ -12,6 +12,17 @@ import ( ) func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "ctrl+g": + if m.state == pendingResponse { + m.stopSignal <- struct{}{} + return shared.KeyHandled(msg) + } + return func() tea.Msg { + return shared.MsgViewChange(shared.ViewSettings) + } + } + switch m.focus { case focusInput: cmd := m.handleInputKey(msg) diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index 222ee57..0c40a41 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -288,7 +288,12 @@ func (m *Model) Footer(width int) string { rightSegments = append(rightSegments, segmentStyle.Render(throughput)) } - model := fmt.Sprintf("Model: %s", *m.App.Ctx.Config.Defaults.Model) + if m.App.ProviderName != "" { + provider := fmt.Sprintf("Provider: %s", m.App.ProviderName) + rightSegments = append(rightSegments, segmentStyle.Render(provider)) + } + + model := fmt.Sprintf("Model: %s", m.App.Model) rightSegments = append(rightSegments, segmentStyle.Render(model)) left := strings.Join(leftSegments, segmentSeparator) diff --git a/pkg/tui/views/settings/settings.go b/pkg/tui/views/settings/settings.go new file mode 100644 index 0000000..f79f06e --- /dev/null +++ b/pkg/tui/views/settings/settings.go @@ -0,0 +1,137 @@ +package settings + +import ( + "strings" + + "git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list" + "git.mlow.ca/mlow/lmcli/pkg/tui/model" + "git.mlow.ca/mlow/lmcli/pkg/tui/shared" + "git.mlow.ca/mlow/lmcli/pkg/tui/styles" + "github.com/charmbracelet/bubbles/viewport" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type Model struct { + App *model.AppModel + prevView shared.View + content viewport.Model + modelList list.Model + width int + height int +} + +type modelOpt struct { + provider string + model string +} + +const ( + modelListId int = iota + 1 +) + +func Settings(app *model.AppModel) *Model { + m := &Model{ + App: app, + content: viewport.New(0, 0), + } + return m +} + +func (m *Model) Init() tea.Cmd { + m.modelList = list.NewWithGroups(m.getModelOptions()) + m.modelList.ID = modelListId + return nil +} + +func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) { + var cmd tea.Cmd + + switch msg := msg.(type) { + case tea.KeyMsg: + m.modelList, cmd = m.modelList.Update(msg) + if cmd != nil { + return m, cmd + } + + switch msg.String() { + case "esc": + return m, func() tea.Msg { + return shared.MsgViewChange(m.prevView) + } + } + case shared.MsgViewEnter: + m.prevView = shared.View(msg) + m.modelList.Focus() + m.content.SetContent(m.renderContent()) + case tea.WindowSizeMsg: + m.width, m.height = msg.Width, msg.Height + m.content.Width = msg.Width + m.content.Height = msg.Height + m.content.SetContent(m.renderContent()) + case list.MsgOptionSelected: + switch msg.ID { + case modelListId: + if modelOpt, ok := msg.Option.Value.(modelOpt); ok { + m.App.Model = modelOpt.model + m.App.ProviderName = modelOpt.provider + } + return m, shared.ChangeView(m.prevView) + } + } + + m.modelList, cmd = m.modelList.Update(msg) + if cmd != nil { + return m, cmd + } + + m.content.SetContent(m.renderContent()) + return m, nil +} + +func (m *Model) getModelOptions() []list.OptionGroup { + modelOpts := []list.OptionGroup{} + for _, p := range m.App.Ctx.Config.Providers { + provider := p.Name + if provider == "" { + provider = p.Kind + } + providerLabel := p.Display + if providerLabel == "" { + providerLabel = strings.ToUpper(provider[:1]) + provider[1:] + } + group := list.OptionGroup{ + Name: providerLabel, + } + for _, model := range p.Models { + group.Options = append(group.Options, list.Option{ + Label: model, + Value: modelOpt{provider, model}, + }) + } + modelOpts = append(modelOpts, group) + } + return modelOpts +} + +func (m *Model) Header(width int) string { + boldStyle := lipgloss.NewStyle().Bold(true) + // TODO: update header depending on active settings mode (model, agent, etc) + header := boldStyle.Render("Model selection") + return styles.Header.Width(width).Render(header) +} + +func (m *Model) Content(width, height int) string { + // TODO: see Header() + currentModel := " Active model: " + m.App.ActiveModel(lipgloss.NewStyle()) + m.modelList.Width, m.modelList.Height = width, height - 2 + return "\n" + currentModel + "\n" + m.modelList.View() +} + +func (m *Model) Footer(width int) string { + return "" +} + +func (m *Model) renderContent() string { + return m.modelList.View() +}