Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
340 lines
9.2 KiB
Go
340 lines
9.2 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
|
"github.com/charmbracelet/lipgloss"
|
|
)
|
|
|
|
// Prompt prompts the configured the configured model and streams the response
|
|
// to stdout. Returns all model reply messages.
|
|
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
|
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
params := provider.RequestParameters{
|
|
Model: m,
|
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
|
Temperature: *ctx.Config.Defaults.Temperature,
|
|
}
|
|
|
|
system := ctx.DefaultSystemPrompt()
|
|
|
|
agent := ctx.GetAgent(ctx.Config.Defaults.Agent)
|
|
if agent != nil {
|
|
if agent.SystemPrompt != "" {
|
|
system = agent.SystemPrompt
|
|
}
|
|
params.Toolbox = agent.Toolbox
|
|
}
|
|
|
|
if system != "" {
|
|
messages = conversation.ApplySystemPrompt(messages, system, false)
|
|
}
|
|
|
|
content := make(chan provider.Chunk)
|
|
defer close(content)
|
|
|
|
// render the content received over the channel
|
|
go ShowDelayedContent(content)
|
|
|
|
reply, err := p.CreateChatCompletionStream(
|
|
context.Background(), params, conversation.MessagesToAPI(messages), content,
|
|
)
|
|
|
|
if reply.Content != "" {
|
|
// there was some content, so break to a new line after it
|
|
fmt.Println()
|
|
|
|
if err != nil {
|
|
lmcli.Warn("Received partial response. Error: %v\n", err)
|
|
err = nil
|
|
}
|
|
}
|
|
return reply, err
|
|
}
|
|
|
|
// lookupConversation either returns the conversation found by the
|
|
// short name or exits the program
|
|
func LookupConversation(ctx *lmcli.Context, shortName string) *conversation.Conversation {
|
|
c, err := ctx.Conversations.FindConversationByShortName(shortName)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
|
}
|
|
if c.ID == 0 {
|
|
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
|
}
|
|
return c
|
|
}
|
|
|
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*conversation.Conversation, error) {
|
|
c, err := ctx.Conversations.FindConversationByShortName(shortName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
|
}
|
|
if c.ID == 0 {
|
|
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func HandleConversationReply(ctx *lmcli.Context, c *conversation.Conversation, persist bool, toSend ...conversation.Message) {
|
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
|
}
|
|
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
|
}
|
|
|
|
// handleConversationReply handles sending messages to an existing
|
|
// conversation, optionally persisting both the sent replies and responses.
|
|
func HandleReply(ctx *lmcli.Context, to *conversation.Message, persist bool, messages ...conversation.Message) {
|
|
if to == nil {
|
|
lmcli.Fatal("Can't prompt from an empty message.")
|
|
}
|
|
|
|
existing, err := ctx.Conversations.PathToRoot(to)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
|
}
|
|
|
|
RenderConversation(ctx, append(existing, messages...), true)
|
|
|
|
var savedReplies []conversation.Message
|
|
if persist && len(messages) > 0 {
|
|
savedReplies, err = ctx.Conversations.Reply(to, messages...)
|
|
if err != nil {
|
|
lmcli.Warn("Could not save messages: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// render a message header with no contents
|
|
RenderMessage(ctx, (&conversation.Message{Role: api.MessageRoleAssistant}))
|
|
|
|
var lastSavedMessage *conversation.Message
|
|
lastSavedMessage = to
|
|
if len(savedReplies) > 0 {
|
|
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
|
}
|
|
|
|
replyCallback := func(reply conversation.Message) {
|
|
if !persist {
|
|
return
|
|
}
|
|
savedReplies, err = ctx.Conversations.Reply(lastSavedMessage, reply)
|
|
if err != nil {
|
|
lmcli.Warn("Could not save reply: %v\n", err)
|
|
}
|
|
lastSavedMessage = &savedReplies[0]
|
|
}
|
|
|
|
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
|
if err != nil {
|
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
|
}
|
|
}
|
|
|
|
func FormatForExternalPrompt(messages []conversation.Message, system bool) string {
|
|
sb := strings.Builder{}
|
|
for _, message := range messages {
|
|
if message.Content == "" {
|
|
continue
|
|
}
|
|
switch message.Role {
|
|
case api.MessageRoleAssistant, api.MessageRoleToolCall:
|
|
sb.WriteString("Assistant:\n\n")
|
|
case api.MessageRoleUser:
|
|
sb.WriteString("User:\n\n")
|
|
default:
|
|
continue
|
|
}
|
|
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func GenerateTitle(ctx *lmcli.Context, messages []conversation.Message) (string, error) {
|
|
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
|
|
|
Example conversation:
|
|
|
|
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}]
|
|
|
|
Example response:
|
|
|
|
{"title": "Help with math homework"}
|
|
`
|
|
type msg struct {
|
|
Role string
|
|
Content string
|
|
}
|
|
|
|
var msgs []msg
|
|
for _, m := range messages {
|
|
switch m.Role {
|
|
case api.MessageRoleAssistant, api.MessageRoleUser:
|
|
msgs = append(msgs, msg{string(m.Role), m.Content})
|
|
}
|
|
}
|
|
|
|
// Serialize the conversation to JSON
|
|
jsonBytes, err := json.Marshal(msgs)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
generateRequest := []conversation.Message{
|
|
{
|
|
Role: api.MessageRoleSystem,
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: api.MessageRoleUser,
|
|
Content: string(jsonBytes),
|
|
},
|
|
}
|
|
|
|
m, _, p, err := ctx.GetModelProvider(
|
|
*ctx.Config.Conversations.TitleGenerationModel, "",
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
requestParams := provider.RequestParameters{
|
|
Model: m,
|
|
MaxTokens: 25,
|
|
}
|
|
|
|
response, err := p.CreateChatCompletion(
|
|
context.Background(), requestParams, conversation.MessagesToAPI(generateRequest),
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Parse the JSON response
|
|
var jsonResponse struct {
|
|
Title string `json:"title"`
|
|
}
|
|
err = json.Unmarshal([]byte(response.Content), &jsonResponse)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return jsonResponse.Title, nil
|
|
}
|
|
|
|
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
|
// received on the signal channel. An empty string sent to the channel to
|
|
// notify the caller that the animation has completed (carriage returned).
|
|
func ShowWaitAnimation(signal chan any) {
|
|
// Save the current cursor position
|
|
fmt.Print("\033[s")
|
|
|
|
animationStep := 0
|
|
for {
|
|
select {
|
|
case _ = <-signal:
|
|
// Relmcli the cursor position
|
|
fmt.Print("\033[u")
|
|
signal <- ""
|
|
return
|
|
default:
|
|
// Move the cursor to the saved position
|
|
modSix := animationStep % 6
|
|
if modSix == 3 || modSix == 0 {
|
|
fmt.Print("\033[u")
|
|
}
|
|
if modSix < 3 {
|
|
fmt.Print(".")
|
|
} else {
|
|
fmt.Print(" ")
|
|
}
|
|
animationStep++
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
|
// for content to be received on the provided channel. As soon as any (possibly
|
|
// chunked) content is received on the channel, the waiting animation is
|
|
// replaced by the content.
|
|
// Blocks until the channel is closed.
|
|
func ShowDelayedContent(content <-chan provider.Chunk) {
|
|
waitSignal := make(chan any)
|
|
go ShowWaitAnimation(waitSignal)
|
|
|
|
firstChunk := true
|
|
for chunk := range content {
|
|
if firstChunk {
|
|
// notify wait animation that we've received data
|
|
waitSignal <- ""
|
|
// wait for signal that wait animation has completed
|
|
<-waitSignal
|
|
firstChunk = false
|
|
}
|
|
fmt.Print(chunk.Content)
|
|
}
|
|
}
|
|
|
|
// RenderConversation renders the given messages to TTY, with optional space
|
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
|
// are printed immediately after the final message (1 if false, 2 if true)
|
|
func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spaceForResponse bool) {
|
|
l := len(messages)
|
|
for i, message := range messages {
|
|
RenderMessage(ctx, &message)
|
|
if i < l-1 || spaceForResponse {
|
|
// print an additional space before the next message
|
|
fmt.Println()
|
|
}
|
|
}
|
|
}
|
|
|
|
func RenderMessage(ctx *lmcli.Context, m *conversation.Message) {
|
|
var messageAge string
|
|
if m.CreatedAt.IsZero() {
|
|
messageAge = "now"
|
|
} else {
|
|
now := time.Now()
|
|
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
|
}
|
|
|
|
headerStyle := lipgloss.NewStyle().Bold(true)
|
|
|
|
switch m.Role {
|
|
case api.MessageRoleSystem:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
|
case api.MessageRoleUser:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
|
case api.MessageRoleAssistant:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
|
}
|
|
|
|
role := headerStyle.Render(m.Role.FriendlyRole())
|
|
|
|
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
|
separator := separatorStyle.Render("===")
|
|
timestamp := separatorStyle.Render(messageAge)
|
|
|
|
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
|
if m.Content != "" {
|
|
ctx.Chroma.Highlight(os.Stdout, m.Content)
|
|
fmt.Println()
|
|
}
|
|
}
|