lmcli/pkg/tui/model/model.go

290 lines
7.5 KiB
Go
Raw Normal View History

package model
import (
"context"
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"github.com/charmbracelet/lipgloss"
)
type LoadedConversation struct {
Conv api.Conversation
LastReply api.Message
}
type AppModel struct {
Ctx *lmcli.Context
Conversations []LoadedConversation
Conversation *api.Conversation
RootMessages []api.Message
Messages []api.Message
Model string
ProviderName string
Provider provider.ChatCompletionProvider
}
func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Conversation: initialConversation,
Model: *ctx.Config.Defaults.Model,
}
if initialConversation == nil {
app.NewConversation()
}
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model
app.ProviderName = provider
return app
}
var (
defaultStyle = lipgloss.NewStyle().Faint(true)
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
)
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
}
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
2024-09-20 20:47:03 -06:00
func (m *AppModel) ClearConversation() {
m.Conversation = nil
m.Messages = []api.Message{}
m.RootMessages = []api.Message{}
}
func (m *AppModel) ApplySystemPrompt() {
var system string
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system == "" {
system = m.Ctx.DefaultSystemPrompt()
}
if system != "" {
m.Messages = api.ApplySystemPrompt(m.Messages, system, false)
}
}
func (m *AppModel) NewConversation() {
m.ClearConversation()
m.ApplySystemPrompt()
}
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
messages, err := m.Ctx.Store.LatestConversationMessages()
if err != nil {
return fmt.Errorf("Could not load conversations: %v", err), nil
}
conversations := make([]LoadedConversation, len(messages))
for i, msg := range messages {
conversations[i] = LoadedConversation{
Conv: *msg.Conversation,
LastReply: msg,
}
}
return nil, conversations
}
func (a *AppModel) LoadConversationRootMessages() ([]api.Message, error) {
messages, err := a.Ctx.Store.RootMessages(a.Conversation.ID)
if err != nil {
return nil, fmt.Errorf("Could not load conversation root messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) LoadConversationMessages() ([]api.Message, error) {
messages, err := a.Ctx.Store.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []api.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) UpdateConversationTitle(conversation *api.Conversation) error {
return a.Ctx.Store.UpdateConversation(conversation)
}
func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Message, error) {
msg, _, err := a.Ctx.Store.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Store.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Store.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
}
}
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *api.Message) error {
return a.Ctx.Store.UpdateMessage(message)
}
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
if len(rootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, rootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Store.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDirection) (*api.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = a.Ctx.Store.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistConversation(conversation *api.Conversation, messages []api.Message) (*api.Conversation, []api.Message, error) {
var err error
if conversation == nil || conversation.ID == 0 {
conversation, messages, err = a.Ctx.Store.StartConversation(messages...)
if err != nil {
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
}
return conversation, messages, nil
}
for i := range messages {
if messages[i].ID > 0 {
err := a.Ctx.Store.UpdateMessage(&messages[i])
if err != nil {
return nil, nil, err
}
} else if i > 0 {
saved, err := a.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil {
return nil, nil, err
}
messages[i] = saved[0]
} else {
return nil, nil, fmt.Errorf("Error: no messages to reply to")
}
}
return conversation, messages, nil
}
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent == nil {
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
}
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
}
func (a *AppModel) Prompt(
messages []api.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*api.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
}
params := provider.RequestParameters{
Model: model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature,
}
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent != nil {
params.Toolbox = agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-stopSignal:
cancel()
}
}()
msg, err := p.CreateChatCompletionStream(
ctx, params, messages, chatReplyChunks,
)
if msg != nil {
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
}
return msg, err
}