lmcli/pkg/tui/model/model.go

282 lines
7.6 KiB
Go
Raw Normal View History

package model
import (
"context"
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"github.com/charmbracelet/lipgloss"
)
type AppModel struct {
Ctx *lmcli.Context
Conversations conversation.ConversationList
Conversation conversation.Conversation
Messages []conversation.Message
Model string
ProviderName string
Provider provider.ChatCompletionProvider
Agent *lmcli.Agent
}
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Model: *ctx.Config.Defaults.Model,
}
if initialConversation == nil {
app.NewConversation()
} else {
}
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
app.Model = model
app.ProviderName = provider
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
return app
}
var (
defaultStyle = lipgloss.NewStyle().Faint(true)
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
)
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
defaultStyle := style.Inherit(defaultStyle)
accentStyle := style.Inherit(accentStyle)
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
}
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
2024-09-20 20:47:03 -06:00
func (m *AppModel) ClearConversation() {
m.Conversation = conversation.Conversation{}
m.Messages = []conversation.Message{}
2024-09-20 20:47:03 -06:00
}
func (m *AppModel) ApplySystemPrompt() {
var system string
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
if agent != nil && agent.SystemPrompt != "" {
system = agent.SystemPrompt
}
if system == "" {
system = m.Ctx.DefaultSystemPrompt()
}
if system != "" {
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
}
}
func (m *AppModel) NewConversation() {
m.ClearConversation()
m.ApplySystemPrompt()
}
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
}
}
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
return a.Ctx.Conversations.UpdateMessage(message)
}
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
if len(conv.RootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Conversations.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = a.Ctx.Conversations.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistMessages() ([]conversation.Message, error) {
messages := make([]conversation.Message, len(a.Messages))
for i, m := range a.Messages {
if i == 0 && m.ID == 0 {
m.Conversation = &a.Conversation
m, err := a.Ctx.Conversations.SaveMessage(m)
if err != nil {
return nil, fmt.Errorf("Could not create first message %d: %v", a.Messages[i].ID, err)
}
messages[i] = *m
// let's set the conversation root message(s), as this is the first message
m.Conversation.RootMessages = []conversation.Message{*m}
m.Conversation.SelectedRoot = &m.Conversation.RootMessages[0]
a.Ctx.Conversations.UpdateConversation(m.Conversation)
} else if m.ID > 0 {
// Existing message, update it
err := a.Ctx.Conversations.UpdateMessage(&m)
if err != nil {
return nil, fmt.Errorf("Could not update message %d: %v", a.Messages[i].ID, err)
}
messages[i] = m
} else if i > 0 {
// New message, reply to previous
replies, err := a.Ctx.Conversations.Reply(&messages[i-1], m)
if err != nil {
return nil, fmt.Errorf("Could not reply with new message: %v", err)
}
messages[i] = replies[0]
} else {
return nil, fmt.Errorf("No messages to reply to (this is a bug)")
}
}
return messages, nil
}
func (a *AppModel) PersistConversation() (conversation.Conversation, error) {
conv := a.Conversation
var err error
if a.Conversation.ID > 0 {
err = a.Ctx.Conversations.UpdateConversation(&conv)
} else {
c, e := a.Ctx.Conversations.CreateConversation("")
err = e
if e == nil && c != nil {
conv = *c
}
}
return conv, err
}
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent == nil {
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
}
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
}
func (a *AppModel) Prompt(
messages []conversation.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*conversation.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
}
params := provider.RequestParameters{
Model: model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature,
}
if a.Agent != nil {
params.Toolbox = a.Agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-stopSignal:
cancel()
}
}()
msg, err := p.CreateChatCompletionStream(
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
)
if msg != nil {
msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
return &msg, err
}
return nil, err
}