Private
Public Access
1
0

Large refactor - it compiles!

This refactor splits out all conversation concerns into a new
`conversation` package. There is now a split between `conversation` and
`api`s representation of `Message`, the latter storing the minimum
information required for interaction with LLM providers. There is
necessary conversation between the two when making LLM calls.
This commit is contained in:
2024-10-20 02:38:42 +00:00
parent 2ea8a73eb5
commit 0384c7cb66
33 changed files with 701 additions and 626 deletions

View File

@@ -6,30 +6,30 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/agents"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/conversation"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/provider"
"github.com/charmbracelet/lipgloss"
)
type LoadedConversation struct {
Conv api.Conversation
LastReply api.Message
Conv conversation.Conversation
LastReply conversation.Message
}
type AppModel struct {
Ctx *lmcli.Context
Conversations []LoadedConversation
Conversation *api.Conversation
RootMessages []api.Message
Messages []api.Message
Conversation *conversation.Conversation
Messages []conversation.Message
Model string
ProviderName string
Provider provider.ChatCompletionProvider
Agent *lmcli.Agent
Agent *lmcli.Agent
}
func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel {
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
app := &AppModel{
Ctx: ctx,
Conversation: initialConversation,
@@ -67,8 +67,7 @@ const (
func (m *AppModel) ClearConversation() {
m.Conversation = nil
m.Messages = []api.Message{}
m.RootMessages = []api.Message{}
m.Messages = []conversation.Message{}
}
func (m *AppModel) ApplySystemPrompt() {
@@ -81,7 +80,7 @@ func (m *AppModel) ApplySystemPrompt() {
system = m.Ctx.DefaultSystemPrompt()
}
if system != "" {
m.Messages = api.ApplySystemPrompt(m.Messages, system, false)
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
}
}
@@ -91,7 +90,7 @@ func (m *AppModel) NewConversation() {
}
func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
messages, err := m.Ctx.Store.LatestConversationMessages()
messages, err := m.Ctx.Conversations.LatestConversationMessages()
if err != nil {
return fmt.Errorf("Could not load conversations: %v", err), nil
}
@@ -106,42 +105,34 @@ func (m *AppModel) LoadConversations() (error, []LoadedConversation) {
return nil, conversations
}
func (a *AppModel) LoadConversationRootMessages() ([]api.Message, error) {
messages, err := a.Ctx.Store.RootMessages(a.Conversation.ID)
if err != nil {
return nil, fmt.Errorf("Could not load conversation root messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) LoadConversationMessages() ([]api.Message, error) {
messages, err := a.Ctx.Store.PathToLeaf(a.Conversation.SelectedRoot)
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
if err != nil {
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
}
return messages, nil
}
func (a *AppModel) GenerateConversationTitle(messages []api.Message) (string, error) {
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
return cmdutil.GenerateTitle(a.Ctx, messages)
}
func (a *AppModel) UpdateConversationTitle(conversation *api.Conversation) error {
return a.Ctx.Store.UpdateConversation(conversation)
func (a *AppModel) UpdateConversationTitle(conversation *conversation.Conversation) error {
return a.Ctx.Conversations.UpdateConversation(conversation)
}
func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Message, error) {
msg, _, err := a.Ctx.Store.CloneBranch(message)
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
if err != nil {
return nil, fmt.Errorf("Could not clone message: %v", err)
}
if selected {
if msg.Parent == nil {
msg.Conversation.SelectedRoot = msg
err = a.Ctx.Store.UpdateConversation(msg.Conversation)
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
} else {
msg.Parent.SelectedReply = msg
err = a.Ctx.Store.UpdateMessage(msg.Parent)
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
}
if err != nil {
return nil, fmt.Errorf("Could not update selected message: %v", err)
@@ -150,11 +141,11 @@ func (a *AppModel) CloneMessage(message api.Message, selected bool) (*api.Messag
return msg, nil
}
func (a *AppModel) UpdateMessageContent(message *api.Message) error {
return a.Ctx.Store.UpdateMessage(message)
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
return a.Ctx.Conversations.UpdateMessage(message)
}
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
currentIndex := -1
for i, reply := range choices {
if reply.ID == selected.ID {
@@ -176,25 +167,25 @@ func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir Mess
return &choices[next], nil
}
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
if len(rootMessages) < 2 {
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
if len(conv.RootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, rootMessages, dir)
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = a.Ctx.Store.UpdateConversation(conv)
err = a.Ctx.Conversations.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDirection) (*api.Message, error) {
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
@@ -205,17 +196,17 @@ func (a *AppModel) CycleSelectedReply(message *api.Message, dir MessageCycleDire
}
message.SelectedReply = nextReply
err = a.Ctx.Store.UpdateMessage(message)
err = a.Ctx.Conversations.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (a *AppModel) PersistConversation(conversation *api.Conversation, messages []api.Message) (*api.Conversation, []api.Message, error) {
func (a *AppModel) PersistConversation(conversation *conversation.Conversation, messages []conversation.Message) (*conversation.Conversation, []conversation.Message, error) {
var err error
if conversation == nil || conversation.ID == 0 {
conversation, messages, err = a.Ctx.Store.StartConversation(messages...)
conversation, messages, err = a.Ctx.Conversations.StartConversation(messages...)
if err != nil {
return nil, nil, fmt.Errorf("Could not start new conversation: %v", err)
}
@@ -224,12 +215,12 @@ func (a *AppModel) PersistConversation(conversation *api.Conversation, messages
for i := range messages {
if messages[i].ID > 0 {
err := a.Ctx.Store.UpdateMessage(&messages[i])
err := a.Ctx.Conversations.UpdateMessage(&messages[i])
if err != nil {
return nil, nil, err
}
} else if i > 0 {
saved, err := a.Ctx.Store.Reply(&messages[i-1], messages[i])
saved, err := a.Ctx.Conversations.Reply(&messages[i-1], messages[i])
if err != nil {
return nil, nil, err
}
@@ -251,10 +242,10 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult,
}
func (a *AppModel) Prompt(
messages []api.Message,
messages []conversation.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*api.Message, error) {
) (*conversation.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil {
return nil, err
@@ -280,11 +271,14 @@ func (a *AppModel) Prompt(
}()
msg, err := p.CreateChatCompletionStream(
ctx, params, messages, chatReplyChunks,
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
)
if msg != nil {
msg := conversation.MessageFromAPI(*msg)
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
return &msg, err
}
return msg, err
return nil, err
}