317 lines
8.7 KiB
Go
317 lines
8.7 KiB
Go
package util
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
|
"github.com/charmbracelet/lipgloss"
|
|
)
|
|
|
|
// Prompt prompts the configured the configured model and streams the response
|
|
// to stdout. Returns all model reply messages.
|
|
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
|
content := make(chan string) // receives the reponse from LLM
|
|
defer close(content)
|
|
|
|
// render all content received over the channel
|
|
go ShowDelayedContent(content)
|
|
|
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
requestParams := model.RequestParameters{
|
|
Model: *ctx.Config.Defaults.Model,
|
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
|
Temperature: *ctx.Config.Defaults.Temperature,
|
|
ToolBag: ctx.EnabledTools,
|
|
}
|
|
|
|
response, err := completionProvider.CreateChatCompletionStream(
|
|
context.Background(), requestParams, messages, callback, content,
|
|
)
|
|
if response != "" {
|
|
// there was some content, so break to a new line after it
|
|
fmt.Println()
|
|
|
|
if err != nil {
|
|
lmcli.Warn("Received partial response. Error: %v\n", err)
|
|
err = nil
|
|
}
|
|
}
|
|
return response, err
|
|
}
|
|
|
|
// lookupConversation either returns the conversation found by the
|
|
// short name or exits the program
|
|
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
|
}
|
|
if c.ID == 0 {
|
|
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
|
}
|
|
return c
|
|
}
|
|
|
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
|
}
|
|
if c.ID == 0 {
|
|
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
|
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
|
}
|
|
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
|
}
|
|
|
|
// handleConversationReply handles sending messages to an existing
|
|
// conversation, optionally persisting both the sent replies and responses.
|
|
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
|
|
if to == nil {
|
|
lmcli.Fatal("Can't prompt from an empty message.")
|
|
}
|
|
|
|
existing, err := ctx.Store.PathToRoot(to)
|
|
if err != nil {
|
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
|
}
|
|
|
|
RenderConversation(ctx, append(existing, messages...), true)
|
|
|
|
var savedReplies []model.Message
|
|
if persist && len(messages) > 0 {
|
|
savedReplies, err = ctx.Store.Reply(to, messages...)
|
|
if err != nil {
|
|
lmcli.Warn("Could not save messages: %v\n", err)
|
|
}
|
|
}
|
|
|
|
// render a message header with no contents
|
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
|
|
|
var lastSavedMessage *model.Message
|
|
lastSavedMessage = to
|
|
if len(savedReplies) > 0 {
|
|
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
|
}
|
|
|
|
replyCallback := func(reply model.Message) {
|
|
if !persist {
|
|
return
|
|
}
|
|
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
|
|
if err != nil {
|
|
lmcli.Warn("Could not save reply: %v\n", err)
|
|
}
|
|
lastSavedMessage = &savedReplies[0]
|
|
}
|
|
|
|
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
|
if err != nil {
|
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
|
}
|
|
}
|
|
|
|
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|
sb := strings.Builder{}
|
|
for _, message := range messages {
|
|
if message.Content == "" {
|
|
continue
|
|
}
|
|
switch message.Role {
|
|
case model.MessageRoleAssistant, model.MessageRoleToolCall:
|
|
sb.WriteString("Assistant:\n\n")
|
|
case model.MessageRoleUser:
|
|
sb.WriteString("User:\n\n")
|
|
default:
|
|
continue
|
|
}
|
|
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
|
|
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
|
|
|
Example conversation:
|
|
|
|
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}]
|
|
|
|
Example response:
|
|
|
|
{"title": "Help with math homework"}
|
|
`
|
|
type msg struct {
|
|
Role string
|
|
Content string
|
|
}
|
|
|
|
var msgs []msg
|
|
for _, m := range messages {
|
|
msgs = append(msgs, msg{string(m.Role), m.Content})
|
|
}
|
|
|
|
// Serialize the conversation to JSON
|
|
conversation, err := json.Marshal(msgs)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
generateRequest := []model.Message{
|
|
{
|
|
Role: model.MessageRoleSystem,
|
|
Content: systemPrompt,
|
|
},
|
|
{
|
|
Role: model.MessageRoleUser,
|
|
Content: string(conversation),
|
|
},
|
|
}
|
|
|
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
requestParams := model.RequestParameters{
|
|
Model: *ctx.Config.Conversations.TitleGenerationModel,
|
|
MaxTokens: 25,
|
|
}
|
|
|
|
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Parse the JSON response
|
|
var jsonResponse struct {
|
|
Title string `json:"title"`
|
|
}
|
|
err = json.Unmarshal([]byte(response), &jsonResponse)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return jsonResponse.Title, nil
|
|
}
|
|
|
|
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
|
// received on the signal channel. An empty string sent to the channel to
|
|
// notify the caller that the animation has completed (carriage returned).
|
|
func ShowWaitAnimation(signal chan any) {
|
|
// Save the current cursor position
|
|
fmt.Print("\033[s")
|
|
|
|
animationStep := 0
|
|
for {
|
|
select {
|
|
case _ = <-signal:
|
|
// Relmcli the cursor position
|
|
fmt.Print("\033[u")
|
|
signal <- ""
|
|
return
|
|
default:
|
|
// Move the cursor to the saved position
|
|
modSix := animationStep % 6
|
|
if modSix == 3 || modSix == 0 {
|
|
fmt.Print("\033[u")
|
|
}
|
|
if modSix < 3 {
|
|
fmt.Print(".")
|
|
} else {
|
|
fmt.Print(" ")
|
|
}
|
|
animationStep++
|
|
time.Sleep(250 * time.Millisecond)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
|
// for content to be received on the provided channel. As soon as any (possibly
|
|
// chunked) content is received on the channel, the waiting animation is
|
|
// replaced by the content.
|
|
// Blocks until the channel is closed.
|
|
func ShowDelayedContent(content <-chan string) {
|
|
waitSignal := make(chan any)
|
|
go ShowWaitAnimation(waitSignal)
|
|
|
|
firstChunk := true
|
|
for chunk := range content {
|
|
if firstChunk {
|
|
// notify wait animation that we've received data
|
|
waitSignal <- ""
|
|
// wait for signal that wait animation has completed
|
|
<-waitSignal
|
|
firstChunk = false
|
|
}
|
|
fmt.Print(chunk)
|
|
}
|
|
}
|
|
|
|
// RenderConversation renders the given messages to TTY, with optional space
|
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
|
// are printed immediately after the final message (1 if false, 2 if true)
|
|
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
|
l := len(messages)
|
|
for i, message := range messages {
|
|
RenderMessage(ctx, &message)
|
|
if i < l-1 || spaceForResponse {
|
|
// print an additional space before the next message
|
|
fmt.Println()
|
|
}
|
|
}
|
|
}
|
|
|
|
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
|
var messageAge string
|
|
if m.CreatedAt.IsZero() {
|
|
messageAge = "now"
|
|
} else {
|
|
now := time.Now()
|
|
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
|
}
|
|
|
|
headerStyle := lipgloss.NewStyle().Bold(true)
|
|
|
|
switch m.Role {
|
|
case model.MessageRoleSystem:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
|
case model.MessageRoleUser:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
|
case model.MessageRoleAssistant:
|
|
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
|
}
|
|
|
|
role := headerStyle.Render(m.Role.FriendlyRole())
|
|
|
|
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
|
separator := separatorStyle.Render("===")
|
|
timestamp := separatorStyle.Render(messageAge)
|
|
|
|
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
|
if m.Content != "" {
|
|
ctx.Chroma.Highlight(os.Stdout, m.Content)
|
|
fmt.Println()
|
|
}
|
|
}
|