lmcli/pkg/cmd/util/util.go

304 lines
8.2 KiB
Go
Raw Normal View History

package util
import (
"context"
"fmt"
"os"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss"
)
// Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM
defer close(content)
// render all content received over the channel
go ShowDelayedContent(content)
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil {
return "", err
}
requestParams := model.RequestParameters{
Model: *ctx.Config.Defaults.Model,
MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature,
2024-03-12 02:01:53 -06:00
ToolBag: ctx.EnabledTools,
}
response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content,
)
if response != "" {
// there was some content, so break to a new line after it
fmt.Println()
if err != nil {
lmcli.Warn("Received partial response. Error: %v\n", err)
err = nil
}
}
2024-04-29 00:14:21 -06:00
return response, err
}
// lookupConversation either returns the conversation found by the
// short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err)
}
if c.ID == 0 {
lmcli.Fatal("Conversation not found: %s\n", shortName)
}
return c
}
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
}
if c.ID == 0 {
return nil, fmt.Errorf("Conversation not found: %s", shortName)
}
return c, nil
}
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
if to == nil {
lmcli.Fatal("Can't prompt from an empty message.")
}
existing, err := ctx.Store.PathToRoot(to)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []model.Message
if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil {
lmcli.Warn("Could not save messages: %v\n", err)
}
}
// render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastMessage *model.Message
lastMessage = to
if len(savedReplies) > 1 {
lastMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply model.Message) {
if !persist {
return
}
savedReplies, err = ctx.Store.Reply(lastMessage, reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
lastMessage = &savedReplies[0]
}
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
}
func FormatForExternalPrompt(messages []model.Message, system bool) string {
sb := strings.Builder{}
for _, message := range messages {
if message.Content == "" {
continue
}
switch message.Role {
case model.MessageRoleAssistant, model.MessageRoleToolCall:
sb.WriteString("Assistant:\n\n")
case model.MessageRoleUser:
sb.WriteString("User:\n\n")
default:
continue
}
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
}
return sb.String()
}
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
Example conversation:
"""
User:
Hello!
Assistant:
Hello! How may I assist you?
"""
Example response:
"""
Title: A brief introduction
"""
`
conversation := FormatForExternalPrompt(messages, false)
generateRequest := []model.Message{
{
Role: model.MessageRoleUser,
Content: fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n%s", conversation, prompt),
},
}
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
if err != nil {
return "", err
}
requestParams := model.RequestParameters{
Model: *ctx.Config.Conversations.TitleGenerationModel,
MaxTokens: 25,
}
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
if err != nil {
return "", err
}
response = strings.TrimPrefix(response, "Title: ")
response = strings.Trim(response, "\"")
2024-05-18 20:59:16 -06:00
response = strings.TrimSpace(response)
return response, nil
}
// ShowWaitAnimation prints an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// notify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
// Save the current cursor position
fmt.Print("\033[s")
animationStep := 0
for {
select {
case _ = <-signal:
// Relmcli the cursor position
fmt.Print("\033[u")
signal <- ""
return
default:
// Move the cursor to the saved position
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\033[u")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// ShowDelayedContent displays a waiting animation to stdout while waiting
// for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed.
func ShowDelayedContent(content <-chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range content {
if firstChunk {
// notify wait animation that we've received data
waitSignal <- ""
// wait for signal that wait animation has completed
<-waitSignal
firstChunk = false
}
fmt.Print(chunk)
}
}
// RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
l := len(messages)
for i, message := range messages {
RenderMessage(ctx, &message)
if i < l-1 || spaceForResponse {
// print an additional space before the next message
fmt.Println()
}
}
}
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
var messageAge string
if m.CreatedAt.IsZero() {
messageAge = "now"
} else {
now := time.Now()
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
}
headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role {
case model.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case model.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case model.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
}
role := headerStyle.Render(m.Role.FriendlyRole())
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
separator := separatorStyle.Render("===")
timestamp := separatorStyle.Render(messageAge)
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
if m.Content != "" {
ctx.Chroma.Highlight(os.Stdout, m.Content)
fmt.Println()
}
}