Matt Low
91d3c9c2e1
Instead of CreateChatCompletion* accepting a pointer to a slice of reply messages, it accepts a callback which is called with each successive reply the conversation. This gives the caller more flexibility in how it handles replies (e.g. it can react to them immediately now, instead of waiting for the entire call to finish)
264 lines
7.4 KiB
Go
264 lines
7.4 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
|
"github.com/charmbracelet/lipgloss"
|
|
)
|
|
|
|
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
|
// the response to stdout. Returns all model reply messages.
|
|
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
|
content := make(chan string) // receives the reponse from LLM
|
|
defer close(content)
|
|
|
|
// render all content received over the channel
|
|
go ShowDelayedContent(content)
|
|
|
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
requestParams := model.RequestParameters{
|
|
Model: *ctx.Config.Defaults.Model,
|
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
|
Temperature: *ctx.Config.Defaults.Temperature,
|
|
ToolBag: ctx.EnabledTools,
|
|
}
|
|
|
|
response, err := completionProvider.CreateChatCompletionStream(
|
|
context.Background(), requestParams, messages, callback, content,
|
|
)
|
|
if response != "" {
|
|
// there was some content, so break to a new line after it
|
|
fmt.Println()
|
|
|
|
if err != nil {
|
|
lmcli.Warn("Received partial response. Error: %v\n", err)
|
|
err = nil
|
|
}
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
// lookupConversation either returns the conversation found by the
|
|
// short name or exits the program
|
|
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
|
}
|
|
if c.ID == 0 {
|
|
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
|
}
|
|
return c
|
|
}
|
|
|
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
|
}
|
|
if c.ID == 0 {
|
|
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// handleConversationReply handles sending messages to an existing
|
|
// conversation, optionally persisting both the sent replies and responses.
|
|
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
|
existing, err := ctx.Store.Messages(c)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
|
}
|
|
|
|
if persist {
|
|
for _, message := range toSend {
|
|
err = ctx.Store.SaveMessage(&message)
|
|
if err != nil {
|
|
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
allMessages := append(existing, toSend...)
|
|
|
|
RenderConversation(ctx, allMessages, true)
|
|
|
|
// render a message header with no contents
|
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
|
|
|
replyCallback := func(reply model.Message) {
|
|
if !persist {
|
|
return
|
|
}
|
|
|
|
reply.ConversationID = c.ID
|
|
err = ctx.Store.SaveMessage(&reply)
|
|
if err != nil {
|
|
lmcli.Warn("Could not save reply: %v\n", err)
|
|
}
|
|
}
|
|
|
|
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
|
if err != nil {
|
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
|
}
|
|
}
|
|
|
|
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|
sb := strings.Builder{}
|
|
for _, message := range messages {
|
|
if message.Role != model.MessageRoleUser && (message.Role != model.MessageRoleSystem || !system) {
|
|
continue
|
|
}
|
|
sb.WriteString(fmt.Sprintf("<%s>\n", message.Role.FriendlyRole()))
|
|
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.Content))
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
|
messages, err := ctx.Store.Messages(c)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
const header = "Generate a concise 4-5 word title for the conversation below."
|
|
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, FormatForExternalPrompt(messages, false))
|
|
|
|
generateRequest := []model.Message{
|
|
{
|
|
Role: model.MessageRoleUser,
|
|
Content: prompt,
|
|
},
|
|
}
|
|
|
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
requestParams := model.RequestParameters{
|
|
Model: *ctx.Config.Conversations.TitleGenerationModel,
|
|
MaxTokens: 25,
|
|
}
|
|
|
|
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
|
// received on the signal channel. An empty string sent to the channel to
|
|
// noftify the caller that the animation has completed (carriage returned).
|
|
func ShowWaitAnimation(signal chan any) {
|
|
// Save the current cursor position
|
|
fmt.Print("\033[s")
|
|
|
|
animationStep := 0
|
|
for {
|
|
select {
|
|
case _ = <-signal:
|
|
// Relmcli the cursor position
|
|
fmt.Print("\033[u")
|
|
signal <- ""
|
|
return
|
|
default:
|
|
// Move the cursor to the saved position
|
|
modSix := animationStep % 6
|
|
if modSix == 3 || modSix == 0 {
|
|
fmt.Print("\033[u")
|
|
}
|
|
if modSix < 3 {
|
|
fmt.Print(".")
|
|
} else {
|
|
fmt.Print(" ")
|
|
}
|
|
animationStep++
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
|
// for content to be received on the provided channel. As soon as any (possibly
|
|
// chunked) content is received on the channel, the waiting animation is
|
|
// replaced by the content.
|
|
// Blocks until the channel is closed.
|
|
func ShowDelayedContent(content <-chan string) {
|
|
waitSignal := make(chan any)
|
|
go ShowWaitAnimation(waitSignal)
|
|
|
|
firstChunk := true
|
|
for chunk := range content {
|
|
if firstChunk {
|
|
// notify wait animation that we've received data
|
|
waitSignal <- ""
|
|
// wait for signal that wait animation has completed
|
|
<-waitSignal
|
|
firstChunk = false
|
|
}
|
|
fmt.Print(chunk)
|
|
}
|
|
}
|
|
|
|
// RenderConversation renders the given messages to TTY, with optional space
|
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
|
// are printed immediately after the final message (1 if false, 2 if true)
|
|
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
|
l := len(messages)
|
|
for i, message := range messages {
|
|
RenderMessage(ctx, &message)
|
|
if i < l-1 || spaceForResponse {
|
|
// print an additional space before the next message
|
|
fmt.Println()
|
|
}
|
|
}
|
|
}
|
|
|
|
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
|
var messageAge string
|
|
if m.CreatedAt.IsZero() {
|
|
messageAge = "now"
|
|
} else {
|
|
now := time.Now()
|
|
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
|
}
|
|
|
|
headerStyle := lipgloss.NewStyle().Bold(true)
|
|
|
|
switch m.Role {
|
|
case model.MessageRoleSystem:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
|
case model.MessageRoleUser:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
|
case model.MessageRoleAssistant:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
|
}
|
|
|
|
role := headerStyle.Render(m.Role.FriendlyRole())
|
|
|
|
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
|
separator := separatorStyle.Render("===")
|
|
timestamp := separatorStyle.Render(messageAge)
|
|
|
|
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
|
if m.Content != "" {
|
|
ctx.Chroma.Highlight(os.Stdout, m.Content)
|
|
fmt.Println()
|
|
}
|
|
}
|