lmcli/pkg/tui/views/chat/conversation.go

247 lines
6.0 KiB
Go
Raw Normal View History

2024-06-02 16:40:46 -06:00
package chat
import (
"context"
"fmt"
"time"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea"
)
func (m *Model) setMessage(i int, msg models.Message) {
if i >= len(m.messages) {
panic("i out of range")
}
m.messages[i] = msg
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) addMessage(msg models.Message) {
m.messages = append(m.messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
}
func (m *Model) setMessageContents(i int, content string) {
if i >= len(m.messages) {
panic("i out of range")
}
m.messages[i].Content = content
m.messageCache[i] = m.renderMessage(i)
}
func (m *Model) rebuildMessageCache() {
m.messageCache = make([]string, len(m.messages))
for i := range m.messages {
m.messageCache[i] = m.renderMessage(i)
}
}
func (m *Model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationMessagesView())
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
}
func (m *Model) loadConversation(shortname string) tea.Cmd {
return func() tea.Msg {
if shortname == "" {
return nil
}
c, err := m.State.Ctx.Store.ConversationByShortName(shortname)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err))
}
if c.ID == 0 {
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
}
return msgConversationLoaded(c)
}
}
func (m *Model) loadMessages(c *models.Conversation) tea.Cmd {
return func() tea.Msg {
messages, err := m.State.Ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
return msgMessagesLoaded(messages)
}
}
func (m *Model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.State.Ctx, m.messages)
if err != nil {
return shared.MsgError(err)
}
return msgConversationTitleChanged(title)
}
}
func cycleMessages(curr *models.Message, msgs []models.Message, dir MessageCycleDirection) (*models.Message, error) {
currentIndex := -1
for i, reply := range msgs {
if reply.ID == curr.ID {
currentIndex = i
break
}
}
if currentIndex < 0 {
return nil, fmt.Errorf("message not found")
}
var next int
if dir == CyclePrev {
// Wrap around to the last reply if at the beginning
next = (currentIndex - 1 + len(msgs)) % len(msgs)
} else {
// Wrap around to the first reply if at the end
next = (currentIndex + 1) % len(msgs)
}
return &msgs[next], nil
}
func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) (*models.Message, error) {
if len(m.rootMessages) < 2 {
return nil, nil
}
nextRoot, err := cycleMessages(conv.SelectedRoot, m.rootMessages, dir)
if err != nil {
return nil, err
}
conv.SelectedRoot = nextRoot
err = m.State.Ctx.Store.UpdateConversation(conv)
if err != nil {
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
}
return nextRoot, nil
}
func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) (*models.Message, error) {
if len(message.Replies) < 2 {
return nil, nil
}
nextReply, err := cycleMessages(message.SelectedReply, message.Replies, dir)
if err != nil {
return nil, err
}
message.SelectedReply = nextReply
err = m.State.Ctx.Store.UpdateMessage(message)
if err != nil {
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
}
return nextReply, nil
}
func (m *Model) persistConversation() error {
if m.conversation.ID == 0 {
// Start a new conversation with all messages so far
c, messages, err := m.State.Ctx.Store.StartConversation(m.messages...)
if err != nil {
return err
}
m.conversation = c
m.messages = messages
return nil
}
// else, we'll handle updating an existing conversation's messages
for i := 0; i < len(m.messages); i++ {
if m.messages[i].ID > 0 {
// message has an ID, update its contents
// TODO: check for content/tool equality before updating?
err := m.State.Ctx.Store.UpdateMessage(&m.messages[i])
if err != nil {
return err
}
} else if i > 0 {
// messages is new, so add it as a reply to previous message
saved, err := m.State.Ctx.Store.Reply(&m.messages[i-1], m.messages[i])
if err != nil {
return err
}
// add this message as a reply to the previous
m.messages[i-1].Replies = append(m.messages[i-1].Replies, saved[0])
m.messages[i] = saved[0]
} else {
// message has no id and no previous messages to add it to
// this shouldn't happen?
return fmt.Errorf("Error: no messages to reply to")
}
}
return nil
}
func (m *Model) promptLLM() tea.Cmd {
m.waitingForReply = true
m.replyCursor.Blink = false
m.status = "Press ctrl+c to cancel"
toPrompt := m.messages
// Add response placeholder message
if m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant {
m.addMessage(models.Message{
Role: models.MessageRoleAssistant,
Content: "",
})
}
m.tokenCount = 0
m.startTime = time.Now()
m.elapsed = 0
return func() tea.Msg {
model, provider, err := m.State.Ctx.GetModelProvider(*m.State.Ctx.Config.Defaults.Model)
if err != nil {
return shared.MsgError(err)
}
requestParams := models.RequestParameters{
Model: model,
MaxTokens: *m.State.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.State.Ctx.Config.Defaults.Temperature,
ToolBag: m.State.Ctx.EnabledTools,
}
replyHandler := func(msg models.Message) {
m.replyChan <- msg
}
ctx, cancel := context.WithCancel(context.Background())
canceled := false
go func() {
select {
case <-m.stopSignal:
canceled = true
cancel()
}
}()
resp, err := provider.CreateChatCompletionStream(
ctx, requestParams, toPrompt, replyHandler, m.replyChunkChan,
)
if err != nil && !canceled {
return msgResponseError(err)
}
return msgResponseEnd(resp)
}
}