lmcli/pkg/cmd/util/util.go

317 lines
8.6 KiB
Go
Raw Normal View History

package util
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss"
)
// Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM
defer close(content)
// render all content received over the channel
go ShowDelayedContent(content)
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
if err != nil {
return "", err
}
requestParams := model.RequestParameters{
Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature,
2024-03-12 02:01:53 -06:00
ToolBag: ctx.EnabledTools,
}
response, err := provider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content,
)
if response != "" {
// there was some content, so break to a new line after it
fmt.Println()
if err != nil {
lmcli.Warn("Received partial response. Error: %v\n", err)
err = nil
}
}
2024-04-29 00:14:21 -06:00
return response, err
}
// lookupConversation either returns the conversation found by the
// short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err)
}
if c.ID == 0 {
lmcli.Fatal("Conversation not found: %s\n", shortName)
}
return c
}
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
}
if c.ID == 0 {
return nil, fmt.Errorf("Conversation not found: %s", shortName)
}
return c, nil
}
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
}
// handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
if to == nil {
lmcli.Fatal("Can't prompt from an empty message.")
}
existing, err := ctx.Store.PathToRoot(to)
if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err)
}
RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []model.Message
if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil {
lmcli.Warn("Could not save messages: %v\n", err)
}
}
// render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastSavedMessage *model.Message
lastSavedMessage = to
if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1]
}
replyCallback := func(reply model.Message) {
if !persist {
return
}
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
lastSavedMessage = &savedReplies[0]
}
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
}
func FormatForExternalPrompt(messages []model.Message, system bool) string {
sb := strings.Builder{}
for _, message := range messages {
if message.Content == "" {
continue
}
switch message.Role {
case model.MessageRoleAssistant, model.MessageRoleToolCall:
sb.WriteString("Assistant:\n\n")
case model.MessageRoleUser:
sb.WriteString("User:\n\n")
default:
continue
}
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
}
return sb.String()
}
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
2024-05-30 12:22:48 -06:00
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
Example conversation:
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}]
2024-05-30 12:22:48 -06:00
Example response:
{"title": "Help with math homework"}
`
type msg struct {
Role string
Content string
}
var msgs []msg
for _, m := range messages {
msgs = append(msgs, msg{string(m.Role), m.Content})
}
2024-05-22 23:59:46 -06:00
// Serialize the conversation to JSON
conversation, err := json.Marshal(msgs)
if err != nil {
return "", err
}
generateRequest := []model.Message{
{
Role: model.MessageRoleSystem,
Content: systemPrompt,
},
{
Role: model.MessageRoleUser,
Content: string(conversation),
},
}
m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel)
if err != nil {
return "", err
}
requestParams := model.RequestParameters{
Model: m,
MaxTokens: 25,
}
response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
if err != nil {
return "", err
}
// Parse the JSON response
var jsonResponse struct {
Title string `json:"title"`
}
err = json.Unmarshal([]byte(response), &jsonResponse)
if err != nil {
return "", err
}
return jsonResponse.Title, nil
}
// ShowWaitAnimation prints an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// notify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
// Save the current cursor position
fmt.Print("\033[s")
animationStep := 0
for {
select {
case _ = <-signal:
// Relmcli the cursor position
fmt.Print("\033[u")
signal <- ""
return
default:
// Move the cursor to the saved position
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\033[u")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// ShowDelayedContent displays a waiting animation to stdout while waiting
// for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed.
func ShowDelayedContent(content <-chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range content {
if firstChunk {
// notify wait animation that we've received data
waitSignal <- ""
// wait for signal that wait animation has completed
<-waitSignal
firstChunk = false
}
fmt.Print(chunk)
}
}
// RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
l := len(messages)
for i, message := range messages {
RenderMessage(ctx, &message)
if i < l-1 || spaceForResponse {
// print an additional space before the next message
fmt.Println()
}
}
}
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
var messageAge string
if m.CreatedAt.IsZero() {
messageAge = "now"
} else {
now := time.Now()
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
}
headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role {
case model.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case model.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case model.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
}
role := headerStyle.Render(m.Role.FriendlyRole())
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
separator := separatorStyle.Render("===")
timestamp := separatorStyle.Render(messageAge)
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
if m.Content != "" {
ctx.Chroma.Highlight(os.Stdout, m.Content)
fmt.Println()
}
}