lmcli/pkg/tui/model/model.go

230 lines
6.1 KiB
Go
Raw Normal View History

package model
import (
"context"
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
)
type LoadedConversation struct {
Conv api.Conversation
LastReply api.Message
}
type AppModel struct {
Ctx *lmcli.Context
Conversations []LoadedConversation
Conversation *api.Conversation
RootMessages []api.Message
Messages []api.Message
}
type MessageCycleDirection int
const (
CycleNext MessageCycleDirection = 1
CyclePrev MessageCycleDirection = -1
)
2024-09-20 20:47:03 -06:00
func (m *AppModel) ClearConversation() {
m.Conversation = nil
m.Messages = []api.Message{}
m.RootMessages = []api.Message{}
}
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
messages, err := m.Ctx.Store.LatestConversationMessages()
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversations: %v", err)), nil
}
conversations := make([]LoadedConversation, len(messages))
for i, msg := range messages {
conversations[i] = LoadedConversation{
Conv: *msg.Conversation,
LastReply: msg,
}
}
return nil, conversations
}
func (a *AppModel) LoadConversationRootMessages() ([]api.Message, error) {
messages, err := a.Ctx.Store.RootMessages(a.Conversation.ID)
if err != nil {
return nil, fmt.Errorf("Could not load conversation root messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) LoadConversationMessages() ([]api.Message, error) {
messages, err := a.Ctx.Store.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []api.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) UpdateConversationTitle(conversation *api.Conversation) error {
return a.Ctx.Store.UpdateConversation(conversation)
}
func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Message, error) {
msg, _, err := a.Ctx.Store.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Store.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Store.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
}
}
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *api.Message) error {
return a.Ctx.Store.UpdateMessage(message)
}
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
if len(rootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, rootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Store.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDirection) (*api.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = a.Ctx.Store.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistConversation(conversation *api.Conversation, messages []api.Message) (*api.Conversation, []api.Message, error) {
var err error
if conversation == nil || conversation.ID == 0 {
conversation, messages, err = a.Ctx.Store.StartConversation(messages...)
if err != nil {
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
}
return conversation, messages, nil
}
for i := range messages {
if messages[i].ID > 0 {
err := a.Ctx.Store.UpdateMessage(&messages[i])
if err != nil {
return nil, nil, err
}
} else if i > 0 {
saved, err := a.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil {
return nil, nil, err
}
messages[i] = saved[0]
} else {
return nil, nil, fmt.Errorf("Error: no messages to reply to")
}
}
return conversation, messages, nil
}
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent == nil {
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
}
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
}
func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) {
model, provider, err := a.Ctx.GetModelProvider(*a.Ctx.Config.Defaults.Model)
if err != nil {
return nil, err
}
params := api.RequestParameters{
Model: model,
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
Temperature: *a.Ctx.Config.Defaults.Temperature,
}
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
if agent != nil {
params.Toolbox = agent.Toolbox
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-stopSignal:
cancel()
}
}()
return provider.CreateChatCompletionStream(
ctx, params, messages, chatReplyChunks,
)
}
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
}
var next int
if dir == CyclePrev {
next = (currentIndex - 1 + len(choices)) % len(choices)
} else {
next = (currentIndex + 1) % len(choices)
}
return &choices[next], nil
}