lmcli/pkg/cmd/util/util.go
Matt Low 91d3c9c2e1 Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply
messages, it accepts a callback which is called with each successive
reply the conversation.

This gives the caller more flexibility in how it handles replies (e.g.
it can react to them immediately now, instead of waiting for the entire
call to finish)
2024-03-12 20:39:34 +00:00

264 lines
7.4 KiB
Go

package util
import (
"context"
"fmt"
"os"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss"
)
// fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM
defer close(content)
// render all content received over the channel
go ShowDelayedContent(content)
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil {
return "", err
}
requestParams := model.RequestParameters{
Model: *ctx.Config.Defaults.Model,
MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature,
ToolBag: ctx.EnabledTools,
}
response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content,
)
if response != "" {
// there was some content, so break to a new line after it
fmt.Println()
if err != nil {
lmcli.Warn("Received partial response. Error: %v\n", err)
err = nil
}
}
return response, nil
}
// lookupConversation either returns the conversation found by the
// short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err)
}
if c.ID == 0 {
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
}
return c
}
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
}
if c.ID == 0 {
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
}
return c, nil
}
// handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses.
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
existing, err := ctx.Store.Messages(c)
if err != nil {
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
}
if persist {
for _, message := range toSend {
err = ctx.Store.SaveMessage(&message)
if err != nil {
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
}
}
}
allMessages := append(existing, toSend...)
RenderConversation(ctx, allMessages, true)
// render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replyCallback := func(reply model.Message) {
if !persist {
return
}
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
}
func FormatForExternalPrompt(messages []model.Message, system bool) string {
sb := strings.Builder{}
for _, message := range messages {
if message.Role != model.MessageRoleUser && (message.Role != model.MessageRoleSystem || !system) {
continue
}
sb.WriteString(fmt.Sprintf("<%s>\n", message.Role.FriendlyRole()))
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.Content))
}
return sb.String()
}
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
messages, err := ctx.Store.Messages(c)
if err != nil {
return "", err
}
const header = "Generate a concise 4-5 word title for the conversation below."
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, FormatForExternalPrompt(messages, false))
generateRequest := []model.Message{
{
Role: model.MessageRoleUser,
Content: prompt,
},
}
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
if err != nil {
return "", err
}
requestParams := model.RequestParameters{
Model: *ctx.Config.Conversations.TitleGenerationModel,
MaxTokens: 25,
}
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
if err != nil {
return "", err
}
return response, nil
}
// ShowWaitAnimation prints an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// noftify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
// Save the current cursor position
fmt.Print("\033[s")
animationStep := 0
for {
select {
case _ = <-signal:
// Relmcli the cursor position
fmt.Print("\033[u")
signal <- ""
return
default:
// Move the cursor to the saved position
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\033[u")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// ShowDelayedContent displays a waiting animation to stdout while waiting
// for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed.
func ShowDelayedContent(content <-chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range content {
if firstChunk {
// notify wait animation that we've received data
waitSignal <- ""
// wait for signal that wait animation has completed
<-waitSignal
firstChunk = false
}
fmt.Print(chunk)
}
}
// RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
l := len(messages)
for i, message := range messages {
RenderMessage(ctx, &message)
if i < l-1 || spaceForResponse {
// print an additional space before the next message
fmt.Println()
}
}
}
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
var messageAge string
if m.CreatedAt.IsZero() {
messageAge = "now"
} else {
now := time.Now()
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
}
headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role {
case model.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case model.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case model.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
}
role := headerStyle.Render(m.Role.FriendlyRole())
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
separator := separatorStyle.Render("===")
timestamp := separatorStyle.Render(messageAge)
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
if m.Content != "" {
ctx.Chroma.Highlight(os.Stdout, m.Content)
fmt.Println()
}
}