Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
285 lines
7.6 KiB
Go
285 lines
7.6 KiB
Go
package model
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
|
"github.com/charmbracelet/lipgloss"
|
|
)
|
|
|
|
type LoadedConversation struct {
|
|
Conv conversation.Conversation
|
|
LastReply conversation.Message
|
|
}
|
|
|
|
type AppModel struct {
|
|
Ctx *lmcli.Context
|
|
Conversations []LoadedConversation
|
|
Conversation *conversation.Conversation
|
|
Messages []conversation.Message
|
|
Model string
|
|
ProviderName string
|
|
Provider provider.ChatCompletionProvider
|
|
Agent *lmcli.Agent
|
|
}
|
|
|
|
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
|
|
app := &AppModel{
|
|
Ctx: ctx,
|
|
Conversation: initialConversation,
|
|
Model: *ctx.Config.Defaults.Model,
|
|
}
|
|
|
|
if initialConversation == nil {
|
|
app.NewConversation()
|
|
}
|
|
|
|
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
|
app.Model = model
|
|
app.ProviderName = provider
|
|
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
|
return app
|
|
}
|
|
|
|
var (
|
|
defaultStyle = lipgloss.NewStyle().Faint(true)
|
|
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
|
|
)
|
|
|
|
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
|
|
defaultStyle := style.Inherit(defaultStyle)
|
|
accentStyle := style.Inherit(accentStyle)
|
|
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
|
|
}
|
|
|
|
type MessageCycleDirection int
|
|
|
|
const (
|
|
CycleNext MessageCycleDirection = 1
|
|
CyclePrev MessageCycleDirection = -1
|
|
)
|
|
|
|
func (m *AppModel) ClearConversation() {
|
|
m.Conversation = nil
|
|
m.Messages = []conversation.Message{}
|
|
}
|
|
|
|
func (m *AppModel) ApplySystemPrompt() {
|
|
var system string
|
|
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
|
|
if agent != nil && agent.SystemPrompt != "" {
|
|
system = agent.SystemPrompt
|
|
}
|
|
if system == "" {
|
|
system = m.Ctx.DefaultSystemPrompt()
|
|
}
|
|
if system != "" {
|
|
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
|
|
}
|
|
}
|
|
|
|
func (m *AppModel) NewConversation() {
|
|
m.ClearConversation()
|
|
m.ApplySystemPrompt()
|
|
}
|
|
|
|
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
|
|
messages, err := m.Ctx.Conversations.LatestConversationMessages()
|
|
if err != nil {
|
|
return fmt.Errorf("Could not load conversations: %v", err), nil
|
|
}
|
|
|
|
conversations := make([]LoadedConversation, len(messages))
|
|
for i, msg := range messages {
|
|
conversations[i] = LoadedConversation{
|
|
Conv: *msg.Conversation,
|
|
LastReply: msg,
|
|
}
|
|
}
|
|
return nil, conversations
|
|
}
|
|
|
|
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
|
|
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
|
|
return cmdutil.GenerateTitle(a.Ctx, messages)
|
|
}
|
|
|
|
func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error {
|
|
return a.Ctx.Conversations.UpdateConversation(conversation)
|
|
}
|
|
|
|
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
|
|
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not clone message: %v", err)
|
|
}
|
|
if selected {
|
|
if msg.Parent == nil {
|
|
msg.Conversation.SelectedRoot = msg
|
|
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
|
|
} else {
|
|
msg.Parent.SelectedReply = msg
|
|
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not update selected message: %v", err)
|
|
}
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
|
|
return a.Ctx.Conversations.UpdateMessage(message)
|
|
}
|
|
|
|
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
|
|
currentIndex := -1
|
|
for i, reply := range choices {
|
|
if reply.ID == selected.ID {
|
|
currentIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentIndex < 0 {
|
|
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
|
|
}
|
|
|
|
var next int
|
|
if dir == CyclePrev {
|
|
next = (currentIndex - 1 + len(choices)) % len(choices)
|
|
} else {
|
|
next = (currentIndex + 1) % len(choices)
|
|
}
|
|
return &choices[next], nil
|
|
}
|
|
|
|
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
|
|
if len(conv.RootMessages) < 2 {
|
|
return nil, nil
|
|
}
|
|
|
|
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
conv.SelectedRoot = nextRoot
|
|
err = a.Ctx.Conversations.UpdateConversation(conv)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
|
|
}
|
|
return nextRoot, nil
|
|
}
|
|
|
|
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
|
|
if len(message.Replies) < 2 {
|
|
return nil, nil
|
|
}
|
|
|
|
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
message.SelectedReply = nextReply
|
|
err = a.Ctx.Conversations.UpdateMessage(message)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
|
|
}
|
|
return nextReply, nil
|
|
}
|
|
|
|
func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) {
|
|
var err error
|
|
if conversation == nil || conversation.ID == 0 {
|
|
conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
|
|
}
|
|
return conversation, messages, nil
|
|
}
|
|
|
|
for i := range messages {
|
|
if messages[i].ID > 0 {
|
|
err := a.Ctx.Conversations.UpdateMessage(&messages[i])
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
} else if i > 0 {
|
|
saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i])
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
messages[i] = saved[0]
|
|
} else {
|
|
return nil, nil, fmt.Errorf("Error: no messages to reply to")
|
|
}
|
|
}
|
|
return conversation, messages, nil
|
|
}
|
|
|
|
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
|
|
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
|
|
if agent == nil {
|
|
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
|
|
}
|
|
|
|
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
|
|
}
|
|
|
|
func (a *AppModel) Prompt(
|
|
messages []conversation.Message,
|
|
chatReplyChunks chan provider.Chunk,
|
|
stopSignal chan struct{},
|
|
) (*conversation.Message, error) {
|
|
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
params := provider.RequestParameters{
|
|
Model: model,
|
|
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
|
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
|
}
|
|
|
|
if a.Agent != nil {
|
|
params.Toolbox = a.Agent.Toolbox
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
go func() {
|
|
select {
|
|
case <-stopSignal:
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
msg, err := p.CreateChatCompletionStream(
|
|
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
|
|
)
|
|
|
|
if msg != nil {
|
|
msg := conversation.MessageFromAPI(*msg)
|
|
msg.Metadata.GenerationProvider = &a.ProviderName
|
|
msg.Metadata.GenerationModel = &a.Model
|
|
return &msg, err
|
|
}
|
|
return nil, err
|
|
}
|