lmcli/pkg/tui/model/model.go

264 lines
7.1 KiB
Go
Raw Normal View History

package model
import (
"context"
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"github.com/charmbracelet/lipgloss"
)
type AppModel struct {
Ctx *lmcli.Context
Conversations conversation.ConversationList
Conversation *conversation.Conversation
Messages []conversation.Message
Model string
ProviderName string
Provider provider.ChatCompletionProvider
Agent *lmcli.Agent
}
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Conversation: initialConversation,
Model: *ctx.Config.Defaults.Model,
}
if initialConversation == nil {
app.NewConversation()
}
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model
app.ProviderName = provider
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app
}
var (
defaultStyle = lipgloss.NewStyle().Faint(true)
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
)
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
}
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
2024-09-20 20:47:03 -06:00
func (m *AppModel) ClearConversation() {
m.Conversation = nil
m.Messages = []conversation.Message{}
2024-09-20 20:47:03 -06:00
}
func (m *AppModel) ApplySystemPrompt() {
var system string
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system == "" {
system = m.Ctx.DefaultSystemPrompt()
}
if system != "" {
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
}
}
func (m *AppModel) NewConversation() {
m.ClearConversation()
m.ApplySystemPrompt()
}
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error {
return a.Ctx.Conversations.UpdateConversation(conversation)
}
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
}
}
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
return a.Ctx.Conversations.UpdateMessage(message)
}
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
if len(conv.RootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Conversations.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = a.Ctx.Conversations.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) {
var err error
if conversation == nil || conversation.ID == 0 {
conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...)
if err != nil {
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
}
return conversation, messages, nil
}
for i := range messages {
if messages[i].ID > 0 {
err := a.Ctx.Conversations.UpdateMessage(&messages[i])
if err != nil {
return nil, nil, err
}
} else if i > 0 {
saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i])
if err != nil {
return nil, nil, err
}
messages[i] = saved[0]
} else {
return nil, nil, fmt.Errorf("Error: no messages to reply to")
}
}
return conversation, messages, nil
}
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent == nil {
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
}
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
}
func (a *AppModel) Prompt(
messages []conversation.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*conversation.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
}
params := provider.RequestParameters{
Model: model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature,
}
if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-stopSignal:
cancel()
}
}()
msg, err := p.CreateChatCompletionStream(
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
)
if msg != nil {
msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
return &msg, err
}
return nil, err
}