tui: Add setting view with support for changing the current model

This commit is contained in:
Matt Low 2024-09-25 15:49:45 +00:00
parent 3ec2675632
commit 69cdc0a5aa
10 changed files with 485 additions and 37 deletions

View File

@ -17,7 +17,7 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) m, _, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -204,8 +204,8 @@ Example response:
}, },
} }
m, provider, err := ctx.GetModelProvider( m, _, provider, err := ctx.GetModelProvider(
*ctx.Config.Conversations.TitleGenerationModel, *ctx.Config.Conversations.TitleGenerationModel, "",
) )
if err != nil { if err != nil {
return "", err return "", err

View File

@ -31,6 +31,7 @@ type Config struct {
} `yaml:"agents"` } `yaml:"agents"`
Providers []*struct { Providers []*struct {
Name string `yaml:"name,omitempty"` Name string `yaml:"name,omitempty"`
Display string `yaml:"display,omitempty"`
Kind string `yaml:"kind"` Kind string `yaml:"kind"`
BaseURL string `yaml:"baseUrl,omitempty"` BaseURL string `yaml:"baseUrl,omitempty"`
APIKey string `yaml:"apiKey,omitempty"` APIKey string `yaml:"apiKey,omitempty"`

View File

@ -123,11 +123,10 @@ func (c *Context) DefaultSystemPrompt() string {
return c.Config.Defaults.SystemPrompt return c.Config.Defaults.SystemPrompt
} }
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { func (c *Context) GetModelProvider(model string, provider string) (string, string, api.ChatCompletionProvider, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")
var provider string if provider == "" && len(parts) > 1 {
if len(parts) > 1 {
model = parts[0] model = parts[0]
provider = parts[1] provider = parts[1]
} }
@ -150,7 +149,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &anthropic.AnthropicClient{ return model, name, &anthropic.AnthropicClient{
BaseURL: url, BaseURL: url,
APIKey: p.APIKey, APIKey: p.APIKey,
}, nil }, nil
@ -159,7 +158,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &google.Client{ return model, name, &google.Client{
BaseURL: url, BaseURL: url,
APIKey: p.APIKey, APIKey: p.APIKey,
}, nil }, nil
@ -168,7 +167,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &ollama.OllamaClient{ return model, name, &ollama.OllamaClient{
BaseURL: url, BaseURL: url,
}, nil }, nil
case "openai": case "openai":
@ -176,18 +175,18 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
if p.BaseURL != "" { if p.BaseURL != "" {
url = p.BaseURL url = p.BaseURL
} }
return model, &openai.OpenAIClient{ return model, name, &openai.OpenAIClient{
BaseURL: url, BaseURL: url,
APIKey: p.APIKey, APIKey: p.APIKey,
Headers: p.Headers, Headers: p.Headers,
}, nil }, nil
default: default:
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind) return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
} }
} }
} }
} }
return "", nil, fmt.Errorf("unknown model: %s", model) return "", "", nil, fmt.Errorf("unknown model: %s", model)
} }
func configDir() string { func configDir() string {

View File

@ -0,0 +1,263 @@
package list
import (
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type Option struct {
Label string
Value interface{}
}
type OptionGroup struct {
Name string
Options []Option
}
type Model struct {
ID int
HeaderStyle lipgloss.Style
ItemStyle lipgloss.Style
SelectedStyle lipgloss.Style
ItemRender func(Option, bool) string
Width int
Height int
optionGroups []OptionGroup
selected int
filterInput textinput.Model
filteredIndices []filteredIndex
content viewport.Model
itemYOffsets []int
}
type filteredIndex struct {
groupIndex int
optionIndex int
}
type MsgOptionSelected struct {
ID int
Option Option
}
func New(opts []Option) Model {
return NewWithGroups([]OptionGroup{{Options: opts}})
}
func NewWithGroups(groups []OptionGroup) Model {
ti := textinput.New()
ti.Prompt = "/"
ti.PromptStyle = lipgloss.NewStyle().Faint(true)
m := Model{
HeaderStyle: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("12")).Padding(1, 0, 1, 1),
ItemStyle: lipgloss.NewStyle(),
SelectedStyle: lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6")),
optionGroups: groups,
selected: 0,
filterInput: ti,
filteredIndices: make([]filteredIndex, 0),
content: viewport.New(0, 0),
itemYOffsets: make([]int, 0),
}
m.filterItems()
m.content.SetContent(m.renderList())
return m
}
func (m *Model) Focused() {
m.filterInput.Focused()
}
func (m *Model) Focus() {
m.filterInput.Focus()
}
func (m *Model) Blur() {
m.filterInput.Blur()
}
func (m *Model) filterItems() {
filterText := strings.ToLower(m.filterInput.Value())
var prevSelection *filteredIndex
if m.selected <= len(m.filteredIndices)-1 {
prevSelection = &m.filteredIndices[m.selected]
}
m.filteredIndices = make([]filteredIndex, 0)
for groupIndex, group := range m.optionGroups {
for optionIndex, option := range group.Options {
if filterText == "" ||
strings.Contains(strings.ToLower(option.Label), filterText) ||
(group.Name != "" && strings.Contains(strings.ToLower(group.Name), filterText)) {
m.filteredIndices = append(m.filteredIndices, filteredIndex{groupIndex, optionIndex})
}
}
}
found := false
if len(m.filteredIndices) > 0 && prevSelection != nil {
// Preserve previous selection if possible
for i, filterIdx := range m.filteredIndices {
if prevSelection.groupIndex == filterIdx.groupIndex {
if prevSelection.optionIndex == filterIdx.optionIndex {
m.selected = i
found = true
break
}
}
}
}
if !found {
m.selected = 0
}
}
func (m *Model) Update(msg tea.Msg) (Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
if m.filterInput.Focused() {
switch msg.String() {
case "esc":
m.filterInput.Blur()
m.filterInput.SetValue("")
m.filterItems()
m.refreshContent()
return *m, shared.KeyHandled(msg)
case "enter":
m.filterInput.Blur()
m.refreshContent()
break
case "up", "down":
break
default:
m.filterInput, cmd = m.filterInput.Update(msg)
m.filterItems()
m.refreshContent()
return *m, cmd
}
}
switch msg.String() {
case "up", "k":
m.moveSelection(-1)
return *m, shared.KeyHandled(msg)
case "down", "j":
m.moveSelection(1)
return *m, shared.KeyHandled(msg)
case "enter":
return *m, func() tea.Msg {
idx := m.filteredIndices[m.selected]
return MsgOptionSelected{
ID: m.ID,
Option: m.optionGroups[idx.groupIndex].Options[idx.optionIndex],
}
}
case "/":
m.filterInput.Focus()
return *m, textinput.Blink
}
}
m.content, cmd = m.content.Update(msg)
return *m, cmd
}
func (m *Model) refreshContent() {
m.content.SetContent(m.renderList())
m.ensureSelectedVisible()
}
func (m *Model) ensureSelectedVisible() {
if m.selected == 0 {
m.content.GotoTop()
} else if m.selected == len(m.filteredIndices)-1 {
m.content.GotoBottom()
} else {
tuiutil.ScrollIntoView(&m.content, m.itemYOffsets[m.selected], 0)
}
}
func (m *Model) moveSelection(delta int) {
prev := m.selected
m.selected = min(len(m.filteredIndices)-1, max(0, m.selected+delta))
if prev != m.selected {
m.refreshContent()
}
}
func (m *Model) View() string {
filter := ""
if m.filterInput.Focused() {
m.filterInput.Width = m.Width
filter = m.filterInput.View()
}
contentHeight := m.Height - tuiutil.Height(filter)
m.content.Width, m.content.Height = m.Width, contentHeight
parts := []string{m.content.View()}
if filter != "" {
parts = append(parts, filter)
}
return lipgloss.JoinVertical(lipgloss.Left, parts...)
}
func (m *Model) renderList() string {
var sb strings.Builder
yOffset := 0
lastGroupIndex := -1
m.itemYOffsets = make([]int, len(m.filteredIndices))
for i, idx := range m.filteredIndices {
if idx.groupIndex != lastGroupIndex {
group := m.optionGroups[idx.groupIndex].Name
if group != "" {
headingStr := m.HeaderStyle.Render(group)
yOffset += tuiutil.Height(headingStr)
sb.WriteString(headingStr)
sb.WriteRune('\n')
}
lastGroupIndex = idx.groupIndex
}
m.itemYOffsets[i] = yOffset
option := m.optionGroups[idx.groupIndex].Options[idx.optionIndex]
var item string
if m.ItemRender != nil {
item = m.ItemRender(option, i == m.selected)
} else {
prefix := " "
if i == m.selected {
prefix = "> "
item = m.SelectedStyle.Render(option.Label)
} else {
item = m.ItemStyle.Render(option.Label)
}
item = fmt.Sprintf("%s%s", prefix, item)
}
sb.WriteString(item)
yOffset += tuiutil.Height(item)
if i < len(m.filteredIndices)-1 {
sb.WriteRune('\n')
}
}
return sb.String()
}

View File

@ -8,6 +8,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/charmbracelet/lipgloss"
) )
type LoadedConversation struct { type LoadedConversation struct {
@ -21,6 +22,37 @@ type AppModel struct {
Conversation *api.Conversation Conversation *api.Conversation
RootMessages []api.Message RootMessages []api.Message
Messages []api.Message Messages []api.Message
Model string
ProviderName string
Provider api.ChatCompletionProvider
}
func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Conversation: initialConversation,
Model: *ctx.Config.Defaults.Model,
}
if initialConversation == nil {
app.NewConversation()
}
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model
app.ProviderName = provider
return app
}
var (
defaultStyle = lipgloss.NewStyle().Faint(true)
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
)
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
} }
type MessageCycleDirection int type MessageCycleDirection int
@ -194,7 +226,7 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult,
} }
func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) { func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) {
model, provider, err := a.Ctx.GetModelProvider(*a.Ctx.Config.Defaults.Model) model, _, provider, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -22,24 +22,24 @@ type View int
const ( const (
ViewChat View = iota ViewChat View = iota
ViewConversations ViewConversations
//StateSettings ViewSettings
//StateHelp //StateHelp
) )
type ( type (
// send to change the current state // send to change the current state
MsgViewChange View MsgViewChange View
// sent to a state when it is entered // sent to a state when it is entered, with the view we're leaving
MsgViewEnter struct{} MsgViewEnter View
// sent when a recoverable error occurs (displayed to user) // sent when a recoverable error occurs (displayed to user)
MsgError struct { Err error } MsgError struct { Err error }
// sent when the view has handled a key input // sent when the view has handled a key input
MsgKeyHandled tea.KeyMsg MsgKeyHandled tea.KeyMsg
) )
func ViewEnter() tea.Cmd { func ViewEnter(from View) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
return MsgViewEnter{} return MsgViewEnter(from)
} }
} }

View File

@ -10,15 +10,11 @@ import (
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/chat" "git.mlow.ca/mlow/lmcli/pkg/tui/views/chat"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations" "git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations"
"git.mlow.ca/mlow/lmcli/pkg/tui/views/settings"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
type LaunchOptions struct {
InitialConversation *api.Conversation
InitialView shared.View
}
type Model struct { type Model struct {
App *model.AppModel App *model.AppModel
@ -26,7 +22,8 @@ type Model struct {
width int width int
height int height int
// errors we will display to the user and allow them to dismiss // errors to display
// TODO: allow dismissing errors
errs []error errs []error
activeView shared.View activeView shared.View
@ -34,11 +31,7 @@ type Model struct {
} }
func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model { func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model {
app := &model.AppModel{ app := model.NewAppModel(ctx, opts.InitialConversation)
Ctx: ctx,
Conversation: opts.InitialConversation,
}
app.NewConversation()
m := Model{ m := Model{
App: app, App: app,
@ -46,6 +39,7 @@ func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model {
views: map[shared.View]shared.ViewModel{ views: map[shared.View]shared.ViewModel{
shared.ViewChat: chat.Chat(app), shared.ViewChat: chat.Chat(app),
shared.ViewConversations: conversations.Conversations(app), shared.ViewConversations: conversations.Conversations(app),
shared.ViewSettings: settings.Settings(app),
}, },
} }
@ -53,14 +47,15 @@ func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model {
} }
func (m *Model) Init() tea.Cmd { func (m *Model) Init() tea.Cmd {
cmds := []tea.Cmd{ var cmds []tea.Cmd
func() tea.Msg {
return shared.MsgViewChange(m.activeView)
},
}
for _, v := range m.views { for _, v := range m.views {
// Init views
cmds = append(cmds, v.Init()) cmds = append(cmds, v.Init())
} }
cmds = append(cmds, func() tea.Msg {
// Initial view change
return shared.MsgViewChange(m.activeView)
})
return tea.Batch(cmds...) return tea.Batch(cmds...)
} }
@ -88,11 +83,11 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, cmd return m, cmd
} }
case shared.MsgViewChange: case shared.MsgViewChange:
currView := m.activeView
m.activeView = shared.View(msg) m.activeView = shared.View(msg)
return m, tea.Batch(tea.WindowSize(), shared.ViewEnter()) return m, tea.Batch(tea.WindowSize(), shared.ViewEnter(currView))
case shared.MsgError: case shared.MsgError:
m.errs = append(m.errs, msg.Err) m.errs = append(m.errs, msg.Err)
return m, nil
} }
view, cmd := m.views[m.activeView].Update(msg) view, cmd := m.views[m.activeView].Update(msg)
@ -134,6 +129,11 @@ func (m *Model) View() string {
return lipgloss.JoinVertical(lipgloss.Left, sections...) return lipgloss.JoinVertical(lipgloss.Left, sections...)
} }
type LaunchOptions struct {
InitialConversation *api.Conversation
InitialView shared.View
}
type LaunchOption func(*LaunchOptions) type LaunchOption func(*LaunchOptions)
func WithInitialConversation(conv *api.Conversation) LaunchOption { func WithInitialConversation(conv *api.Conversation) LaunchOption {

View File

@ -12,6 +12,17 @@ import (
) )
func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd { func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "ctrl+g":
if m.state == pendingResponse {
m.stopSignal <- struct{}{}
return shared.KeyHandled(msg)
}
return func() tea.Msg {
return shared.MsgViewChange(shared.ViewSettings)
}
}
switch m.focus { switch m.focus {
case focusInput: case focusInput:
cmd := m.handleInputKey(msg) cmd := m.handleInputKey(msg)

View File

@ -288,7 +288,12 @@ func (m *Model) Footer(width int) string {
rightSegments = append(rightSegments, segmentStyle.Render(throughput)) rightSegments = append(rightSegments, segmentStyle.Render(throughput))
} }
model := fmt.Sprintf("Model: %s", *m.App.Ctx.Config.Defaults.Model) if m.App.ProviderName != "" {
provider := fmt.Sprintf("Provider: %s", m.App.ProviderName)
rightSegments = append(rightSegments, segmentStyle.Render(provider))
}
model := fmt.Sprintf("Model: %s", m.App.Model)
rightSegments = append(rightSegments, segmentStyle.Render(model)) rightSegments = append(rightSegments, segmentStyle.Render(model))
left := strings.Join(leftSegments, segmentSeparator) left := strings.Join(leftSegments, segmentSeparator)

View File

@ -0,0 +1,137 @@
package settings
import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type Model struct {
App *model.AppModel
prevView shared.View
content viewport.Model
modelList list.Model
width int
height int
}
type modelOpt struct {
provider string
model string
}
const (
modelListId int = iota + 1
)
func Settings(app *model.AppModel) *Model {
m := &Model{
App: app,
content: viewport.New(0, 0),
}
return m
}
func (m *Model) Init() tea.Cmd {
m.modelList = list.NewWithGroups(m.getModelOptions())
m.modelList.ID = modelListId
return nil
}
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
m.modelList, cmd = m.modelList.Update(msg)
if cmd != nil {
return m, cmd
}
switch msg.String() {
case "esc":
return m, func() tea.Msg {
return shared.MsgViewChange(m.prevView)
}
}
case shared.MsgViewEnter:
m.prevView = shared.View(msg)
m.modelList.Focus()
m.content.SetContent(m.renderContent())
case tea.WindowSizeMsg:
m.width, m.height = msg.Width, msg.Height
m.content.Width = msg.Width
m.content.Height = msg.Height
m.content.SetContent(m.renderContent())
case list.MsgOptionSelected:
switch msg.ID {
case modelListId:
if modelOpt, ok := msg.Option.Value.(modelOpt); ok {
m.App.Model = modelOpt.model
m.App.ProviderName = modelOpt.provider
}
return m, shared.ChangeView(m.prevView)
}
}
m.modelList, cmd = m.modelList.Update(msg)
if cmd != nil {
return m, cmd
}
m.content.SetContent(m.renderContent())
return m, nil
}
func (m *Model) getModelOptions() []list.OptionGroup {
modelOpts := []list.OptionGroup{}
for _, p := range m.App.Ctx.Config.Providers {
provider := p.Name
if provider == "" {
provider = p.Kind
}
providerLabel := p.Display
if providerLabel == "" {
providerLabel = strings.ToUpper(provider[:1]) + provider[1:]
}
group := list.OptionGroup{
Name: providerLabel,
}
for _, model := range p.Models {
group.Options = append(group.Options, list.Option{
Label: model,
Value: modelOpt{provider, model},
})
}
modelOpts = append(modelOpts, group)
}
return modelOpts
}
func (m *Model) Header(width int) string {
boldStyle := lipgloss.NewStyle().Bold(true)
// TODO: update header depending on active settings mode (model, agent, etc)
header := boldStyle.Render("Model selection")
return styles.Header.Width(width).Render(header)
}
func (m *Model) Content(width, height int) string {
// TODO: see Header()
currentModel := " Active model: " + m.App.ActiveModel(lipgloss.NewStyle())
m.modelList.Width, m.modelList.Height = width, height - 2
return "\n" + currentModel + "\n" + m.modelList.View()
}
func (m *Model) Footer(width int) string {
return ""
}
func (m *Model) renderContent() string {
return m.modelList.View()
}