230 lines
6.1 KiB
Go
230 lines
6.1 KiB
Go
package model
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
|
)
|
|
|
|
type LoadedConversation struct {
|
|
Conv api.Conversation
|
|
LastReply api.Message
|
|
}
|
|
|
|
type AppModel struct {
|
|
Ctx *lmcli.Context
|
|
Conversations []LoadedConversation
|
|
Conversation *api.Conversation
|
|
RootMessages []api.Message
|
|
Messages []api.Message
|
|
}
|
|
|
|
type MessageCycleDirection int
|
|
|
|
const (
|
|
CycleNext MessageCycleDirection = 1
|
|
CyclePrev MessageCycleDirection = -1
|
|
)
|
|
|
|
func (m *AppModel) ClearConversation() {
|
|
m.Conversation = nil
|
|
m.Messages = []api.Message{}
|
|
m.RootMessages = []api.Message{}
|
|
}
|
|
|
|
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
|
|
messages, err := m.Ctx.Store.LatestConversationMessages()
|
|
if err != nil {
|
|
return shared.MsgError(fmt.Errorf("Could not load conversations: %v", err)), nil
|
|
}
|
|
|
|
conversations := make([]LoadedConversation, len(messages))
|
|
for i, msg := range messages {
|
|
conversations[i] = LoadedConversation{
|
|
Conv: *msg.Conversation,
|
|
LastReply: msg,
|
|
}
|
|
}
|
|
return nil, conversations
|
|
}
|
|
|
|
func (a *AppModel) LoadConversationRootMessages() ([]api.Message, error) {
|
|
messages, err := a.Ctx.Store.RootMessages(a.Conversation.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not load conversation root messages: %v %v", a.Conversation.SelectedRoot, err)
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func (a *AppModel) LoadConversationMessages() ([]api.Message, error) {
|
|
messages, err := a.Ctx.Store.PathToLeaf(a.Conversation.SelectedRoot)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func (a *AppModel) GenerateConversationTitle(messages []api.Message) (string, error) {
|
|
return cmdutil.GenerateTitle(a.Ctx, messages)
|
|
}
|
|
|
|
func (a *AppModel) UpdateConversationTitle(conversation *api.Conversation) error {
|
|
return a.Ctx.Store.UpdateConversation(conversation)
|
|
}
|
|
|
|
func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Message, error) {
|
|
msg, _, err := a.Ctx.Store.CloneBranch(message)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not clone message: %v", err)
|
|
}
|
|
if selected {
|
|
if msg.Parent == nil {
|
|
msg.Conversation.SelectedRoot = msg
|
|
err = a.Ctx.Store.UpdateConversation(msg.Conversation)
|
|
} else {
|
|
msg.Parent.SelectedReply = msg
|
|
err = a.Ctx.Store.UpdateMessage(msg.Parent)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not update selected message: %v", err)
|
|
}
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
func (a *AppModel) UpdateMessageContent(message *api.Message) error {
|
|
return a.Ctx.Store.UpdateMessage(message)
|
|
}
|
|
|
|
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
|
if len(rootMessages) < 2 {
|
|
return nil, nil
|
|
}
|
|
|
|
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, rootMessages, dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conv.SelectedRoot = nextRoot
|
|
err = a.Ctx.Store.UpdateConversation(conv)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
|
|
}
|
|
return nextRoot, nil
|
|
}
|
|
|
|
func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
|
if len(message.Replies) < 2 {
|
|
return nil, nil
|
|
}
|
|
|
|
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
message.SelectedReply = nextReply
|
|
err = a.Ctx.Store.UpdateMessage(message)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
|
|
}
|
|
return nextReply, nil
|
|
}
|
|
|
|
func (a *AppModel) PersistConversation(conversation *api.Conversation, messages []api.Message) (*api.Conversation, []api.Message, error) {
|
|
var err error
|
|
if conversation == nil || conversation.ID == 0 {
|
|
conversation, messages, err = a.Ctx.Store.StartConversation(messages...)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
|
|
}
|
|
return conversation, messages, nil
|
|
}
|
|
|
|
for i := range messages {
|
|
if messages[i].ID > 0 {
|
|
err := a.Ctx.Store.UpdateMessage(&messages[i])
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
} else if i > 0 {
|
|
saved, err := a.Ctx.Store.Reply(&messages[i-1], messages[i])
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
messages[i] = saved[0]
|
|
} else {
|
|
return nil, nil, fmt.Errorf("Error: no messages to reply to")
|
|
}
|
|
}
|
|
return conversation, messages, nil
|
|
}
|
|
|
|
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
|
|
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
|
|
if agent == nil {
|
|
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
|
|
}
|
|
|
|
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
|
|
}
|
|
|
|
func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) {
|
|
model, provider, err := a.Ctx.GetModelProvider(*a.Ctx.Config.Defaults.Model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
params := api.RequestParameters{
|
|
Model: model,
|
|
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
|
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
|
}
|
|
|
|
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
|
|
if agent != nil {
|
|
params.Toolbox = agent.Toolbox
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
go func() {
|
|
select {
|
|
case <-stopSignal:
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
return provider.CreateChatCompletionStream(
|
|
ctx, params, messages, chatReplyChunks,
|
|
)
|
|
}
|
|
|
|
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
|
currentIndex := -1
|
|
for i, reply := range choices {
|
|
if reply.ID == selected.ID {
|
|
currentIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentIndex < 0 {
|
|
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
|
|
}
|
|
|
|
var next int
|
|
if dir == CyclePrev {
|
|
next = (currentIndex - 1 + len(choices)) % len(choices)
|
|
} else {
|
|
next = (currentIndex + 1) % len(choices)
|
|
}
|
|
return &choices[next], nil
|
|
}
|