Compare commits
No commits in common. "db465f1bf05c747e065ad84ee7daa37547af442c" and "aeeb7bb7f79ccc5a1380b47a54e4b33559381101" have entirely different histories.
db465f1bf0
...
aeeb7bb7f7
@ -3,7 +3,6 @@ package cmd
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
||||
"github.com/spf13/cobra"
|
||||
@ -15,16 +14,11 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
Short: "Open the chat interface",
|
||||
Long: `Open the chat interface, optionally on a given conversation.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// TODO: implement jump-to-conversation logic
|
||||
shortname := ""
|
||||
if len(args) == 1 {
|
||||
shortname = args[0]
|
||||
}
|
||||
if shortname != ""{
|
||||
_, err := cmdutil.LookupConversationE(ctx, shortname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := tui.Launch(ctx, shortname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -27,12 +28,36 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
|
||||
messagesToCopy, err := ctx.Store.Messages(toClone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to clone conversation: %v", err)
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
||||
}
|
||||
|
||||
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
|
||||
clone := &model.Conversation{
|
||||
Title: toClone.Title + " - Clone",
|
||||
}
|
||||
if err := ctx.Store.SaveConversation(clone); err != nil {
|
||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
||||
}
|
||||
|
||||
var errors []error
|
||||
messageCnt := 0
|
||||
for _, message := range messagesToCopy {
|
||||
newMessage := message
|
||||
newMessage.ConversationID = clone.ID
|
||||
newMessage.ID = 0
|
||||
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
||||
errors = append(errors, err)
|
||||
} else {
|
||||
messageCnt++
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
|
@ -26,7 +26,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
@ -39,7 +39,21 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
}
|
||||
|
||||
desiredIdx := len(messages) - 1 - offset
|
||||
toEdit := messages[desiredIdx]
|
||||
|
||||
// walk backwards through the conversation deleting messages until and
|
||||
// including the last user message
|
||||
toRemove := []model.Message{}
|
||||
var toEdit *model.Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if i == desiredIdx {
|
||||
toEdit = &messages[i]
|
||||
}
|
||||
toRemove = append(toRemove, messages[i])
|
||||
messages = messages[:i]
|
||||
if toEdit != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||
switch newContents {
|
||||
@ -49,17 +63,26 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
toEdit.Content = newContents
|
||||
|
||||
role, _ := cmd.Flags().GetString("role")
|
||||
if role != "" {
|
||||
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||
if role == "" {
|
||||
role = string(toEdit.Role)
|
||||
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||
}
|
||||
toEdit.Role = model.MessageRole(role)
|
||||
|
||||
for _, message := range toRemove {
|
||||
err = ctx.Store.DeleteMessage(&message)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not delete message: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.Store.UpdateMessage(&toEdit)
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRole(role),
|
||||
Content: newContents,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
|
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
@ -20,7 +21,7 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
Short: "List conversations",
|
||||
Long: `List conversations in order of recent activity`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
messages, err := ctx.Store.LatestConversationMessages()
|
||||
conversations, err := ctx.Store.Conversations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||
}
|
||||
@ -57,8 +58,13 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
|
||||
all, _ := cmd.Flags().GetBool("all")
|
||||
|
||||
for _, message := range messages {
|
||||
messageAge := now.Sub(message.CreatedAt)
|
||||
for _, conversation := range conversations {
|
||||
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
||||
if lastMessage == nil || err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||
|
||||
var category string
|
||||
for _, c := range categories {
|
||||
@ -70,9 +76,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
message.Conversation.ShortName.String,
|
||||
conversation.ShortName.String,
|
||||
util.HumanTimeElapsedSince(messageAge),
|
||||
message.Conversation.Title,
|
||||
conversation.Title,
|
||||
)
|
||||
|
||||
categorized[category] = append(
|
||||
@ -90,10 +96,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
continue
|
||||
}
|
||||
|
||||
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
||||
return int(a.timeSinceReply - b.timeSinceReply)
|
||||
})
|
||||
|
||||
fmt.Printf("%s:\n", category.name)
|
||||
for _, conv := range conversationLines {
|
||||
if conversationsPrinted >= count && !all {
|
||||
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
|
||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
||||
break outer
|
||||
}
|
||||
|
||||
|
@ -15,43 +15,42 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
|
||||
if input == "" {
|
||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if messageContents == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
var messages []model.Message
|
||||
conversation := &model.Conversation{}
|
||||
err := ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not save new conversation: %v", err)
|
||||
}
|
||||
|
||||
// TODO: probably just make this part of the conversation
|
||||
system := ctx.GetSystemPrompt()
|
||||
if system != "" {
|
||||
messages = append(messages, model.Message{
|
||||
messages := []model.Message{
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleSystem,
|
||||
Content: system,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, model.Message{
|
||||
Content: ctx.GetSystemPrompt(),
|
||||
},
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleUser,
|
||||
Content: input,
|
||||
})
|
||||
|
||||
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not start a new conversation: %v", err)
|
||||
Content: messageContents,
|
||||
},
|
||||
}
|
||||
|
||||
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
||||
|
||||
title, err := cmdutil.GenerateTitle(ctx, messages)
|
||||
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
|
||||
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = ctx.Store.UpdateConversation(conversation)
|
||||
|
||||
err = ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save conversation title: %v\n", err)
|
||||
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
@ -15,26 +15,21 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
|
||||
if input == "" {
|
||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if message == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
var messages []model.Message
|
||||
|
||||
// TODO: stop supplying system prompt as a message
|
||||
system := ctx.GetSystemPrompt()
|
||||
if system != "" {
|
||||
messages = append(messages, model.Message{
|
||||
messages := []model.Message{
|
||||
{
|
||||
Role: model.MessageRoleSystem,
|
||||
Content: system,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, model.Message{
|
||||
Content: ctx.GetSystemPrompt(),
|
||||
},
|
||||
{
|
||||
Role: model.MessageRoleUser,
|
||||
Content: input,
|
||||
})
|
||||
Content: message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := cmdutil.Prompt(ctx, messages, nil)
|
||||
if err != nil {
|
||||
|
@ -24,17 +24,12 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
var err error
|
||||
var title string
|
||||
|
||||
generate, _ := cmd.Flags().GetBool("generate")
|
||||
var title string
|
||||
if generate {
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
||||
}
|
||||
title, err = cmdutil.GenerateTitle(ctx, messages)
|
||||
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||
}
|
||||
@ -46,9 +41,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = ctx.Store.UpdateConversation(conversation)
|
||||
err = ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not update conversation title: %v\n", err)
|
||||
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retry the last user reply in a conversation",
|
||||
Long: `Prompt the conversation from the last user response.`,
|
||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
@ -25,28 +25,25 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
// Load the complete thread from the root message
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
// Find the last user message in the conversation
|
||||
var lastUserMessage *model.Message
|
||||
var i int
|
||||
for i = len(messages) - 1; i >= 0; i-- {
|
||||
// walk backwards through the conversation and delete messages, break
|
||||
// when we find the latest user response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == model.MessageRoleUser {
|
||||
lastUserMessage = &messages[i]
|
||||
break
|
||||
}
|
||||
|
||||
err = ctx.Store.DeleteMessage(&messages[i])
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if lastUserMessage == nil {
|
||||
return fmt.Errorf("No user message found in the conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
// Start a new branch at the last user message
|
||||
cmdutil.HandleReply(ctx, lastUserMessage, true)
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
|
@ -57,7 +57,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio
|
||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
||||
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
return c
|
||||
}
|
||||
@ -68,63 +68,48 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
|
||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||
}
|
||||
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
||||
}
|
||||
|
||||
// handleConversationReply handles sending messages to an existing
|
||||
// conversation, optionally persisting both the sent replies and responses.
|
||||
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
|
||||
if to == nil {
|
||||
lmcli.Fatal("Can't prompt from an empty message.")
|
||||
}
|
||||
|
||||
existing, err := ctx.Store.PathToRoot(to)
|
||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||
existing, err := ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
||||
}
|
||||
|
||||
RenderConversation(ctx, append(existing, messages...), true)
|
||||
|
||||
var savedReplies []model.Message
|
||||
if persist && len(messages) > 0 {
|
||||
savedReplies, err = ctx.Store.Reply(to, messages...)
|
||||
if persist {
|
||||
for _, message := range toSend {
|
||||
err = ctx.Store.SaveMessage(&message)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save messages: %v\n", err)
|
||||
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allMessages := append(existing, toSend...)
|
||||
|
||||
RenderConversation(ctx, allMessages, true)
|
||||
|
||||
// render a message header with no contents
|
||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||
|
||||
var lastMessage *model.Message
|
||||
|
||||
lastMessage = to
|
||||
if len(savedReplies) > 1 {
|
||||
lastMessage = &savedReplies[len(savedReplies)-1]
|
||||
}
|
||||
|
||||
replyCallback := func(reply model.Message) {
|
||||
if !persist {
|
||||
return
|
||||
}
|
||||
savedReplies, err = ctx.Store.Reply(lastMessage, reply)
|
||||
|
||||
reply.ConversationID = c.ID
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
lastMessage = &savedReplies[0]
|
||||
}
|
||||
|
||||
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
||||
_, err = Prompt(ctx, allMessages, replyCallback)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
@ -149,7 +134,12 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
|
||||
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
||||
messages, err := ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
||||
|
||||
Example conversation:
|
||||
@ -196,7 +186,6 @@ Title: A brief introduction
|
||||
|
||||
response = strings.TrimPrefix(response, "Title: ")
|
||||
response = strings.Trim(response, "\"")
|
||||
response = strings.TrimSpace(response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
@ -24,9 +24,9 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
cmdutil.RenderConversation(ctx, messages, false)
|
||||
|
@ -22,7 +22,6 @@ type Config struct {
|
||||
EnabledTools []string `yaml:"enabledTools"`
|
||||
} `yaml:"tools"`
|
||||
Providers []*struct {
|
||||
Name *string `yaml:"name"`
|
||||
Kind *string `yaml:"kind"`
|
||||
BaseURL *string `yaml:"baseUrl"`
|
||||
APIKey *string `yaml:"apiKey"`
|
||||
|
@ -4,12 +4,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
@ -36,9 +34,7 @@ func NewContext() (*Context, error) {
|
||||
}
|
||||
|
||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||
//Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
@ -61,37 +57,16 @@ func NewContext() (*Context, error) {
|
||||
}
|
||||
|
||||
func (c *Context) GetModels() (models []string) {
|
||||
modelCounts := make(map[string]int)
|
||||
for _, p := range c.Config.Providers {
|
||||
for _, m := range *p.Models {
|
||||
modelCounts[m]++
|
||||
models = append(models, *p.Name+"/"+m)
|
||||
}
|
||||
}
|
||||
|
||||
for m, c := range modelCounts {
|
||||
if c == 1 {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||
parts := strings.Split(model, "/")
|
||||
|
||||
var provider string
|
||||
if len(parts) > 1 {
|
||||
provider = parts[0]
|
||||
model = parts[1]
|
||||
}
|
||||
|
||||
for _, p := range c.Config.Providers {
|
||||
if provider != "" && *p.Name != provider {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range *p.Models {
|
||||
if m == model {
|
||||
switch *p.Kind {
|
||||
@ -100,28 +75,21 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
return &anthropic.AnthropicClient{
|
||||
anthropic := &anthropic.AnthropicClient{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
}, nil
|
||||
case "google":
|
||||
url := "https://generativelanguage.googleapis.com"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
return &google.Client{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
}, nil
|
||||
return anthropic, nil
|
||||
case "openai":
|
||||
url := "https://api.openai.com/v1"
|
||||
if p.BaseURL != nil {
|
||||
url = *p.BaseURL
|
||||
}
|
||||
return &openai.OpenAIClient{
|
||||
openai := &openai.OpenAIClient{
|
||||
BaseURL: url,
|
||||
APIKey: *p.APIKey,
|
||||
}, nil
|
||||
}
|
||||
return openai, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider kind: %s", *p.Kind)
|
||||
}
|
||||
|
@ -17,27 +17,18 @@ const (
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"index"`
|
||||
Conversation Conversation `gorm:"foreignKey:ConversationID"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time `gorm:"index"`
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
ParentID *uint `gorm:"index"`
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
SelectedRootID *uint
|
||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
|
@ -1,413 +0,0 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
func convertTools(tools []model.Tool) []Tool {
|
||||
geminiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
// TODO: proper enum handing
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Values: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
geminiTools[i] = Tool{
|
||||
FunctionDeclarations: []FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "OBJECT",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
||||
converted := make([]ContentPart, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
args := make(map[string]string)
|
||||
for k, v := range call.Parameters {
|
||||
args[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
converted[i].FunctionCall = &FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(functionCalls))
|
||||
for i, call := range functionCalls {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range call.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = call.Name
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
||||
results := make([]FunctionResponse, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||
}
|
||||
results[i] = FunctionResponse{
|
||||
Name: result.ToolName,
|
||||
Response: obj,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
startIdx := 0
|
||||
var system string
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
system = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
for _, m := range messages[startIdx:] {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
content := Content{
|
||||
Role: "model",
|
||||
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
case "tool_result":
|
||||
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// expand tool_result messages' results into multiple gemini messages
|
||||
for _, result := range results {
|
||||
content := Content{
|
||||
Role: "function",
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
FunctionResp: &result,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case model.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case model.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
panic("Unhandled role: " + m.Role)
|
||||
}
|
||||
|
||||
content := Content{
|
||||
Role: role,
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: m.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
}
|
||||
|
||||
request := &GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
SystemInstructions: system,
|
||||
GenerationConfig: &GenerationConfig{
|
||||
MaxOutputTokens: ¶ms.MaxTokens,
|
||||
Temperature: ¶ms.Temperature,
|
||||
TopP: ¶ms.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []model.ToolCall,
|
||||
callback provider.ReplyCallback,
|
||||
messages []model.Message,
|
||||
) ([]model.Message, error) {
|
||||
lastMessage := messages[len(messages)-1]
|
||||
continuation := false
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
continuation = true
|
||||
}
|
||||
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolResult)
|
||||
}
|
||||
|
||||
if continuation {
|
||||
messages[len(messages)-1] = toolCall
|
||||
} else {
|
||||
messages = append(messages, toolCall)
|
||||
}
|
||||
messages = append(messages, toolResult)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp GenerateContentResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := completionResp.Candidates[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(
|
||||
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content, err
|
||||
}
|
||||
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamResp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- part.Text
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
messages, err := handleToolCalls(
|
||||
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
||||
)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
package google
|
||||
|
||||
type Client struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]string `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []ContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GenerationConfig struct {
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
SystemInstructions string `json:"systemInstructions,omitempty"`
|
||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
type GenerateContentResponse struct {
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type FunctionDeclaration struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Values []string `json:"values,omitempty"`
|
||||
}
|
@ -13,26 +13,21 @@ import (
|
||||
)
|
||||
|
||||
type ConversationStore interface {
|
||||
Conversations() ([]model.Conversation, error)
|
||||
|
||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||
ConversationShortNameCompletions(search string) []string
|
||||
RootMessages(conversationID uint) ([]model.Message, error)
|
||||
LatestConversationMessages() ([]model.Message, error)
|
||||
|
||||
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
||||
UpdateConversation(conversation *model.Conversation) error
|
||||
SaveConversation(conversation *model.Conversation) error
|
||||
DeleteConversation(conversation *model.Conversation) error
|
||||
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
||||
|
||||
MessageByID(messageID uint) (*model.Message, error)
|
||||
MessageReplies(messageID uint) ([]model.Message, error)
|
||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
||||
|
||||
SaveMessage(message *model.Message) error
|
||||
DeleteMessage(message *model.Message) error
|
||||
UpdateMessage(message *model.Message) error
|
||||
DeleteMessage(message *model.Message, prune bool) error
|
||||
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
||||
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
||||
|
||||
PathToRoot(message *model.Message) ([]model.Message, error)
|
||||
PathToLeaf(message *model.Message) ([]model.Message, error)
|
||||
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
||||
}
|
||||
|
||||
type SQLStore struct {
|
||||
@ -57,52 +52,47 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
return &SQLStore{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) saveNewConversation(c *model.Conversation) error {
|
||||
// Save the new conversation
|
||||
err := s.db.Save(&c).Error
|
||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate and save its "short name"
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
||||
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
return s.UpdateConversation(c)
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
||||
if c == nil || c.ID == 0 {
|
||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
return s.db.Updates(&c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
||||
// Delete messages first
|
||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Delete(&c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
||||
panic("Not yet implemented")
|
||||
//return s.db.Delete(&message).Error
|
||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
||||
if m == nil || m.ID == 0 {
|
||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||
}
|
||||
return s.db.Updates(&m).Error
|
||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
||||
return s.db.Create(message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
||||
return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
||||
return s.db.Updates(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
||||
var conversations []model.Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
var conversations []model.Conversation
|
||||
// ignore error for completions
|
||||
s.db.Find(&conversations)
|
||||
completions := make([]string, 0, len(conversations))
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
@ -116,249 +106,27 @@ func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversatio
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation model.Conversation
|
||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
||||
var rootMessages []model.Message
|
||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rootMessages, nil
|
||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
||||
var messages []model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
||||
var message model.Message
|
||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
return &message, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
||||
var replies []model.Message
|
||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// StartConversation starts a new conversation with the provided messages
|
||||
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||
}
|
||||
|
||||
// Create new conversation
|
||||
conversation := &model.Conversation{}
|
||||
err := s.saveNewConversation(conversation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create first message
|
||||
messages[0].ConversationID = conversation.ID
|
||||
err = s.db.Create(&messages[0]).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Update conversation's selected root message
|
||||
conversation.SelectedRoot = &messages[0]
|
||||
err = s.UpdateConversation(conversation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Add additional replies to conversation
|
||||
if len(messages) > 1 {
|
||||
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
messages = append([]model.Message{messages[0]}, newMessages...)
|
||||
}
|
||||
return conversation, messages, nil
|
||||
}
|
||||
|
||||
// CloneConversation clones the given conversation and all of its root meesages
|
||||
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
||||
rootMessages, err := s.RootMessages(toClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
clone := &model.Conversation{
|
||||
Title: toClone.Title + " - Clone",
|
||||
}
|
||||
if err := s.saveNewConversation(clone); err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
||||
}
|
||||
|
||||
var errors []error
|
||||
var messageCnt uint = 0
|
||||
for _, root := range rootMessages {
|
||||
messageCnt++
|
||||
newRoot := root
|
||||
newRoot.ConversationID = clone.ID
|
||||
|
||||
cloned, count, err := s.CloneBranch(newRoot)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
messageCnt += count
|
||||
|
||||
if root.ID == *toClone.SelectedRootID {
|
||||
clone.SelectedRootID = &cloned.ID
|
||||
if err := s.UpdateConversation(clone); err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
return clone, messageCnt, nil
|
||||
}
|
||||
|
||||
// Reply replies to the given parentMessage with a series of messages
|
||||
func (s *SQLStore) Reply(parentMessage *model.Message, messages ...model.Message) ([]model.Message, error) {
|
||||
var savedMessages []model.Message
|
||||
currentParent := parentMessage
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
for i := range messages {
|
||||
message := messages[i]
|
||||
message.ConversationID = currentParent.ConversationID
|
||||
message.ParentID = ¤tParent.ID
|
||||
message.ID = 0
|
||||
message.CreatedAt = time.Time{}
|
||||
|
||||
if err := tx.Create(&message).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update parent selected reply
|
||||
currentParent.SelectedReply = &message
|
||||
if err := tx.Model(currentParent).Update("selected_reply_id", message.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
savedMessages = append(savedMessages, message)
|
||||
currentParent = &message
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return savedMessages, err
|
||||
}
|
||||
|
||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||
// a new message object. The new message will be attached to the same parent as
|
||||
// the message to clone.
|
||||
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
||||
newMessage := messageToClone
|
||||
newMessage.ID = 0
|
||||
newMessage.Replies = nil
|
||||
newMessage.SelectedReplyID = nil
|
||||
newMessage.SelectedReply = nil
|
||||
|
||||
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||
}
|
||||
|
||||
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
||||
}
|
||||
|
||||
var replyCount uint = 0
|
||||
for _, reply := range originalReplies {
|
||||
replyCount++
|
||||
|
||||
newReply := reply
|
||||
newReply.ConversationID = messageToClone.ConversationID
|
||||
newReply.ParentID = &newMessage.ID
|
||||
newReply.Parent = &newMessage
|
||||
|
||||
res, c, err := s.CloneBranch(newReply)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
newMessage.Replies = append(newMessage.Replies, *res)
|
||||
replyCount += c
|
||||
|
||||
if reply.ID == *messageToClone.SelectedReplyID {
|
||||
newMessage.SelectedReplyID = &res.ID
|
||||
if err := s.UpdateMessage(&newMessage); err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &newMessage, replyCount, nil
|
||||
}
|
||||
|
||||
// PathToRoot traverses message Parent until reaching the tree root
|
||||
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
||||
if message == nil {
|
||||
return nil, fmt.Errorf("Message is nil")
|
||||
}
|
||||
var path []model.Message
|
||||
current := message
|
||||
for {
|
||||
path = append([]model.Message{*current}, path...)
|
||||
if current.Parent == nil {
|
||||
break
|
||||
}
|
||||
|
||||
var err error
|
||||
current, err = s.MessageByID(*current.ParentID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding parent message: %w", err)
|
||||
}
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathToLeaf traverses message SelectedReply until reaching a tree leaf
|
||||
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
||||
if message == nil {
|
||||
return nil, fmt.Errorf("Message is nil")
|
||||
}
|
||||
var path []model.Message
|
||||
current := message
|
||||
for {
|
||||
path = append(path, *current)
|
||||
if current.SelectedReplyID == nil {
|
||||
break
|
||||
}
|
||||
|
||||
var err error
|
||||
current, err = s.MessageByID(*current.SelectedReplyID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("finding selected reply: %w", err)
|
||||
}
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
||||
var latestMessages []model.Message
|
||||
|
||||
subQuery := s.db.Model(&model.Message{}).
|
||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||
Group("conversation_id")
|
||||
|
||||
err := s.db.Model(&model.Message{}).
|
||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||
Order("created_at DESC").
|
||||
Preload("Conversation").
|
||||
Find(&latestMessages).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return latestMessages, nil
|
||||
// AddReply adds the given messages as a reply to the given conversation, can be
|
||||
// used to easily copy a message associated with one conversation, to another
|
||||
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
||||
m.ConversationID = c.ID
|
||||
m.ID = 0
|
||||
m.CreatedAt = time.Time{}
|
||||
return &m, s.SaveMessage(&m)
|
||||
}
|
||||
|
@ -11,9 +11,7 @@ import (
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
||||
|
||||
Use these results for your own reference in completing your task, they do not need to be shown to the user.
|
||||
const TREE_DESCRIPTION = `Retrieve a tree view of a directory's contents.
|
||||
|
||||
Example result:
|
||||
{
|
||||
@ -37,45 +35,48 @@ var DirTreeTool = model.Tool{
|
||||
Description: "If set, display the tree starting from this path relative to the current one.",
|
||||
},
|
||||
{
|
||||
Name: "depth",
|
||||
Name: "max_depth",
|
||||
Type: "integer",
|
||||
Description: "Depth of directory recursion. Default 0. Use -1 for unlimited.",
|
||||
Description: "Maximum depth of recursion. Default is unlimited.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
if tmp, ok := args["relative_path"]; ok {
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
var depth int = 0 // Default value if not provided
|
||||
if tmp, ok := args["depth"]; ok {
|
||||
switch v := tmp.(type) {
|
||||
case float64:
|
||||
depth = int(v)
|
||||
case string:
|
||||
var err error
|
||||
if depth, err = strconv.Atoi(v); err != nil {
|
||||
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
|
||||
var maxDepth int = -1
|
||||
tmp, ok = args["max_depth"]
|
||||
if ok {
|
||||
maxDepth, ok = tmp.(int)
|
||||
if !ok {
|
||||
if tmps, ok := tmp.(string); ok {
|
||||
tmpi, err := strconv.Atoi(tmps)
|
||||
maxDepth = tmpi
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := tree(relativeDir, depth)
|
||||
result := tree(relativeDir, maxDepth)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not serialize result: %v", err)
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func tree(path string, depth int) model.CallResult {
|
||||
func tree(path string, maxDepth int) model.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
@ -86,7 +87,7 @@ func tree(path string, depth int) model.CallResult {
|
||||
|
||||
var treeOutput strings.Builder
|
||||
treeOutput.WriteString(path + "\n")
|
||||
err := buildTree(&treeOutput, path, "", depth)
|
||||
err := buildTree(&treeOutput, path, "", maxDepth)
|
||||
if err != nil {
|
||||
return model.CallResult{
|
||||
Message: err.Error(),
|
||||
@ -96,7 +97,7 @@ func tree(path string, depth int) model.CallResult {
|
||||
return model.CallResult{Result: treeOutput.String()}
|
||||
}
|
||||
|
||||
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
||||
func buildTree(output *strings.Builder, path string, prefix string, maxDepth int) error {
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -123,14 +124,14 @@ func buildTree(output *strings.Builder, path string, prefix string, depth int) e
|
||||
output.WriteString(prefix + branch + file.Name())
|
||||
if file.IsDir() {
|
||||
output.WriteString("/\n")
|
||||
if depth != 0 {
|
||||
if maxDepth != 0 {
|
||||
var nextPrefix string
|
||||
if isLast {
|
||||
nextPrefix = prefix + " "
|
||||
} else {
|
||||
nextPrefix = prefix + "│ "
|
||||
}
|
||||
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
|
||||
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, maxDepth-1)
|
||||
}
|
||||
} else {
|
||||
output.WriteString(sizeStr + "\n")
|
||||
@ -139,3 +140,4 @@ func buildTree(output *strings.Builder, path string, prefix string, depth int) e
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -9,9 +9,7 @@ import (
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
||||
|
||||
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
|
||||
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||
|
||||
|
136
pkg/tui/chat.go
136
pkg/tui/chat.go
@ -66,10 +66,6 @@ type chatModel struct {
|
||||
replyChunkChan chan string
|
||||
persistence bool // whether we will save new messages in the conversation
|
||||
|
||||
tokenCount uint
|
||||
startTime time.Time
|
||||
elapsed time.Duration
|
||||
|
||||
// ui state
|
||||
focus focusState
|
||||
wrap bool // whether message content is wrapped to viewport width
|
||||
@ -286,9 +282,6 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
||||
}
|
||||
m.updateContent()
|
||||
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
||||
|
||||
m.tokenCount++
|
||||
m.elapsed = time.Now().Sub(m.startTime)
|
||||
case msgAssistantReply:
|
||||
// the last reply that was being worked on is finished
|
||||
reply := models.Message(msg)
|
||||
@ -307,9 +300,14 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
||||
}
|
||||
|
||||
if m.persistence {
|
||||
err := m.persistConversation()
|
||||
var err error
|
||||
if m.conversation.ID == 0 {
|
||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
||||
}
|
||||
if err != nil {
|
||||
cmds = append(cmds, wrapError(err))
|
||||
} else {
|
||||
cmds = append(cmds, m.persistConversation())
|
||||
}
|
||||
}
|
||||
|
||||
@ -336,7 +334,7 @@ func (m chatModel) Update(msg tea.Msg) (chatModel, tea.Cmd) {
|
||||
title := string(msg)
|
||||
m.conversation.Title = title
|
||||
if m.persistence {
|
||||
err := m.ctx.Store.UpdateConversation(m.conversation)
|
||||
err := m.ctx.Store.SaveConversation(m.conversation)
|
||||
if err != nil {
|
||||
cmds = append(cmds, wrapError(err))
|
||||
}
|
||||
@ -464,8 +462,8 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
m.input.Blur()
|
||||
return true, nil
|
||||
case "ctrl+s":
|
||||
input := strings.TrimSpace(m.input.Value())
|
||||
if input == "" {
|
||||
userInput := strings.TrimSpace(m.input.Value())
|
||||
if strings.TrimSpace(userInput) == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@ -473,20 +471,36 @@ func (m *chatModel) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
return true, wrapError(fmt.Errorf("Can't reply to a user message"))
|
||||
}
|
||||
|
||||
m.addMessage(models.Message{
|
||||
reply := models.Message{
|
||||
Role: models.MessageRoleUser,
|
||||
Content: input,
|
||||
})
|
||||
|
||||
m.input.SetValue("")
|
||||
Content: userInput,
|
||||
}
|
||||
|
||||
if m.persistence {
|
||||
err := m.persistConversation()
|
||||
var err error
|
||||
if m.conversation.ID == 0 {
|
||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
||||
}
|
||||
if err != nil {
|
||||
return true, wrapError(err)
|
||||
}
|
||||
|
||||
// ensure all messages up to the one we're about to add are persisted
|
||||
cmd := m.persistConversation()
|
||||
if cmd != nil {
|
||||
return true, cmd
|
||||
}
|
||||
|
||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
||||
if err != nil {
|
||||
return true, wrapError(err)
|
||||
}
|
||||
reply = *savedReply
|
||||
}
|
||||
|
||||
m.input.SetValue("")
|
||||
m.addMessage(reply)
|
||||
|
||||
m.updateContent()
|
||||
m.content.GotoBottom()
|
||||
return true, m.promptLLM()
|
||||
@ -679,16 +693,10 @@ func (m *chatModel) footerView() string {
|
||||
saving,
|
||||
segmentStyle.Render(status),
|
||||
}
|
||||
rightSegments := []string{}
|
||||
|
||||
if m.elapsed > 0 && m.tokenCount > 0 {
|
||||
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
|
||||
rightSegments = append(rightSegments, segmentStyle.Render(throughput))
|
||||
rightSegments := []string{
|
||||
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
|
||||
}
|
||||
|
||||
model := fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)
|
||||
rightSegments = append(rightSegments, segmentStyle.Render(model))
|
||||
|
||||
left := strings.Join(leftSegments, segmentSeparator)
|
||||
right := strings.Join(rightSegments, segmentSeparator)
|
||||
|
||||
@ -762,7 +770,7 @@ func (m *chatModel) loadConversation(shortname string) tea.Cmd {
|
||||
|
||||
func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, err := m.ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||
messages, err := m.ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
||||
}
|
||||
@ -770,48 +778,62 @@ func (m *chatModel) loadMessages(c *models.Conversation) tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *chatModel) persistConversation() error {
|
||||
if m.conversation.ID == 0 {
|
||||
// Start a new conversation with all messages so far
|
||||
c, messages, err := m.ctx.Store.StartConversation(m.messages...)
|
||||
func (m *chatModel) persistConversation() tea.Cmd {
|
||||
existingMessages, err := m.ctx.Store.Messages(m.conversation)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.conversation = c
|
||||
m.messages = messages
|
||||
|
||||
return nil
|
||||
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
|
||||
}
|
||||
|
||||
// else, we'll handle updating an existing conversation's messages
|
||||
for i := 0; i < len(m.messages); i++ {
|
||||
if m.messages[i].ID > 0 {
|
||||
// message has an ID, update its contents
|
||||
// TODO: check for content/tool equality before updating?
|
||||
err := m.ctx.Store.UpdateMessage(&m.messages[i])
|
||||
if err != nil {
|
||||
return err
|
||||
existingById := make(map[uint]*models.Message, len(existingMessages))
|
||||
for _, msg := range existingMessages {
|
||||
existingById[msg.ID] = &msg
|
||||
}
|
||||
} else if i > 0 {
|
||||
// messages is new, so add it as a reply to previous message
|
||||
saved, err := m.ctx.Store.Reply(&m.messages[i-1], m.messages[i])
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
currentById := make(map[uint]*models.Message, len(m.messages))
|
||||
for _, msg := range m.messages {
|
||||
currentById[msg.ID] = &msg
|
||||
}
|
||||
|
||||
for _, msg := range existingMessages {
|
||||
_, ok := currentById[msg.ID]
|
||||
if !ok {
|
||||
err := m.ctx.Store.DeleteMessage(&msg)
|
||||
if err != nil {
|
||||
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, msg := range m.messages {
|
||||
if msg.ID > 0 {
|
||||
exist, ok := existingById[msg.ID]
|
||||
if ok {
|
||||
if msg.Content == exist.Content {
|
||||
continue
|
||||
}
|
||||
// update message when contents don't match that of store
|
||||
err := m.ctx.Store.UpdateMessage(&msg)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
m.messages[i] = saved[0]
|
||||
} else {
|
||||
// message has no id and no previous messages to add it to
|
||||
// this shouldn't happen?
|
||||
return fmt.Errorf("Error: no messages to reply to")
|
||||
// this would be quite odd... and I'm not sure how to handle
|
||||
// it at the time of writing this
|
||||
}
|
||||
} else {
|
||||
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
m.setMessage(i, *newMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *chatModel) generateConversationTitle() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
title, err := cmdutil.GenerateTitle(m.ctx, m.messages)
|
||||
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
|
||||
if err != nil {
|
||||
return msgError(err)
|
||||
}
|
||||
@ -835,10 +857,6 @@ func (m *chatModel) promptLLM() tea.Cmd {
|
||||
m.waitingForReply = true
|
||||
m.status = "Press ctrl+c to cancel"
|
||||
|
||||
m.tokenCount = 0
|
||||
m.startTime = time.Now()
|
||||
m.elapsed = 0
|
||||
|
||||
return func() tea.Msg {
|
||||
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
|
@ -2,6 +2,7 @@ package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -114,7 +115,7 @@ func (m *conversationsModel) handleResize(width, height int) {
|
||||
func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
case msgStateEnter:
|
||||
case msgStateChange:
|
||||
cmds = append(cmds, m.loadConversations())
|
||||
m.content.SetContent(m.renderConversationList())
|
||||
case tea.WindowSizeMsg:
|
||||
@ -144,16 +145,24 @@ func (m conversationsModel) Update(msg tea.Msg) (conversationsModel, tea.Cmd) {
|
||||
|
||||
func (m *conversationsModel) loadConversations() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, err := m.ctx.Store.LatestConversationMessages()
|
||||
conversations, err := m.ctx.Store.Conversations()
|
||||
if err != nil {
|
||||
return msgError(fmt.Errorf("Could not load conversations: %v", err))
|
||||
}
|
||||
|
||||
loaded := make([]loadedConversation, len(messages))
|
||||
for i, m := range messages {
|
||||
loaded[i].lastReply = m
|
||||
loaded[i].conv = m.Conversation
|
||||
loaded := make([]loadedConversation, len(conversations))
|
||||
for i, c := range conversations {
|
||||
lastMessage, err := m.ctx.Store.LastMessage(&c)
|
||||
if err != nil {
|
||||
return msgError(err)
|
||||
}
|
||||
loaded[i].conv = c
|
||||
loaded[i].lastReply = *lastMessage
|
||||
}
|
||||
|
||||
slices.SortFunc(loaded, func(a, b loadedConversation) int {
|
||||
return b.lastReply.CreatedAt.Compare(a.lastReply.CreatedAt)
|
||||
})
|
||||
|
||||
return msgConversationsLoaded(loaded)
|
||||
}
|
||||
|
@ -147,16 +147,16 @@ func SetStructDefaults(data interface{}) bool {
|
||||
case reflect.String:
|
||||
defaultValue := defaultTag
|
||||
field.Set(reflect.ValueOf(&defaultValue))
|
||||
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetUint(intValue)
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits())
|
||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetInt(intValue)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits())
|
||||
case reflect.Float32:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Float64:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Bool:
|
||||
|
Loading…
Reference in New Issue
Block a user